Skip to content

Commit

Permalink
changing the onnxwrapper script for gpu issue (open-mmlab#532)
Browse files Browse the repository at this point in the history
* changing the onnxwrapper script

* gpu_issue

* Update wrapper.py

* Update wrapper.py

* Update runtime.txt

* Update runtime.txt

* Update wrapper.py
  • Loading branch information
sanjaypavo authored Jun 7, 2022
1 parent 182cc51 commit 2a0fcb6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
11 changes: 5 additions & 6 deletions mmdeploy/backend/onnxruntime/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,17 @@ def __init__(self,
logger.warning(f'The library of onnxruntime custom ops does \
not exist: {ort_custom_op_path}')
device_id = parse_device_id(device)
is_cuda_available = ort.get_device() == 'GPU'
providers = [('CUDAExecutionProvider', {'device_id': device_id})] \
if is_cuda_available else ['CPUExecutionProvider']
providers = ['CPUExecutionProvider'] \
if device == 'cpu' else \
[('CUDAExecutionProvider', {'device_id': device_id})]
sess = ort.InferenceSession(
onnx_file, session_options, providers=providers)
if output_names is None:
output_names = [_.name for _ in sess.get_outputs()]
self.sess = sess
self.io_binding = sess.io_binding()
self.device_id = device_id
self.is_cuda_available = is_cuda_available
self.device_type = 'cuda' if is_cuda_available else 'cpu'
self.device_type = 'cpu' if device == 'cpu' else 'cuda'
super().__init__(output_names)

def forward(self, inputs: Dict[str,
Expand All @@ -77,7 +76,7 @@ def forward(self, inputs: Dict[str,
for name, input_tensor in inputs.items():
# set io binding for inputs/outputs
input_tensor = input_tensor.contiguous()
if not self.is_cuda_available:
if self.device_type == 'cpu':
input_tensor = input_tensor.cpu()
# Avoid unnecessary data transfer between host and device
element_type = input_tensor.new_zeros(
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ h5py
matplotlib
numpy
onnx>=1.8.0
protobuf==3.20.0
six
terminaltables

0 comments on commit 2a0fcb6

Please sign in to comment.