Skip to content

Commit 31fe728

Browse files
committed
Create output name to id mapping.
Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
1 parent b350b5d commit 31fe728

File tree

3 files changed

+105
-4
lines changed

3 files changed

+105
-4
lines changed

src/plugins/intel_npu/src/plugin/npuw/host_flash_attention.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,49 @@ static void build_tile_param_mapping(HostFlashAttention& hfa, const std::shared_
579579
std::cout << "==================================================\n" << std::endl;
580580
}
581581

582+
// ============================================================================
583+
// Helper function: Build tile model output index mapping
584+
// ============================================================================
585+
static void build_tile_output_mapping(HostFlashAttention& hfa, const std::shared_ptr<ov::Model>& tile_model) {
586+
LOG_INFO("Building HFA Tile Model output index mapping...");
587+
588+
// Parse tile model outputs by their tensor names
589+
// Expected output order: [acc, maxx, d]
590+
const auto& tile_outputs = tile_model->outputs();
591+
for (std::size_t i = 0; i < tile_outputs.size(); ++i) {
592+
const auto& tensor_names = tile_outputs[i].get_names();
593+
if (tensor_names.empty()) {
594+
LOG_WARN("Tile model output[" << i << "] has no tensor name");
595+
continue;
596+
}
597+
598+
const std::string& name = *tensor_names.begin();
599+
600+
// Map tensor name to enum ID
601+
if (name == "acc") {
602+
hfa._tile_output_index_map[HFATileOutputId::ACC] = i;
603+
LOG_DEBUG("Mapped ACC to tile output[" << i << "]");
604+
} else if (name == "maxx") {
605+
hfa._tile_output_index_map[HFATileOutputId::MAXX] = i;
606+
LOG_DEBUG("Mapped MAXX to tile output[" << i << "]");
607+
} else if (name == "d") {
608+
hfa._tile_output_index_map[HFATileOutputId::D] = i;
609+
LOG_DEBUG("Mapped D to tile output[" << i << "]");
610+
} else {
611+
LOG_WARN("Unknown tile model output name: " << name);
612+
}
613+
}
614+
615+
// Print the tile output mapping
616+
std::cout << "\n========== HFA Tile Model Output Mapping ==========\n";
617+
std::cout << "Total entries: " << hfa._tile_output_index_map.size() << "\n";
618+
619+
for (const auto& [output_id, output_idx] : hfa._tile_output_index_map) {
620+
std::cout << " " << hfa_tile_output_id_to_string(output_id) << " -> output[" << output_idx << "]" << std::endl;
621+
}
622+
std::cout << "==================================================\n" << std::endl;
623+
}
624+
582625
// ============================================================================
583626
// Helper function: Extract sequence dimension from Concat node
584627
// ============================================================================
@@ -734,6 +777,11 @@ std::optional<HostFlashAttention> HostFlashAttention::from(const std::shared_ptr
734777
// ========================================================================
735778
build_tile_param_mapping(hfa, tile_model);
736779

780+
// ========================================================================
781+
// Step 10: Build tile model output index mapping
782+
// ========================================================================
783+
build_tile_output_mapping(hfa, tile_model);
784+
737785
LOG_INFO("Successfully created HostFlashAttention with query_size=" << query_size << ", tile_size=" << query_size);
738786

739787
return hfa;
@@ -761,6 +809,9 @@ HostFlashAttention::HostFlashAttention(const function::HostFlashAttention& func_
761809
// Copy HFA Tile Model input index mapping from function HFA
762810
_sdpa_attention_info._tile_param_index_map = func_hfa._tile_param_index_map;
763811

812+
// Copy HFA Tile Model output index mapping from function HFA
813+
_sdpa_attention_info._tile_output_index_map = func_hfa._tile_output_index_map;
814+
764815
// Copy query size directly from function HFA (no need to extract from model)
765816
_sdpa_attention_info._query_size = func_hfa._query_size;
766817

src/plugins/intel_npu/src/plugin/npuw/host_flash_attention.hpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ enum class HFATileInputId : uint8_t {
4848
COUNT
4949
};
5050

51+
// HFA Regular Tile Model output tensor identifiers
52+
// Represents the output layout for regular (non-final) tile models
53+
// Output names: [acc, maxx, d]
54+
enum class HFATileOutputId : uint8_t {
55+
ACC = 0, // Accumulated attention output
56+
MAXX = 1, // Maximum values for numerical stability
57+
D = 2, // Normalization denominator
58+
59+
// Sentinel value for enum range
60+
COUNT
61+
};
62+
5163
// Helper functions to convert enum values to string representations for logging/debugging
5264
inline const char* sdpa_input_id_to_string(SDPAInputId id) {
5365
switch (id) {
@@ -89,6 +101,19 @@ inline const char* hfa_tile_input_id_to_string(HFATileInputId id) {
89101
}
90102
}
91103

104+
inline const char* hfa_tile_output_id_to_string(HFATileOutputId id) {
105+
switch (id) {
106+
case HFATileOutputId::ACC:
107+
return "ACC";
108+
case HFATileOutputId::MAXX:
109+
return "MAXX";
110+
case HFATileOutputId::D:
111+
return "D";
112+
default:
113+
return "UNKNOWN";
114+
}
115+
}
116+
92117
namespace function {
93118

94119
// HostFlashAttention structure definition
@@ -124,6 +149,12 @@ struct HostFlashAttention {
124149
// This is created after tile model generation in from() method
125150
std::map<HFATileInputId, std::size_t> _tile_param_index_map;
126151

152+
// Tile model output index mapping
153+
// Maps tile output IDs (UPDATED_ACC, UPDATED_MAX, UPDATED_D) to actual output indices
154+
// Only applicable to regular tile model (final tile has single output at index 0)
155+
// This is created after tile model generation in from() method
156+
std::map<HFATileOutputId, std::size_t> _tile_output_index_map;
157+
127158
// Validation helpers
128159
bool is_valid() const {
129160
return _tile_model != nullptr && _final_tile_model != nullptr && _tile_size > 0;
@@ -157,6 +188,11 @@ struct HostFlashAttentionInfo {
157188
// This allows accessing tile model parameters by semantic name
158189
// Populated from function::HostFlashAttention::_tile_param_index_map
159190
std::map<HFATileInputId, std::size_t> _tile_param_index_map;
191+
192+
// Mapping from HFA Tile output identifier to actual output index in tile model
193+
// This allows accessing tile model outputs by semantic name rather than hardcoded indices
194+
// Populated from function::HostFlashAttention::_tile_output_index_map
195+
std::map<HFATileOutputId, std::size_t> _tile_output_index_map;
160196
};
161197

162198
// Compile-time host flash attention information

src/plugins/intel_npu/src/plugin/npuw/just_sync_infer_request.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,21 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
13771377
const std::size_t tile_idx_max = get_tile_param_idx(ov::npuw::HFATileInputId::PAST_MAX);
13781378
const std::size_t tile_idx_d = get_tile_param_idx(ov::npuw::HFATileInputId::PAST_D);
13791379

1380+
// Pre-cache regular tile model output indices (final tile has only 1 output at index 0)
1381+
// Regular tile outputs: [ACC, MAXX, D]
1382+
const auto& tile_output_map = hfa_desc._sdpa_attention_info._tile_output_index_map;
1383+
auto get_tile_output_idx = [&](ov::npuw::HFATileOutputId output_id) -> std::size_t {
1384+
auto it = tile_output_map.find(output_id);
1385+
if (it == tile_output_map.end()) {
1386+
OPENVINO_THROW("HFA: Tile output mapping not found for output ID: ", static_cast<uint8_t>(output_id));
1387+
}
1388+
return it->second;
1389+
};
1390+
1391+
const std::size_t regular_tile_output_acc = get_tile_output_idx(ov::npuw::HFATileOutputId::ACC);
1392+
const std::size_t regular_tile_output_max = get_tile_output_idx(ov::npuw::HFATileOutputId::MAXX);
1393+
const std::size_t regular_tile_output_d = get_tile_output_idx(ov::npuw::HFATileOutputId::D);
1394+
13801395
auto state_acc = regular_tile_request->get_tensor(hfa_desc._compiled_tile_model->inputs()[tile_idx_acc]);
13811396
auto state_max = regular_tile_request->get_tensor(hfa_desc._compiled_tile_model->inputs()[tile_idx_max]);
13821397
auto state_sum = regular_tile_request->get_tensor(hfa_desc._compiled_tile_model->inputs()[tile_idx_d]);
@@ -1632,10 +1647,9 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
16321647
final_attention_output->copy_to(attention_output_tensor._ptr);
16331648
} else {
16341649
// Regular tile: Update accumulation state for next iteration
1635-
// Tile model outputs: [0] updated_acc, [1] updated_max, [2] updated_sum
1636-
auto output_acc = current_request->get_tensor(current_model->outputs()[0]);
1637-
auto output_max = current_request->get_tensor(current_model->outputs()[1]);
1638-
auto output_sum = current_request->get_tensor(current_model->outputs()[2]);
1650+
auto output_acc = current_request->get_tensor(current_model->outputs()[regular_tile_output_acc]);
1651+
auto output_max = current_request->get_tensor(current_model->outputs()[regular_tile_output_max]);
1652+
auto output_sum = current_request->get_tensor(current_model->outputs()[regular_tile_output_d]);
16391653

16401654
// Copy updated state back to input buffers for next tile
16411655
output_acc->copy_to(state_acc._ptr);

0 commit comments

Comments
 (0)