diff --git a/mmdeploy/apis/utils/utils.py b/mmdeploy/apis/utils/utils.py index fc6fde73cb..f00136618c 100644 --- a/mmdeploy/apis/utils/utils.py +++ b/mmdeploy/apis/utils/utils.py @@ -2,7 +2,25 @@ import mmcv from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase -from mmdeploy.utils import get_codebase, get_task_type +from mmdeploy.utils import (get_backend, get_codebase, get_task_type, + parse_device_id) + + +def check_backend_device(deploy_cfg: mmcv.Config, device: str): + """Check if device is appropriate for the backend. + + Args: + deploy_cfg (str | mmcv.Config): Deployment config file. + device (str): A string specifying device type. + """ + backend = get_backend(deploy_cfg).value + device_id = parse_device_id(device) + mismatch = dict( + tensorrt=lambda id: id == -1, + openvino=lambda id: id > -1, + ) + if backend in mismatch and mismatch[backend](device_id): + raise ValueError(f'{device} is invalid for the backend {backend}') def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config, @@ -17,6 +35,7 @@ def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config, Returns: BaseTask: A task processor. """ + check_backend_device(deploy_cfg=deploy_cfg, device=device) codebase_type = get_codebase(deploy_cfg) import_codebase(codebase_type) codebase = get_codebase_class(codebase_type)