@@ -59,6 +59,17 @@ bool ModelHasInputOutputNames(std::shared_ptr<ov::Model> model, const std::strin
5959 return false ;
6060}
6161
62+ std::string GetInputOutputName (std::shared_ptr<ov::Model> ov_model,
63+ const std::vector<std::string>& candidate_names) {
64+ for (const auto & name : candidate_names) {
65+ if (ModelHasInputOutputNames (ov_model, name)) {
66+ return name;
67+ }
68+ }
69+ // Return the first candidate as default if none are found
70+ return candidate_names.empty () ? " " : candidate_names[0 ];
71+ }
72+
6273void FuseCacheReorder (std::shared_ptr<ov::Model> ov_model,
6374 std::vector<std::string>& not_kv_inputs,
6475 const std::vector<std::string>& key_value_input_names,
@@ -67,18 +78,15 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
6778 throw std::runtime_error (" Model already has fused cache" );
6879 }
6980
70- std::string main_input_name = " inputs_embeds" ;
71- if (ModelHasInputOutputNames (ov_model, " input_ids" )) {
72- main_input_name = " input_ids" ;
73- }
74-
75- if (ModelHasInputOutputNames (ov_model, " input_hidden_states" )) {
76- main_input_name = " input_hidden_states" ;
77- }
81+ // Define input name candidates in priority order
82+ const std::vector<std::string> input_name_candidates = {
83+ " inputs_embeds" , // Default fallback
84+ " input_ids" , // Most common
85+ " input_hidden_states" , // Alternative
86+ " /model/embed_tokens/Gather_output_0" // Specific model type
87+ };
7888
79- if (ModelHasInputOutputNames (ov_model, " /model/embed_tokens/Gather_output_0" )) {
80- main_input_name = " /model/embed_tokens/Gather_output_0" ;
81- }
89+ std::string main_input_name = GetInputOutputName (ov_model, input_name_candidates);
8290
8391 auto input_batch = ov_model->input (main_input_name).get_partial_shape ()[0 ];
8492
@@ -131,7 +139,7 @@ void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
131139 std::vector<std::string> not_kv_inputs;
132140 const auto & params = model->get_parameters ();
133141 bool found = false ;
134- for (auto i = 0 ; i < params.size (); i++) {
142+ for (size_t i = 0 ; i < params.size (); i++) {
135143 auto param_name = params.at (i)->output (0 ).get_any_name ();
136144 if (param_name.find (" key_values" ) != std::string::npos) {
137145 key_value_input_names.push_back (param_name);
0 commit comments