@@ -59,6 +59,17 @@ bool ModelHasInputOutputNames(std::shared_ptr<ov::Model> model, const std::strin
5959 return false ;
6060}
6161
62+ std::string GetInputOutputName (std::shared_ptr<ov::Model> ov_model,
63+ const std::vector<std::string>& candidate_names) {
64+ for (const auto & name : candidate_names) {
65+ if (ModelHasInputOutputNames (ov_model, name)) {
66+ return name;
67+ }
68+ }
69+ // Return the first candidate as default if none are found
70+ return candidate_names.empty () ? " " : candidate_names[0 ];
71+ }
72+
6273void FuseCacheReorder (std::shared_ptr<ov::Model> ov_model,
6374 std::vector<std::string>& not_kv_inputs,
6475 const std::vector<std::string>& key_value_input_names,
@@ -67,10 +78,15 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
6778 throw std::runtime_error (" Model already has fused cache" );
6879 }
6980
70- std::string main_input_name = " inputs_embeds" ;
71- if (ModelHasInputOutputNames (ov_model, " input_ids" )) {
72- main_input_name = " input_ids" ;
73- }
81+ // Define input name candidates in priority order
82+ const std::vector<std::string> input_name_candidates = {
83+ " inputs_embeds" , // Default fallback
84+ " input_ids" , // Most common
85+ " input_hidden_states" , // Alternative
86+ " /model/embed_tokens/Gather_output_0" // Specific model type
87+ };
88+
89+ std::string main_input_name = GetInputOutputName (ov_model, input_name_candidates);
7490
7591 auto input_batch = ov_model->input (main_input_name).get_partial_shape ()[0 ];
7692
@@ -130,6 +146,14 @@ void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
130146 key_value_input_names.push_back (name);
131147 found = true ;
132148 break ;
149+ } else if (name.find (" keys" ) != std::string::npos) {
150+ key_value_input_names.push_back (name);
151+ found = true ;
152+ break ;
153+ } else if (name.find (" values" ) != std::string::npos) {
154+ key_value_input_names.push_back (name);
155+ found = true ;
156+ break ;
133157 }
134158 }
135159
0 commit comments