Skip to content

Commit 25c6976

Browse files
committed
address PR review
1 parent 1e132f3 commit 25c6976

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, s
365365
// check if there is input_ids tensors and if the tensor type is int64,
366366
// because logic prefill_use_full_chat_history is only for specific inputs and data type
367367
auto input_ids_opt = FindTensor("input_ids");
368-
if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() != ov::element::i64) {
368+
if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() == ov::element::i64) {
369369
prefill_use_full_chat_history = true;
370370
}
371371
}

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ bool ModelHasInputOutputNames(std::shared_ptr<ov::Model> model, const std::strin
5959
return false;
6060
}
6161

62+
std::string GetInputOutputName(std::shared_ptr<ov::Model> ov_model,
63+
const std::vector<std::string>& candidate_names) {
64+
for (const auto& name : candidate_names) {
65+
if (ModelHasInputOutputNames(ov_model, name)) {
66+
return name;
67+
}
68+
}
69+
// Return the first candidate as default if none are found
70+
return candidate_names.empty() ? "" : candidate_names[0];
71+
}
72+
6273
void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
6374
std::vector<std::string>& not_kv_inputs,
6475
const std::vector<std::string>& key_value_input_names,
@@ -67,18 +78,15 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
6778
throw std::runtime_error("Model already has fused cache");
6879
}
6980

70-
std::string main_input_name = "inputs_embeds";
71-
if (ModelHasInputOutputNames(ov_model, "input_ids")) {
72-
main_input_name = "input_ids";
73-
}
74-
75-
if (ModelHasInputOutputNames(ov_model, "input_hidden_states")) {
76-
main_input_name = "input_hidden_states";
77-
}
81+
// Define input name candidates in priority order
82+
const std::vector<std::string> input_name_candidates = {
83+
"inputs_embeds", // Default fallback
84+
"input_ids", // Most common
85+
"input_hidden_states", // Alternative
86+
"/model/embed_tokens/Gather_output_0" // Specific model type
87+
};
7888

79-
if (ModelHasInputOutputNames(ov_model, "/model/embed_tokens/Gather_output_0")) {
80-
main_input_name = "/model/embed_tokens/Gather_output_0";
81-
}
89+
std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates);
8290

8391
auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0];
8492

@@ -131,7 +139,7 @@ void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
131139
std::vector<std::string> not_kv_inputs;
132140
const auto& params = model->get_parameters();
133141
bool found = false;
134-
for (auto i = 0; i < params.size(); i++) {
142+
for (size_t i = 0; i < params.size(); i++) {
135143
auto param_name = params.at(i)->output(0).get_any_name();
136144
if (param_name.find("key_values") != std::string::npos) {
137145
key_value_input_names.push_back(param_name);

0 commit comments

Comments
 (0)