Skip to content

Commit 8968944

Browse files
committed
Extract IO indices during compilation.
Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
1 parent 1d9e498 commit 8968944

File tree

3 files changed

+86
-65
lines changed

3 files changed

+86
-65
lines changed

src/plugins/intel_npu/src/plugin/npuw/host_flash_attention.cpp

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -806,19 +806,57 @@ HostFlashAttention::HostFlashAttention(const function::HostFlashAttention& func_
806806
// Copy SDPA input index mapping from function HFA (already built in from() method)
807807
_sdpa_attention_info._sdpa_param_index_map = func_hfa._sdpa_param_index_map;
808808

809-
// Copy HFA Tile Model input index mapping from function HFA
810-
_sdpa_attention_info._tile_param_index_map = func_hfa._tile_param_index_map;
811-
812-
// Copy HFA Tile Model output index mapping from function HFA
813-
_sdpa_attention_info._tile_output_index_map = func_hfa._tile_output_index_map;
814-
815809
// Copy query size directly from function HFA (no need to extract from model)
816810
_sdpa_attention_info._query_size = func_hfa._query_size;
817811

818812
// Copy K/V sequence dimensions from function HFA
819813
_sdpa_attention_info._k_seq_dim = func_hfa._k_seq_dim;
820814
_sdpa_attention_info._v_seq_dim = func_hfa._v_seq_dim;
821815

816+
// Pre-cache tile indices from function HFA maps
817+
LOG_INFO("Pre-caching tile indices...");
818+
819+
auto get_tile_input_idx = [&](HFATileInputId input_id) -> std::size_t {
820+
auto it = func_hfa._tile_param_index_map.find(input_id);
821+
if (it == func_hfa._tile_param_index_map.end()) {
822+
OPENVINO_THROW("HFA: Tile input mapping not found for input ID: ", static_cast<uint8_t>(input_id));
823+
}
824+
return it->second;
825+
};
826+
827+
auto get_tile_output_idx = [&](HFATileOutputId output_id) -> std::size_t {
828+
auto it = func_hfa._tile_output_index_map.find(output_id);
829+
if (it == func_hfa._tile_output_index_map.end()) {
830+
OPENVINO_THROW("HFA: Tile output mapping not found for output ID: ", static_cast<uint8_t>(output_id));
831+
}
832+
return it->second;
833+
};
834+
835+
// Cache all tile input indices
836+
_sdpa_attention_info._tile_input_indices.q = get_tile_input_idx(HFATileInputId::Q);
837+
_sdpa_attention_info._tile_input_indices.k = get_tile_input_idx(HFATileInputId::K_TILE);
838+
_sdpa_attention_info._tile_input_indices.v = get_tile_input_idx(HFATileInputId::V_TILE);
839+
_sdpa_attention_info._tile_input_indices.mask = get_tile_input_idx(HFATileInputId::MASK_TILE);
840+
_sdpa_attention_info._tile_input_indices.acc = get_tile_input_idx(HFATileInputId::PAST_ACC);
841+
_sdpa_attention_info._tile_input_indices.max = get_tile_input_idx(HFATileInputId::PAST_MAX);
842+
_sdpa_attention_info._tile_input_indices.d = get_tile_input_idx(HFATileInputId::PAST_D);
843+
844+
// Cache all tile output indices
845+
_sdpa_attention_info._tile_output_indices.acc = get_tile_output_idx(HFATileOutputId::ACC);
846+
_sdpa_attention_info._tile_output_indices.max = get_tile_output_idx(HFATileOutputId::MAXX);
847+
_sdpa_attention_info._tile_output_indices.d = get_tile_output_idx(HFATileOutputId::D);
848+
849+
LOG_INFO("Pre-cached indices: inputs[q=" << _sdpa_attention_info._tile_input_indices.q
850+
<< ", k=" << _sdpa_attention_info._tile_input_indices.k
851+
<< ", v=" << _sdpa_attention_info._tile_input_indices.v
852+
<< ", mask=" << _sdpa_attention_info._tile_input_indices.mask
853+
<< ", acc=" << _sdpa_attention_info._tile_input_indices.acc
854+
<< ", max=" << _sdpa_attention_info._tile_input_indices.max
855+
<< ", d=" << _sdpa_attention_info._tile_input_indices.d
856+
<< "], outputs[acc=" << _sdpa_attention_info._tile_output_indices.acc
857+
<< ", max=" << _sdpa_attention_info._tile_output_indices.max
858+
<< ", d=" << _sdpa_attention_info._tile_output_indices.d << "]");
859+
822860
// Note: _compiled_tile_model and _compiled_final_tile_model will be set later by
823861
// compile_host_flash_attention_model()
824862
}

src/plugins/intel_npu/src/plugin/npuw/host_flash_attention.hpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,26 @@ struct HostFlashAttentionInfo {
184184
// Populated from function::HostFlashAttention::_sdpa_param_index_map
185185
std::map<SDPAInputId, std::size_t> _sdpa_param_index_map;
186186

187-
// Mapping from HFA Tile parameter identifier to actual parameter index in tile model
188-
// This allows accessing tile model parameters by semantic name
189-
// Populated from function::HostFlashAttention::_tile_param_index_map
190-
std::map<HFATileInputId, std::size_t> _tile_param_index_map;
191-
192-
// Mapping from HFA Tile output identifier to actual output index in tile model
193-
// This allows accessing tile model outputs by semantic name rather than hardcoded indices
194-
// Populated from function::HostFlashAttention::_tile_output_index_map
195-
std::map<HFATileOutputId, std::size_t> _tile_output_index_map;
187+
// NOTE: Tile input/output maps are not stored to save memory.
188+
// Indices are pre-cached below during compilation.
189+
190+
// Pre-cached tile input indices
191+
struct {
192+
std::size_t q = 0u;
193+
std::size_t k = 0u;
194+
std::size_t v = 0u;
195+
std::size_t mask = 0u;
196+
std::size_t acc = 0u;
197+
std::size_t max = 0u;
198+
std::size_t d = 0u;
199+
} _tile_input_indices;
200+
201+
// Pre-cached tile output indices
202+
struct {
203+
std::size_t acc = 0u;
204+
std::size_t max = 0u;
205+
std::size_t d = 0u;
206+
} _tile_output_indices;
196207
};
197208

198209
// Compile-time host flash attention information

src/plugins/intel_npu/src/plugin/npuw/just_sync_infer_request.cpp

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,45 +1335,17 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
13351335
auto attention_output_tensor = hfa_outputs.at(0);
13361336

13371337
// ================================================================================================
1338-
// SECTION 4: Index Pre-caching and State Initialization
1338+
// SECTION 4: State Initialization
13391339
// ================================================================================================
13401340

1341-
const auto& tile_input_map = hfa_desc._sdpa_attention_info._tile_param_index_map;
1342-
auto get_tile_param_idx = [&](ov::npuw::HFATileInputId input_id) -> std::size_t {
1343-
auto it = tile_input_map.find(input_id);
1344-
if (it == tile_input_map.end()) {
1345-
OPENVINO_THROW("HFA: Tile input mapping not found for input ID: ", static_cast<uint8_t>(input_id));
1346-
}
1347-
return it->second;
1348-
};
1349-
1350-
// Pre-cache tile input indices
1351-
const std::size_t tile_idx_q = get_tile_param_idx(ov::npuw::HFATileInputId::Q);
1352-
const std::size_t tile_idx_k = get_tile_param_idx(ov::npuw::HFATileInputId::K_TILE);
1353-
const std::size_t tile_idx_v = get_tile_param_idx(ov::npuw::HFATileInputId::V_TILE);
1354-
const std::size_t tile_idx_mask = get_tile_param_idx(ov::npuw::HFATileInputId::MASK_TILE);
1355-
const std::size_t tile_idx_acc = get_tile_param_idx(ov::npuw::HFATileInputId::PAST_ACC);
1356-
const std::size_t tile_idx_max = get_tile_param_idx(ov::npuw::HFATileInputId::PAST_MAX);
1357-
const std::size_t tile_idx_d = get_tile_param_idx(ov::npuw::HFATileInputId::PAST_D);
1358-
1359-
// Pre-cache tile output indices
1360-
const auto& tile_output_map = hfa_desc._sdpa_attention_info._tile_output_index_map;
1361-
auto get_tile_output_idx = [&](ov::npuw::HFATileOutputId output_id) -> std::size_t {
1362-
auto it = tile_output_map.find(output_id);
1363-
if (it == tile_output_map.end()) {
1364-
OPENVINO_THROW("HFA: Tile output mapping not found for output ID: ", static_cast<uint8_t>(output_id));
1365-
}
1366-
return it->second;
1367-
};
1368-
1369-
const std::size_t regular_tile_output_acc = get_tile_output_idx(ov::npuw::HFATileOutputId::ACC);
1370-
const std::size_t regular_tile_output_max = get_tile_output_idx(ov::npuw::HFATileOutputId::MAXX);
1371-
const std::size_t regular_tile_output_d = get_tile_output_idx(ov::npuw::HFATileOutputId::D);
1341+
// Use pre-cached indices (populated during compilation)
1342+
const auto& tile_in = sdpa_info._tile_input_indices;
1343+
const auto& tile_out = sdpa_info._tile_output_indices;
13721344

1373-
// Initialize state tensors (acc, max, d) to zero/negative infinity
1374-
auto state_acc = regular_tile_request->get_tensor(hfa_desc._compiled_tile_model->inputs()[tile_idx_acc]);
1375-
auto state_max = regular_tile_request->get_tensor(hfa_desc._compiled_tile_model->inputs()[tile_idx_max]);
1376-
auto state_sum = regular_tile_request->get_tensor(hfa_desc._compiled_tile_model->inputs()[tile_idx_d]);
1345+
// Initialize state tensors to zero/negative infinity
1346+
auto state_acc = regular_tile_request->get_tensor(hfa_desc._compiled_tile_model->inputs()[tile_in.acc]);
1347+
auto state_max = regular_tile_request->get_tensor(hfa_desc._compiled_tile_model->inputs()[tile_in.max]);
1348+
auto state_sum = regular_tile_request->get_tensor(hfa_desc._compiled_tile_model->inputs()[tile_in.d]);
13771349

13781350
const auto acc_element_type = state_acc->get_element_type();
13791351
if (acc_element_type == ov::element::f16) {
@@ -1399,8 +1371,8 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
13991371
const size_t present_seq_length = present_key_tensor->get_shape()[K_SEQ_DIM];
14001372

14011373
// Set query tensor once (constant across all tiles)
1402-
regular_tile_request->set_tensor(hfa_desc._compiled_tile_model->inputs()[tile_idx_q], query_tensor);
1403-
final_tile_request->set_tensor(hfa_desc._compiled_final_tile_model->inputs()[tile_idx_q], query_tensor);
1374+
regular_tile_request->set_tensor(hfa_desc._compiled_tile_model->inputs()[tile_in.q], query_tensor);
1375+
final_tile_request->set_tensor(hfa_desc._compiled_final_tile_model->inputs()[tile_in.q], query_tensor);
14041376

14051377
// ================================================================================================
14061378
// SECTION 6: Helper Functions
@@ -1505,17 +1477,17 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15051477
}
15061478

15071479
// 7.3: Get tile input buffers
1508-
auto k_tile_buffer = current_request->get_tensor(current_model->inputs()[tile_idx_k]);
1509-
auto v_tile_buffer = current_request->get_tensor(current_model->inputs()[tile_idx_v]);
1510-
auto mask_tile_buffer = current_request->get_tensor(current_model->inputs()[tile_idx_mask]);
1480+
auto k_tile_buffer = current_request->get_tensor(current_model->inputs()[tile_in.k]);
1481+
auto v_tile_buffer = current_request->get_tensor(current_model->inputs()[tile_in.v]);
1482+
auto mask_tile_buffer = current_request->get_tensor(current_model->inputs()[tile_in.mask]);
15111483

15121484
// 7.4: Extract K tile
15131485
if (can_reuse_tensor_zero_copy(source_k_tensor,
15141486
k_tile_buffer,
15151487
K_SEQ_DIM,
15161488
kv_tile_offset,
15171489
current_tile_length)) {
1518-
current_request->set_tensor(current_model->inputs()[tile_idx_k], source_k_tensor);
1490+
current_request->set_tensor(current_model->inputs()[tile_in.k], source_k_tensor);
15191491
} else {
15201492
extract_and_copy_tile(source_k_tensor, k_tile_buffer, K_SEQ_DIM, kv_tile_offset, current_tile_length, "K");
15211493
}
@@ -1526,7 +1498,7 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15261498
V_SEQ_DIM,
15271499
kv_tile_offset,
15281500
current_tile_length)) {
1529-
current_request->set_tensor(current_model->inputs()[tile_idx_v], source_v_tensor);
1501+
current_request->set_tensor(current_model->inputs()[tile_in.v], source_v_tensor);
15301502
} else {
15311503
extract_and_copy_tile(source_v_tensor, v_tile_buffer, V_SEQ_DIM, kv_tile_offset, current_tile_length, "V");
15321504
}
@@ -1541,7 +1513,7 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15411513
MASK_KV_SEQ_DIM,
15421514
mask_tile_offset,
15431515
current_tile_length)) {
1544-
current_request->set_tensor(current_model->inputs()[tile_idx_mask], attention_mask_tensor);
1516+
current_request->set_tensor(current_model->inputs()[tile_in.mask], attention_mask_tensor);
15451517
} else {
15461518
extract_and_copy_tile(attention_mask_tensor,
15471519
mask_tile_buffer,
@@ -1552,9 +1524,9 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15521524
}
15531525

15541526
// 7.7: Set state tensors
1555-
current_request->set_tensor(current_model->inputs()[tile_idx_acc], state_acc);
1556-
current_request->set_tensor(current_model->inputs()[tile_idx_max], state_max);
1557-
current_request->set_tensor(current_model->inputs()[tile_idx_d], state_sum);
1527+
current_request->set_tensor(current_model->inputs()[tile_in.acc], state_acc);
1528+
current_request->set_tensor(current_model->inputs()[tile_in.max], state_max);
1529+
current_request->set_tensor(current_model->inputs()[tile_in.d], state_sum);
15581530

15591531
// 7.8: Execute tile inference
15601532
current_request->infer();
@@ -1564,9 +1536,9 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15641536
auto final_attention_output = current_request->get_tensor(current_model->outputs()[0]);
15651537
final_attention_output->copy_to(attention_output_tensor._ptr);
15661538
} else {
1567-
auto output_acc = current_request->get_tensor(current_model->outputs()[regular_tile_output_acc]);
1568-
auto output_max = current_request->get_tensor(current_model->outputs()[regular_tile_output_max]);
1569-
auto output_sum = current_request->get_tensor(current_model->outputs()[regular_tile_output_d]);
1539+
auto output_acc = current_request->get_tensor(current_model->outputs()[tile_out.acc]);
1540+
auto output_max = current_request->get_tensor(current_model->outputs()[tile_out.max]);
1541+
auto output_sum = current_request->get_tensor(current_model->outputs()[tile_out.d]);
15701542

15711543
output_acc->copy_to(state_acc._ptr);
15721544
output_max->copy_to(state_max._ptr);

0 commit comments

Comments
 (0)