diff --git a/mmdeploy/backend/onnxruntime/wrapper.py b/mmdeploy/backend/onnxruntime/wrapper.py index 51116716cd..4239853e2d 100644 --- a/mmdeploy/backend/onnxruntime/wrapper.py +++ b/mmdeploy/backend/onnxruntime/wrapper.py @@ -27,7 +27,7 @@ class ORTWrapper(BaseWrapper): >>> import torch >>> >>> onnx_file = 'model.onnx' - >>> model = ORTWrapper(onnx_file, -1) + >>> model = ORTWrapper(onnx_file, 'cpu') >>> inputs = dict(input=torch.randn(1, 3, 224, 224, device='cpu')) >>> outputs = model(inputs) >>> print(outputs) @@ -79,7 +79,9 @@ def forward(self, inputs: Dict[str, input_tensor = input_tensor.contiguous() if not self.is_cuda_available: input_tensor = input_tensor.cpu() - element_type = input_tensor.numpy().dtype + # Avoid unnecessary data transfer between host and device + element_type = input_tensor.new_zeros( + 1, device='cpu').numpy().dtype self.io_binding.bind_input( name=name, device_type=self.device_type,