Skip to content

Commit e138d4f

Browse files
committed
Access IO via created mapping.
Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
1 parent aa50ab8 commit e138d4f

File tree

3 files changed

+290
-31
lines changed

3 files changed

+290
-31
lines changed

src/plugins/intel_npu/src/plugin/npuw/host_flash_attention.cpp

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,172 @@ std::optional<HostFlashAttention> HostFlashAttention::from(const std::shared_ptr
572572
hfa._kv_cache_size = kv_cache_size;
573573
hfa._sdpa_attention = std::move(attention_opt.value()); // Store SDPA attention metadata
574574

575+
// Build SDPA input parameter index mapping from pattern nodes
576+
// This mapping will be transferred to compiled::HostFlashAttentionInfo
577+
LOG_INFO("Building SDPA input parameter index mapping...");
578+
579+
// Helper lambda to safely extract parameter from node (skipping Convert ops)
580+
auto extract_param = [&](const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<ov::op::v0::Parameter> {
581+
auto current = node;
582+
// Skip Convert nodes to get to actual Parameter
583+
while (current && ov::is_type<ov::op::v0::Convert>(current.get())) {
584+
if (current->get_input_size() > 0) {
585+
current = current->get_input_node_shared_ptr(0);
586+
} else {
587+
break;
588+
}
589+
}
590+
return ov::as_type_ptr<ov::op::v0::Parameter>(current);
591+
};
592+
593+
// Extract Q (query) parameter - input 0 of MatMul1
594+
if (auto q_param = extract_param(pattern_nodes.matmul1_node->get_input_node_shared_ptr(0))) {
595+
std::size_t q_idx = model->get_parameter_index(q_param);
596+
hfa._sdpa_param_index_map[SDPAInputId::QUERY] = q_idx;
597+
LOG_DEBUG("Mapped QUERY to parameter index " << q_idx);
598+
}
599+
600+
// Extract past_key parameter - input 0 of past_key_concat
601+
if (pattern_nodes.past_key_concat_node) {
602+
if (auto past_k_param = extract_param(pattern_nodes.past_key_concat_node->get_input_node_shared_ptr(0))) {
603+
std::size_t past_k_idx = model->get_parameter_index(past_k_param);
604+
hfa._sdpa_param_index_map[SDPAInputId::PAST_KEY] = past_k_idx;
605+
LOG_DEBUG("Mapped PAST_KEY to parameter index " << past_k_idx);
606+
}
607+
608+
// Extract present_key parameter - input 1 of past_key_concat
609+
if (auto present_k_param = extract_param(pattern_nodes.past_key_concat_node->get_input_node_shared_ptr(1))) {
610+
std::size_t present_k_idx = model->get_parameter_index(present_k_param);
611+
hfa._sdpa_param_index_map[SDPAInputId::PRESENT_KEY] = present_k_idx;
612+
LOG_DEBUG("Mapped PRESENT_KEY to parameter index " << present_k_idx);
613+
}
614+
}
615+
616+
// Extract past_value parameter - input 0 of past_value_concat
617+
if (pattern_nodes.past_value_concat_node) {
618+
if (auto past_v_param = extract_param(pattern_nodes.past_value_concat_node->get_input_node_shared_ptr(0))) {
619+
std::size_t past_v_idx = model->get_parameter_index(past_v_param);
620+
hfa._sdpa_param_index_map[SDPAInputId::PAST_VALUE] = past_v_idx;
621+
LOG_DEBUG("Mapped PAST_VALUE to parameter index " << past_v_idx);
622+
}
623+
624+
// Extract present_value parameter - input 1 of past_value_concat
625+
if (auto present_v_param = extract_param(pattern_nodes.past_value_concat_node->get_input_node_shared_ptr(1))) {
626+
std::size_t present_v_idx = model->get_parameter_index(present_v_param);
627+
hfa._sdpa_param_index_map[SDPAInputId::PRESENT_VALUE] = present_v_idx;
628+
LOG_DEBUG("Mapped PRESENT_VALUE to parameter index " << present_v_idx);
629+
}
630+
}
631+
632+
// Extract mask parameter - from SDPA attention metadata
633+
std::size_t mask_idx = model->get_parameter_index(hfa._sdpa_attention._mask);
634+
hfa._sdpa_param_index_map[SDPAInputId::ATTENTION_MASK] = mask_idx;
635+
LOG_DEBUG("Mapped ATTENTION_MASK to parameter index " << mask_idx);
636+
637+
LOG_INFO("Built SDPA input mapping with " << hfa._sdpa_param_index_map.size() << " entries");
638+
639+
// Print the complete mapping table
640+
std::cout << "\n========== SDPA Input Index Mapping ==========\n";
641+
std::cout << "Total entries: " << hfa._sdpa_param_index_map.size() << "\n";
642+
643+
// Helper to convert enum to string for printing
644+
auto sdpa_input_id_to_string = [](SDPAInputId id) -> const char* {
645+
switch (id) {
646+
case SDPAInputId::PAST_KEY:
647+
return "PAST_KEY";
648+
case SDPAInputId::PAST_VALUE:
649+
return "PAST_VALUE";
650+
case SDPAInputId::QUERY:
651+
return "QUERY";
652+
case SDPAInputId::PRESENT_KEY:
653+
return "PRESENT_KEY";
654+
case SDPAInputId::ATTENTION_MASK:
655+
return "ATTENTION_MASK";
656+
case SDPAInputId::PRESENT_VALUE:
657+
return "PRESENT_VALUE";
658+
default:
659+
return "UNKNOWN";
660+
}
661+
};
662+
663+
for (const auto& [input_id, param_idx] : hfa._sdpa_param_index_map) {
664+
std::cout << " " << sdpa_input_id_to_string(input_id) << " -> parameter[" << param_idx << "]" << std::endl;
665+
}
666+
std::cout << "=============================================\n" << std::endl;
667+
668+
// Build HFA Tile Model input index mapping
669+
// This mapping allows accessing tile model inputs by semantic name
670+
LOG_INFO("Building HFA Tile Model input index mapping...");
671+
672+
// Parse tile model inputs by their tensor names
673+
// Expected input order: [past_acc, past_max, past_d, k_tile, v_tile, q, mask_tile]
674+
const auto& tile_inputs = tile_model->inputs();
675+
for (std::size_t i = 0; i < tile_inputs.size(); ++i) {
676+
const auto& tensor_names = tile_inputs[i].get_names();
677+
if (tensor_names.empty()) {
678+
LOG_WARN("Tile model input[" << i << "] has no tensor name");
679+
continue;
680+
}
681+
682+
const std::string& name = *tensor_names.begin();
683+
684+
// Map tensor name to enum ID
685+
if (name == "past_acc") {
686+
hfa._tile_param_index_map[HFATileInputId::PAST_ACC] = i;
687+
LOG_DEBUG("Mapped PAST_ACC to tile input[" << i << "]");
688+
} else if (name == "past_max") {
689+
hfa._tile_param_index_map[HFATileInputId::PAST_MAX] = i;
690+
LOG_DEBUG("Mapped PAST_MAX to tile input[" << i << "]");
691+
} else if (name == "past_d") {
692+
hfa._tile_param_index_map[HFATileInputId::PAST_D] = i;
693+
LOG_DEBUG("Mapped PAST_D to tile input[" << i << "]");
694+
} else if (name == "k_tile") {
695+
hfa._tile_param_index_map[HFATileInputId::K_TILE] = i;
696+
LOG_DEBUG("Mapped K_TILE to tile input[" << i << "]");
697+
} else if (name == "v_tile") {
698+
hfa._tile_param_index_map[HFATileInputId::V_TILE] = i;
699+
LOG_DEBUG("Mapped V_TILE to tile input[" << i << "]");
700+
} else if (name == "q") {
701+
hfa._tile_param_index_map[HFATileInputId::Q] = i;
702+
LOG_DEBUG("Mapped Q to tile input[" << i << "]");
703+
} else if (name == "mask_tile") {
704+
hfa._tile_param_index_map[HFATileInputId::MASK_TILE] = i;
705+
LOG_DEBUG("Mapped MASK_TILE to tile input[" << i << "]");
706+
} else {
707+
LOG_WARN("Unknown tile model input name: " << name);
708+
}
709+
}
710+
711+
// Print the tile input mapping
712+
std::cout << "\n========== HFA Tile Model Input Mapping ==========\n";
713+
std::cout << "Total entries: " << hfa._tile_param_index_map.size() << "\n";
714+
715+
auto tile_input_id_to_string = [](HFATileInputId id) -> const char* {
716+
switch (id) {
717+
case HFATileInputId::PAST_ACC:
718+
return "PAST_ACC";
719+
case HFATileInputId::PAST_MAX:
720+
return "PAST_MAX";
721+
case HFATileInputId::PAST_D:
722+
return "PAST_D";
723+
case HFATileInputId::K_TILE:
724+
return "K_TILE";
725+
case HFATileInputId::V_TILE:
726+
return "V_TILE";
727+
case HFATileInputId::Q:
728+
return "Q";
729+
case HFATileInputId::MASK_TILE:
730+
return "MASK_TILE";
731+
default:
732+
return "UNKNOWN";
733+
}
734+
};
735+
736+
for (const auto& [input_id, input_idx] : hfa._tile_param_index_map) {
737+
std::cout << " " << tile_input_id_to_string(input_id) << " -> input[" << input_idx << "]" << std::endl;
738+
}
739+
std::cout << "==================================================\n" << std::endl;
740+
575741
LOG_INFO("Successfully created HostFlashAttention");
576742
std::cout << "HostFlashAttention created with tile_size=" << hfa._tile_size
577743
<< ", kv_cache_size=" << hfa._kv_cache_size << std::endl;
@@ -600,6 +766,7 @@ HostFlashAttention::HostFlashAttention(const function::HostFlashAttention& func_
600766
const auto& sdpa_attn = func_hfa._sdpa_attention;
601767
const auto& original_model = func_hfa._original_model;
602768

769+
// Build parameter info for past key/value tensors
603770
_sdpa_attention_info.params.reserve(sdpa_attn._inputs.size());
604771
for (const auto& input : sdpa_attn._inputs) {
605772
std::size_t p_idx = original_model->get_parameter_index(input.param);
@@ -608,8 +775,16 @@ HostFlashAttention::HostFlashAttention(const function::HostFlashAttention& func_
608775
_sdpa_attention_info.mask_idx = original_model->get_parameter_index(sdpa_attn._mask);
609776
_sdpa_attention_info.query_size = sdpa_attn.query_len();
610777

778+
// Copy SDPA input index mapping from function HFA (already built in from() method)
779+
_sdpa_attention_info.sdpa_param_index_map = func_hfa._sdpa_param_index_map;
780+
781+
// Copy HFA Tile Model input index mapping from function HFA
782+
_sdpa_attention_info.tile_param_index_map = func_hfa._tile_param_index_map;
783+
611784
LOG_INFO("Extracted HFA config: tile_size=" << _tile_size << ", kv_cache_size=" << _kv_cache_size);
612785
LOG_INFO("Extracted " << _sdpa_attention_info.params.size() << " past KV parameters from original SDPA model");
786+
LOG_INFO("Copied SDPA input mapping with " << _sdpa_attention_info.sdpa_param_index_map.size() << " entries");
787+
LOG_INFO("Copied Tile input mapping with " << _sdpa_attention_info.tile_param_index_map.size() << " entries");
613788

614789
// Note: _compiled_tile_model and _compiled_final_tile_model will be set later by
615790
// compile_host_flash_attention_model()

src/plugins/intel_npu/src/plugin/npuw/host_flash_attention.hpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,37 @@
1717
namespace ov {
1818
namespace npuw {
1919

20+
// SDPA (Scaled Dot-Product Attention) input tensor identifiers
21+
// Represents the standardized input layout for SDPA operations
22+
// Defined at namespace level for use in both function and compiled namespaces
23+
enum class SDPAInputId : uint8_t {
24+
PAST_KEY = 0, // Historical key cache tensor
25+
PAST_VALUE = 1, // Historical value cache tensor
26+
QUERY = 2, // Query tensor for current iteration
27+
PRESENT_KEY = 3, // Current key tensor (new tokens)
28+
ATTENTION_MASK = 4, // Attention mask tensor
29+
PRESENT_VALUE = 5, // Current value tensor (new tokens)
30+
31+
// Sentinel value for enum range
32+
COUNT
33+
};
34+
35+
// HFA Tile Model input tensor identifiers
36+
// Represents the input layout for Host Flash Attention tile models
37+
// Input names: [past_acc, past_max, past_d, k_tile, v_tile, q, mask_tile]
38+
enum class HFATileInputId : uint8_t {
39+
PAST_ACC = 0, // Accumulated attention output from previous tiles
40+
PAST_MAX = 1, // Maximum values from previous tiles (for numerical stability)
41+
PAST_D = 2, // Normalization denominator from previous tiles
42+
K_TILE = 3, // Current K (key) tile slice
43+
V_TILE = 4, // Current V (value) tile slice
44+
Q = 5, // Query tensor (full, not tiled)
45+
MASK_TILE = 6, // Current attention mask tile slice
46+
47+
// Sentinel value for enum range
48+
COUNT
49+
};
50+
2051
namespace function {
2152

2253
// HostFlashAttention structure definition
@@ -43,6 +74,16 @@ struct HostFlashAttention {
4374
// Total KV cache size for tiling
4475
int64_t _kv_cache_size = 0;
4576

77+
// SDPA model parameter index mapping
78+
// Maps semantic SDPA parameter IDs to actual parameter indices in the original SDPA model
79+
// This is created during pattern analysis in from() method
80+
std::map<SDPAInputId, std::size_t> _sdpa_param_index_map;
81+
82+
// HFA Tile Model parameter index mapping
83+
// Maps semantic tile parameter IDs to actual parameter indices in the tile model
84+
// This is created after tile model generation in from() method
85+
std::map<HFATileInputId, std::size_t> _tile_param_index_map;
86+
4687
// Validation helpers
4788
bool is_valid() const {
4889
return _tile_model != nullptr && _final_tile_model != nullptr && _tile_size > 0 && _kv_cache_size > 0;
@@ -66,6 +107,16 @@ struct HostFlashAttentionInfo {
66107
std::vector<Param> params; // past key/value parameters from original SDPA
67108
std::size_t mask_idx = 0u; // mask parameter index in original SDPA model
68109
std::size_t query_size = 0u; // query size for selector compatibility
110+
111+
// Mapping from SDPA parameter identifier to actual parameter index in original SDPA model
112+
// This allows accessing SDPA model parameters by semantic name rather than hardcoded indices
113+
// Populated from function::HostFlashAttention::_sdpa_param_index_map
114+
std::map<SDPAInputId, std::size_t> sdpa_param_index_map;
115+
116+
// Mapping from HFA Tile parameter identifier to actual parameter index in tile model
117+
// This allows accessing tile model parameters by semantic name
118+
// Populated from function::HostFlashAttention::_tile_param_index_map
119+
std::map<HFATileInputId, std::size_t> tile_param_index_map;
69120
};
70121

71122
// Compile-time host flash attention information

0 commit comments

Comments
 (0)