@@ -1335,45 +1335,17 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
13351335 auto attention_output_tensor = hfa_outputs.at (0 );
13361336
13371337 // ================================================================================================
1338- // SECTION 4: Index Pre-caching and State Initialization
1338+ // SECTION 4: State Initialization
13391339 // ================================================================================================
13401340
1341- const auto & tile_input_map = hfa_desc._sdpa_attention_info ._tile_param_index_map ;
1342- auto get_tile_param_idx = [&](ov::npuw::HFATileInputId input_id) -> std::size_t {
1343- auto it = tile_input_map.find (input_id);
1344- if (it == tile_input_map.end ()) {
1345- OPENVINO_THROW (" HFA: Tile input mapping not found for input ID: " , static_cast <uint8_t >(input_id));
1346- }
1347- return it->second ;
1348- };
1349-
1350- // Pre-cache tile input indices
1351- const std::size_t tile_idx_q = get_tile_param_idx (ov::npuw::HFATileInputId::Q);
1352- const std::size_t tile_idx_k = get_tile_param_idx (ov::npuw::HFATileInputId::K_TILE);
1353- const std::size_t tile_idx_v = get_tile_param_idx (ov::npuw::HFATileInputId::V_TILE);
1354- const std::size_t tile_idx_mask = get_tile_param_idx (ov::npuw::HFATileInputId::MASK_TILE);
1355- const std::size_t tile_idx_acc = get_tile_param_idx (ov::npuw::HFATileInputId::PAST_ACC);
1356- const std::size_t tile_idx_max = get_tile_param_idx (ov::npuw::HFATileInputId::PAST_MAX);
1357- const std::size_t tile_idx_d = get_tile_param_idx (ov::npuw::HFATileInputId::PAST_D);
1358-
1359- // Pre-cache tile output indices
1360- const auto & tile_output_map = hfa_desc._sdpa_attention_info ._tile_output_index_map ;
1361- auto get_tile_output_idx = [&](ov::npuw::HFATileOutputId output_id) -> std::size_t {
1362- auto it = tile_output_map.find (output_id);
1363- if (it == tile_output_map.end ()) {
1364- OPENVINO_THROW (" HFA: Tile output mapping not found for output ID: " , static_cast <uint8_t >(output_id));
1365- }
1366- return it->second ;
1367- };
1368-
1369- const std::size_t regular_tile_output_acc = get_tile_output_idx (ov::npuw::HFATileOutputId::ACC);
1370- const std::size_t regular_tile_output_max = get_tile_output_idx (ov::npuw::HFATileOutputId::MAXX);
1371- const std::size_t regular_tile_output_d = get_tile_output_idx (ov::npuw::HFATileOutputId::D);
1341+ // Use pre-cached indices (populated during compilation)
1342+ const auto & tile_in = sdpa_info._tile_input_indices ;
1343+ const auto & tile_out = sdpa_info._tile_output_indices ;
13721344
1373- // Initialize state tensors (acc, max, d) to zero/negative infinity
1374- auto state_acc = regular_tile_request->get_tensor (hfa_desc._compiled_tile_model ->inputs ()[tile_idx_acc ]);
1375- auto state_max = regular_tile_request->get_tensor (hfa_desc._compiled_tile_model ->inputs ()[tile_idx_max ]);
1376- auto state_sum = regular_tile_request->get_tensor (hfa_desc._compiled_tile_model ->inputs ()[tile_idx_d ]);
1345+ // Initialize state tensors to zero/negative infinity
1346+ auto state_acc = regular_tile_request->get_tensor (hfa_desc._compiled_tile_model ->inputs ()[tile_in. acc ]);
1347+ auto state_max = regular_tile_request->get_tensor (hfa_desc._compiled_tile_model ->inputs ()[tile_in. max ]);
1348+ auto state_sum = regular_tile_request->get_tensor (hfa_desc._compiled_tile_model ->inputs ()[tile_in. d ]);
13771349
13781350 const auto acc_element_type = state_acc->get_element_type ();
13791351 if (acc_element_type == ov::element::f16 ) {
@@ -1399,8 +1371,8 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
13991371 const size_t present_seq_length = present_key_tensor->get_shape ()[K_SEQ_DIM];
14001372
14011373 // Set query tensor once (constant across all tiles)
1402- regular_tile_request->set_tensor (hfa_desc._compiled_tile_model ->inputs ()[tile_idx_q ], query_tensor);
1403- final_tile_request->set_tensor (hfa_desc._compiled_final_tile_model ->inputs ()[tile_idx_q ], query_tensor);
1374+ regular_tile_request->set_tensor (hfa_desc._compiled_tile_model ->inputs ()[tile_in. q ], query_tensor);
1375+ final_tile_request->set_tensor (hfa_desc._compiled_final_tile_model ->inputs ()[tile_in. q ], query_tensor);
14041376
14051377 // ================================================================================================
14061378 // SECTION 6: Helper Functions
@@ -1505,17 +1477,17 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15051477 }
15061478
15071479 // 7.3: Get tile input buffers
1508- auto k_tile_buffer = current_request->get_tensor (current_model->inputs ()[tile_idx_k ]);
1509- auto v_tile_buffer = current_request->get_tensor (current_model->inputs ()[tile_idx_v ]);
1510- auto mask_tile_buffer = current_request->get_tensor (current_model->inputs ()[tile_idx_mask ]);
1480+ auto k_tile_buffer = current_request->get_tensor (current_model->inputs ()[tile_in. k ]);
1481+ auto v_tile_buffer = current_request->get_tensor (current_model->inputs ()[tile_in. v ]);
1482+ auto mask_tile_buffer = current_request->get_tensor (current_model->inputs ()[tile_in. mask ]);
15111483
15121484 // 7.4: Extract K tile
15131485 if (can_reuse_tensor_zero_copy (source_k_tensor,
15141486 k_tile_buffer,
15151487 K_SEQ_DIM,
15161488 kv_tile_offset,
15171489 current_tile_length)) {
1518- current_request->set_tensor (current_model->inputs ()[tile_idx_k ], source_k_tensor);
1490+ current_request->set_tensor (current_model->inputs ()[tile_in. k ], source_k_tensor);
15191491 } else {
15201492 extract_and_copy_tile (source_k_tensor, k_tile_buffer, K_SEQ_DIM, kv_tile_offset, current_tile_length, " K" );
15211493 }
@@ -1526,7 +1498,7 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15261498 V_SEQ_DIM,
15271499 kv_tile_offset,
15281500 current_tile_length)) {
1529- current_request->set_tensor (current_model->inputs ()[tile_idx_v ], source_v_tensor);
1501+ current_request->set_tensor (current_model->inputs ()[tile_in. v ], source_v_tensor);
15301502 } else {
15311503 extract_and_copy_tile (source_v_tensor, v_tile_buffer, V_SEQ_DIM, kv_tile_offset, current_tile_length, " V" );
15321504 }
@@ -1541,7 +1513,7 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15411513 MASK_KV_SEQ_DIM,
15421514 mask_tile_offset,
15431515 current_tile_length)) {
1544- current_request->set_tensor (current_model->inputs ()[tile_idx_mask ], attention_mask_tensor);
1516+ current_request->set_tensor (current_model->inputs ()[tile_in. mask ], attention_mask_tensor);
15451517 } else {
15461518 extract_and_copy_tile (attention_mask_tensor,
15471519 mask_tile_buffer,
@@ -1552,9 +1524,9 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15521524 }
15531525
15541526 // 7.7: Set state tensors
1555- current_request->set_tensor (current_model->inputs ()[tile_idx_acc ], state_acc);
1556- current_request->set_tensor (current_model->inputs ()[tile_idx_max ], state_max);
1557- current_request->set_tensor (current_model->inputs ()[tile_idx_d ], state_sum);
1527+ current_request->set_tensor (current_model->inputs ()[tile_in. acc ], state_acc);
1528+ current_request->set_tensor (current_model->inputs ()[tile_in. max ], state_max);
1529+ current_request->set_tensor (current_model->inputs ()[tile_in. d ], state_sum);
15581530
15591531 // 7.8: Execute tile inference
15601532 current_request->infer ();
@@ -1564,9 +1536,9 @@ void ov::npuw::JustInferRequest::run_hfa_tiled_inference(std::size_t real_idx, s
15641536 auto final_attention_output = current_request->get_tensor (current_model->outputs ()[0 ]);
15651537 final_attention_output->copy_to (attention_output_tensor._ptr );
15661538 } else {
1567- auto output_acc = current_request->get_tensor (current_model->outputs ()[regular_tile_output_acc ]);
1568- auto output_max = current_request->get_tensor (current_model->outputs ()[regular_tile_output_max ]);
1569- auto output_sum = current_request->get_tensor (current_model->outputs ()[regular_tile_output_d ]);
1539+ auto output_acc = current_request->get_tensor (current_model->outputs ()[tile_out. acc ]);
1540+ auto output_max = current_request->get_tensor (current_model->outputs ()[tile_out. max ]);
1541+ auto output_sum = current_request->get_tensor (current_model->outputs ()[tile_out. d ]);
15701542
15711543 output_acc->copy_to (state_acc._ptr );
15721544 output_max->copy_to (state_max._ptr );
0 commit comments