@@ -579,6 +579,49 @@ static void build_tile_param_mapping(HostFlashAttention& hfa, const std::shared_
579579 std::cout << " ==================================================\n " << std::endl;
580580}
581581
582+ // ============================================================================
583+ // Helper function: Build tile model output index mapping
584+ // ============================================================================
585+ static void build_tile_output_mapping (HostFlashAttention& hfa, const std::shared_ptr<ov::Model>& tile_model) {
586+ LOG_INFO (" Building HFA Tile Model output index mapping..." );
587+
588+ // Parse tile model outputs by their tensor names
589+ // Expected output order: [acc, maxx, d]
590+ const auto & tile_outputs = tile_model->outputs ();
591+ for (std::size_t i = 0 ; i < tile_outputs.size (); ++i) {
592+ const auto & tensor_names = tile_outputs[i].get_names ();
593+ if (tensor_names.empty ()) {
594+ LOG_WARN (" Tile model output[" << i << " ] has no tensor name" );
595+ continue ;
596+ }
597+
598+ const std::string& name = *tensor_names.begin ();
599+
600+ // Map tensor name to enum ID
601+ if (name == " acc" ) {
602+ hfa._tile_output_index_map [HFATileOutputId::ACC] = i;
603+ LOG_DEBUG (" Mapped ACC to tile output[" << i << " ]" );
604+ } else if (name == " maxx" ) {
605+ hfa._tile_output_index_map [HFATileOutputId::MAXX] = i;
606+ LOG_DEBUG (" Mapped MAXX to tile output[" << i << " ]" );
607+ } else if (name == " d" ) {
608+ hfa._tile_output_index_map [HFATileOutputId::D] = i;
609+ LOG_DEBUG (" Mapped D to tile output[" << i << " ]" );
610+ } else {
611+ LOG_WARN (" Unknown tile model output name: " << name);
612+ }
613+ }
614+
615+ // Print the tile output mapping
616+ std::cout << " \n ========== HFA Tile Model Output Mapping ==========\n " ;
617+ std::cout << " Total entries: " << hfa._tile_output_index_map .size () << " \n " ;
618+
619+ for (const auto & [output_id, output_idx] : hfa._tile_output_index_map ) {
620+ std::cout << " " << hfa_tile_output_id_to_string (output_id) << " -> output[" << output_idx << " ]" << std::endl;
621+ }
622+ std::cout << " ==================================================\n " << std::endl;
623+ }
624+
582625// ============================================================================
583626// Helper function: Extract sequence dimension from Concat node
584627// ============================================================================
@@ -734,6 +777,11 @@ std::optional<HostFlashAttention> HostFlashAttention::from(const std::shared_ptr
734777 // ========================================================================
735778 build_tile_param_mapping (hfa, tile_model);
736779
780+ // ========================================================================
781+ // Step 10: Build tile model output index mapping
782+ // ========================================================================
783+ build_tile_output_mapping (hfa, tile_model);
784+
737785 LOG_INFO (" Successfully created HostFlashAttention with query_size=" << query_size << " , tile_size=" << query_size);
738786
739787 return hfa;
@@ -761,6 +809,9 @@ HostFlashAttention::HostFlashAttention(const function::HostFlashAttention& func_
761809 // Copy HFA Tile Model input index mapping from function HFA
762810 _sdpa_attention_info._tile_param_index_map = func_hfa._tile_param_index_map ;
763811
812+ // Copy HFA Tile Model output index mapping from function HFA
813+ _sdpa_attention_info._tile_output_index_map = func_hfa._tile_output_index_map ;
814+
764815 // Copy query size directly from function HFA (no need to extract from model)
765816 _sdpa_attention_info._query_size = func_hfa._query_size ;
766817
0 commit comments