@@ -572,6 +572,172 @@ std::optional<HostFlashAttention> HostFlashAttention::from(const std::shared_ptr
572572 hfa._kv_cache_size = kv_cache_size;
573573 hfa._sdpa_attention = std::move (attention_opt.value ()); // Store SDPA attention metadata
574574
575+ // Build SDPA input parameter index mapping from pattern nodes
576+ // This mapping will be transferred to compiled::HostFlashAttentionInfo
577+ LOG_INFO (" Building SDPA input parameter index mapping..." );
578+
579+ // Helper lambda to safely extract parameter from node (skipping Convert ops)
580+ auto extract_param = [&](const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<ov::op::v0::Parameter> {
581+ auto current = node;
582+ // Skip Convert nodes to get to actual Parameter
583+ while (current && ov::is_type<ov::op::v0::Convert>(current.get ())) {
584+ if (current->get_input_size () > 0 ) {
585+ current = current->get_input_node_shared_ptr (0 );
586+ } else {
587+ break ;
588+ }
589+ }
590+ return ov::as_type_ptr<ov::op::v0::Parameter>(current);
591+ };
592+
593+ // Extract Q (query) parameter - input 0 of MatMul1
594+ if (auto q_param = extract_param (pattern_nodes.matmul1_node ->get_input_node_shared_ptr (0 ))) {
595+ std::size_t q_idx = model->get_parameter_index (q_param);
596+ hfa._sdpa_param_index_map [SDPAInputId::QUERY] = q_idx;
597+ LOG_DEBUG (" Mapped QUERY to parameter index " << q_idx);
598+ }
599+
600+ // Extract past_key parameter - input 0 of past_key_concat
601+ if (pattern_nodes.past_key_concat_node ) {
602+ if (auto past_k_param = extract_param (pattern_nodes.past_key_concat_node ->get_input_node_shared_ptr (0 ))) {
603+ std::size_t past_k_idx = model->get_parameter_index (past_k_param);
604+ hfa._sdpa_param_index_map [SDPAInputId::PAST_KEY] = past_k_idx;
605+ LOG_DEBUG (" Mapped PAST_KEY to parameter index " << past_k_idx);
606+ }
607+
608+ // Extract present_key parameter - input 1 of past_key_concat
609+ if (auto present_k_param = extract_param (pattern_nodes.past_key_concat_node ->get_input_node_shared_ptr (1 ))) {
610+ std::size_t present_k_idx = model->get_parameter_index (present_k_param);
611+ hfa._sdpa_param_index_map [SDPAInputId::PRESENT_KEY] = present_k_idx;
612+ LOG_DEBUG (" Mapped PRESENT_KEY to parameter index " << present_k_idx);
613+ }
614+ }
615+
616+ // Extract past_value parameter - input 0 of past_value_concat
617+ if (pattern_nodes.past_value_concat_node ) {
618+ if (auto past_v_param = extract_param (pattern_nodes.past_value_concat_node ->get_input_node_shared_ptr (0 ))) {
619+ std::size_t past_v_idx = model->get_parameter_index (past_v_param);
620+ hfa._sdpa_param_index_map [SDPAInputId::PAST_VALUE] = past_v_idx;
621+ LOG_DEBUG (" Mapped PAST_VALUE to parameter index " << past_v_idx);
622+ }
623+
624+ // Extract present_value parameter - input 1 of past_value_concat
625+ if (auto present_v_param = extract_param (pattern_nodes.past_value_concat_node ->get_input_node_shared_ptr (1 ))) {
626+ std::size_t present_v_idx = model->get_parameter_index (present_v_param);
627+ hfa._sdpa_param_index_map [SDPAInputId::PRESENT_VALUE] = present_v_idx;
628+ LOG_DEBUG (" Mapped PRESENT_VALUE to parameter index " << present_v_idx);
629+ }
630+ }
631+
632+ // Extract mask parameter - from SDPA attention metadata
633+ std::size_t mask_idx = model->get_parameter_index (hfa._sdpa_attention ._mask );
634+ hfa._sdpa_param_index_map [SDPAInputId::ATTENTION_MASK] = mask_idx;
635+ LOG_DEBUG (" Mapped ATTENTION_MASK to parameter index " << mask_idx);
636+
637+ LOG_INFO (" Built SDPA input mapping with " << hfa._sdpa_param_index_map .size () << " entries" );
638+
639+ // Print the complete mapping table
640+ std::cout << " \n ========== SDPA Input Index Mapping ==========\n " ;
641+ std::cout << " Total entries: " << hfa._sdpa_param_index_map .size () << " \n " ;
642+
643+ // Helper to convert enum to string for printing
644+ auto sdpa_input_id_to_string = [](SDPAInputId id) -> const char * {
645+ switch (id) {
646+ case SDPAInputId::PAST_KEY:
647+ return " PAST_KEY" ;
648+ case SDPAInputId::PAST_VALUE:
649+ return " PAST_VALUE" ;
650+ case SDPAInputId::QUERY:
651+ return " QUERY" ;
652+ case SDPAInputId::PRESENT_KEY:
653+ return " PRESENT_KEY" ;
654+ case SDPAInputId::ATTENTION_MASK:
655+ return " ATTENTION_MASK" ;
656+ case SDPAInputId::PRESENT_VALUE:
657+ return " PRESENT_VALUE" ;
658+ default :
659+ return " UNKNOWN" ;
660+ }
661+ };
662+
663+ for (const auto & [input_id, param_idx] : hfa._sdpa_param_index_map ) {
664+ std::cout << " " << sdpa_input_id_to_string (input_id) << " -> parameter[" << param_idx << " ]" << std::endl;
665+ }
666+ std::cout << " =============================================\n " << std::endl;
667+
668+ // Build HFA Tile Model input index mapping
669+ // This mapping allows accessing tile model inputs by semantic name
670+ LOG_INFO (" Building HFA Tile Model input index mapping..." );
671+
672+ // Parse tile model inputs by their tensor names
673+ // Expected input order: [past_acc, past_max, past_d, k_tile, v_tile, q, mask_tile]
674+ const auto & tile_inputs = tile_model->inputs ();
675+ for (std::size_t i = 0 ; i < tile_inputs.size (); ++i) {
676+ const auto & tensor_names = tile_inputs[i].get_names ();
677+ if (tensor_names.empty ()) {
678+ LOG_WARN (" Tile model input[" << i << " ] has no tensor name" );
679+ continue ;
680+ }
681+
682+ const std::string& name = *tensor_names.begin ();
683+
684+ // Map tensor name to enum ID
685+ if (name == " past_acc" ) {
686+ hfa._tile_param_index_map [HFATileInputId::PAST_ACC] = i;
687+ LOG_DEBUG (" Mapped PAST_ACC to tile input[" << i << " ]" );
688+ } else if (name == " past_max" ) {
689+ hfa._tile_param_index_map [HFATileInputId::PAST_MAX] = i;
690+ LOG_DEBUG (" Mapped PAST_MAX to tile input[" << i << " ]" );
691+ } else if (name == " past_d" ) {
692+ hfa._tile_param_index_map [HFATileInputId::PAST_D] = i;
693+ LOG_DEBUG (" Mapped PAST_D to tile input[" << i << " ]" );
694+ } else if (name == " k_tile" ) {
695+ hfa._tile_param_index_map [HFATileInputId::K_TILE] = i;
696+ LOG_DEBUG (" Mapped K_TILE to tile input[" << i << " ]" );
697+ } else if (name == " v_tile" ) {
698+ hfa._tile_param_index_map [HFATileInputId::V_TILE] = i;
699+ LOG_DEBUG (" Mapped V_TILE to tile input[" << i << " ]" );
700+ } else if (name == " q" ) {
701+ hfa._tile_param_index_map [HFATileInputId::Q] = i;
702+ LOG_DEBUG (" Mapped Q to tile input[" << i << " ]" );
703+ } else if (name == " mask_tile" ) {
704+ hfa._tile_param_index_map [HFATileInputId::MASK_TILE] = i;
705+ LOG_DEBUG (" Mapped MASK_TILE to tile input[" << i << " ]" );
706+ } else {
707+ LOG_WARN (" Unknown tile model input name: " << name);
708+ }
709+ }
710+
711+ // Print the tile input mapping
712+ std::cout << " \n ========== HFA Tile Model Input Mapping ==========\n " ;
713+ std::cout << " Total entries: " << hfa._tile_param_index_map .size () << " \n " ;
714+
715+ auto tile_input_id_to_string = [](HFATileInputId id) -> const char * {
716+ switch (id) {
717+ case HFATileInputId::PAST_ACC:
718+ return " PAST_ACC" ;
719+ case HFATileInputId::PAST_MAX:
720+ return " PAST_MAX" ;
721+ case HFATileInputId::PAST_D:
722+ return " PAST_D" ;
723+ case HFATileInputId::K_TILE:
724+ return " K_TILE" ;
725+ case HFATileInputId::V_TILE:
726+ return " V_TILE" ;
727+ case HFATileInputId::Q:
728+ return " Q" ;
729+ case HFATileInputId::MASK_TILE:
730+ return " MASK_TILE" ;
731+ default :
732+ return " UNKNOWN" ;
733+ }
734+ };
735+
736+ for (const auto & [input_id, input_idx] : hfa._tile_param_index_map ) {
737+ std::cout << " " << tile_input_id_to_string (input_id) << " -> input[" << input_idx << " ]" << std::endl;
738+ }
739+ std::cout << " ==================================================\n " << std::endl;
740+
575741 LOG_INFO (" Successfully created HostFlashAttention" );
576742 std::cout << " HostFlashAttention created with tile_size=" << hfa._tile_size
577743 << " , kv_cache_size=" << hfa._kv_cache_size << std::endl;
@@ -600,6 +766,7 @@ HostFlashAttention::HostFlashAttention(const function::HostFlashAttention& func_
600766 const auto & sdpa_attn = func_hfa._sdpa_attention ;
601767 const auto & original_model = func_hfa._original_model ;
602768
769+ // Build parameter info for past key/value tensors
603770 _sdpa_attention_info.params .reserve (sdpa_attn._inputs .size ());
604771 for (const auto & input : sdpa_attn._inputs ) {
605772 std::size_t p_idx = original_model->get_parameter_index (input.param );
@@ -608,8 +775,16 @@ HostFlashAttention::HostFlashAttention(const function::HostFlashAttention& func_
608775 _sdpa_attention_info.mask_idx = original_model->get_parameter_index (sdpa_attn._mask );
609776 _sdpa_attention_info.query_size = sdpa_attn.query_len ();
610777
778+ // Copy SDPA input index mapping from function HFA (already built in from() method)
779+ _sdpa_attention_info.sdpa_param_index_map = func_hfa._sdpa_param_index_map ;
780+
781+ // Copy HFA Tile Model input index mapping from function HFA
782+ _sdpa_attention_info.tile_param_index_map = func_hfa._tile_param_index_map ;
783+
611784 LOG_INFO (" Extracted HFA config: tile_size=" << _tile_size << " , kv_cache_size=" << _kv_cache_size);
612785 LOG_INFO (" Extracted " << _sdpa_attention_info.params .size () << " past KV parameters from original SDPA model" );
786+ LOG_INFO (" Copied SDPA input mapping with " << _sdpa_attention_info.sdpa_param_index_map .size () << " entries" );
787+ LOG_INFO (" Copied Tile input mapping with " << _sdpa_attention_info.tile_param_index_map .size () << " entries" );
613788
614789 // Note: _compiled_tile_model and _compiled_final_tile_model will be set later by
615790 // compile_host_flash_attention_model()
0 commit comments