Skip to content

Commit 1e132f3

Browse files
committed
unify the code
1 parent 65bbecc commit 1e132f3

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,14 @@ void OVInferRequest::Infer() {
360360

361361
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
362362
: OVInferRequest(std::move(infer_request)), target_device(device) {
363-
// bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
364-
// if (gpu_or_npu) {
365-
// prefill_use_full_chat_history = true;
366-
// }
363+
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
364+
365+
// check if there is input_ids tensors and if the tensor type is int64,
366+
// because logic prefill_use_full_chat_history is only for specific inputs and data type
367+
auto input_ids_opt = FindTensor("input_ids");
368+
if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() != ov::element::i64) {
369+
prefill_use_full_chat_history = true;
370+
}
367371
}
368372

369373
void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type,

0 commit comments

Comments
 (0)