Skip to content

Commit

Permalink
add device backend check (#886)
Browse files Browse the repository at this point in the history
* add device backend check

* safe check

* only activated for tensorrt and openvino

* resolve comments
  • Loading branch information
AllentDan authored Aug 16, 2022
1 parent 3fa1582 commit 9fbfdd2
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion mmdeploy/apis/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,25 @@
import mmcv

from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
from mmdeploy.utils import get_codebase, get_task_type
from mmdeploy.utils import (get_backend, get_codebase, get_task_type,
parse_device_id)


def check_backend_device(deploy_cfg: mmcv.Config, device: str):
"""Check if device is appropriate for the backend.
Args:
deploy_cfg (str | mmcv.Config): Deployment config file.
device (str): A string specifying device type.
"""
backend = get_backend(deploy_cfg).value
device_id = parse_device_id(device)
mismatch = dict(
tensorrt=lambda id: id == -1,
openvino=lambda id: id > -1,
)
if backend in mismatch and mismatch[backend](device_id):
raise ValueError(f'{device} is invalid for the backend {backend}')


def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
Expand All @@ -17,6 +35,7 @@ def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
Returns:
BaseTask: A task processor.
"""
check_backend_device(deploy_cfg=deploy_cfg, device=device)
codebase_type = get_codebase(deploy_cfg)
import_codebase(codebase_type)
codebase = get_codebase_class(codebase_type)
Expand Down

0 comments on commit 9fbfdd2

Please sign in to comment.