diff --git a/mmdeploy/apis/pytorch2torchscript.py b/mmdeploy/apis/pytorch2torchscript.py index 3e27ccdba2..c984892360 100644 --- a/mmdeploy/apis/pytorch2torchscript.py +++ b/mmdeploy/apis/pytorch2torchscript.py @@ -33,7 +33,6 @@ def torch2torchscript_impl(model: torch.nn.Module, deploy_cfg = load_config(deploy_cfg)[0] - # ir_cfg = get_ir_config(deploy_cfg) backend = get_backend(deploy_cfg).value patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) diff --git a/mmdeploy/backend/torchscript/wrapper.py b/mmdeploy/backend/torchscript/wrapper.py index 93e57b5f51..668ab23aa0 100644 --- a/mmdeploy/backend/torchscript/wrapper.py +++ b/mmdeploy/backend/torchscript/wrapper.py @@ -35,9 +35,10 @@ class TorchscriptWrapper(BaseWrapper): >>> print(outputs) """ - def __init__(self, model: Union[str, torch.jit.RecursiveScriptModule], - input_names: Optional[Sequence[str]], - output_names: Optional[Sequence[str]]): + def __init__(self, + model: Union[str, torch.jit.RecursiveScriptModule], + input_names: Optional[Sequence[str]] = None, + output_names: Optional[Sequence[str]] = None): # load custom ops if exist custom_ops_path = get_ops_path() if osp.exists(custom_ops_path): @@ -60,10 +61,14 @@ def forward( """Run forward inference. Args: - inputs: The input name and tensor pairs. + inputs (torch.Tensor | Sequence[torch.Tensor] | Dict[str, + torch.Tensor]): The input tensor, or tensor sequence, or pairs + of input names and tensors. Return: - outputs: The output name and tensor pairs. + outputs (torch.Tensor | Sequence[torch.Tensor] | Dict[str, + torch.Tensor]): The input tensor, or tensor sequence, or pairs + of input names and tensors. """ is_dict_inputs = isinstance(inputs, Dict)