Skip to content

Commit

Permalink
fix onnxruntime wrapper for gpu inference (#123)
Browse files Browse the repository at this point in the history
* fix ncnn wrapper for ort-gpu

* resolve comment

* fix lint
  • Loading branch information
RunningLeon authored Feb 8, 2022
1 parent 51fa2ff commit 9f9670e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mmdeploy/backend/onnxruntime/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ORTWrapper(BaseWrapper):
>>> import torch
>>>
>>> onnx_file = 'model.onnx'
>>> model = ORTWrapper(onnx_file, -1)
>>> model = ORTWrapper(onnx_file, 'cpu')
>>> inputs = dict(input=torch.randn(1, 3, 224, 224, device='cpu'))
>>> outputs = model(inputs)
>>> print(outputs)
Expand Down Expand Up @@ -79,7 +79,9 @@ def forward(self, inputs: Dict[str,
input_tensor = input_tensor.contiguous()
if not self.is_cuda_available:
input_tensor = input_tensor.cpu()
element_type = input_tensor.numpy().dtype
# Avoid unnecessary data transfer between host and device
element_type = input_tensor.new_zeros(
1, device='cpu').numpy().dtype
self.io_binding.bind_input(
name=name,
device_type=self.device_type,
Expand Down

0 comments on commit 9f9670e

Please sign in to comment.