diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index 9f8e91f368..63bcb6a963 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -14,6 +14,7 @@ def torch2onnx(img: Any, deploy_cfg: Union[str, mmengine.Config], model_cfg: Union[str, mmengine.Config], model_checkpoint: Optional[str] = None, + append_input: list = None, device: str = 'cuda:0'): """Convert PyTorch model to ONNX model. @@ -42,6 +43,7 @@ def torch2onnx(img: Any, model_cfg (str | mmengine.Config): Model config file or Config object. model_checkpoint (str): A checkpoint path of PyTorch model, defaults to `None`. + append_input (list): Additional inputs other than images, suitable for multimodal models such as text features of Grounded DINO. device (str): A string specifying device type, defaults to 'cuda:0'. """ @@ -68,6 +70,10 @@ def torch2onnx(img: Any, if isinstance(model_inputs, list) and len(model_inputs) == 1: model_inputs = model_inputs[0] + if isinstance(append_input, list): + temp = [model_inputs] + temp.extend(append_input) + model_inputs = temp data_samples = data['data_samples'] input_metas = {'data_samples': data_samples, 'mode': 'predict'}