Skip to content

Commit b9a73f3

Browse files
authored
Merge pull request #821 from Kotomi-Du/make_stateful_phisilica
CVS-175736-[OVEP] Enable stateful mode for Phi-silica models
2 parents 323cfeb + 2041402 commit b9a73f3

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,11 @@ void OVInferRequest::Infer() {
361361
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
362362
: OVInferRequest(std::move(infer_request)), target_device(device) {
363363
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
364-
if (gpu_or_npu) {
364+
365+
// check if there is input_ids tensors and if the tensor type is int64,
366+
// because logic prefill_use_full_chat_history is only for specific inputs and data type
367+
auto input_ids_opt = FindTensor("input_ids");
368+
if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() == ov::element::i64) {
365369
prefill_use_full_chat_history = true;
366370
}
367371
}

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ bool ModelHasInputOutputNames(std::shared_ptr<ov::Model> model, const std::strin
5959
return false;
6060
}
6161

62+
std::string GetInputOutputName(std::shared_ptr<ov::Model> ov_model,
63+
const std::vector<std::string>& candidate_names) {
64+
for (const auto& name : candidate_names) {
65+
if (ModelHasInputOutputNames(ov_model, name)) {
66+
return name;
67+
}
68+
}
69+
// Return the first candidate as default if none are found
70+
return candidate_names.empty() ? "" : candidate_names[0];
71+
}
72+
6273
void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
6374
std::vector<std::string>& not_kv_inputs,
6475
const std::vector<std::string>& key_value_input_names,
@@ -67,10 +78,15 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
6778
throw std::runtime_error("Model already has fused cache");
6879
}
6980

70-
std::string main_input_name = "inputs_embeds";
71-
if (ModelHasInputOutputNames(ov_model, "input_ids")) {
72-
main_input_name = "input_ids";
73-
}
81+
// Define input name candidates in priority order
82+
const std::vector<std::string> input_name_candidates = {
83+
"inputs_embeds", // Default fallback
84+
"input_ids", // Most common
85+
"input_hidden_states", // Alternative
86+
"/model/embed_tokens/Gather_output_0" // Specific model type
87+
};
88+
89+
std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates);
7490

7591
auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0];
7692

@@ -130,6 +146,14 @@ void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
130146
key_value_input_names.push_back(name);
131147
found = true;
132148
break;
149+
} else if (name.find("keys") != std::string::npos) {
150+
key_value_input_names.push_back(name);
151+
found = true;
152+
break;
153+
} else if (name.find("values") != std::string::npos) {
154+
key_value_input_names.push_back(name);
155+
found = true;
156+
break;
133157
}
134158
}
135159

0 commit comments

Comments
 (0)