@@ -126,29 +126,16 @@ class OVInferRequest {
126126 OVTensorPtr GetTensor (const std::string& name);
127127 std::string GetInputTensorName (uint32_t index);
128128
129- // Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set.
129+ // Set tensor call infer req tensor if ort_ptr differs from last set ptr .
130130 void SetTensor (const std::string& name, const ov::element::Type& type, const ov::Shape& shape, void * ort_ptr) {
131131 auto & cached_binding = bindings_cache_[name];
132- if (cached_binding.ort_ptr != ort_ptr) {
133- auto tensor_ptr = std::make_shared<ov::Tensor>(type, shape, const_cast <void *>(ort_ptr));
134- SetTensor (name, tensor_ptr);
135- cached_binding = {tensor_ptr, ort_ptr};
136- } else if (ort_ptr == nullptr ) {
137- // a null ort_ptr is expected for a tensor that has 0 elements.
138- // for example, a tensor of shape=[1, 8, 0, 64], which is valid.
139- // So, we check to see if at least one shape entry is 0.
140- auto contains_zero = [](const ov::Shape& shape) {
141- for (auto & s : shape)
142- if (s == 0 ) return true ;
143- return false ;
144- };
145- if (contains_zero (shape)) {
146- // if there are zero elements (i.e. at least one shape entry is 0),
147- // then create and set the tensor anyway.
148- auto tensor_ptr = std::make_shared<ov::Tensor>(type, shape);
149- SetTensor (name, tensor_ptr);
150- cached_binding = {tensor_ptr, ort_ptr};
151- }
132+ if (cached_binding.ort_ptr != ort_ptr ||
133+ !cached_binding.tensor_ptr ||
134+ cached_binding.tensor_ptr ->get_shape () != shape) {
135+ cached_binding.tensor_ptr .reset ();
136+ auto ov_tensor = std::make_shared<ov::Tensor>(type, shape, const_cast <void *>(ort_ptr));
137+ ovInfReq.set_tensor (name, *ov_tensor);
138+ cached_binding = {ov_tensor, ort_ptr};
152139 }
153140 }
154141
0 commit comments