Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Mar 7, 2022
1 parent 25f9693 commit 8ed6b86
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
1 change: 0 additions & 1 deletion mmdeploy/apis/pytorch2torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def torch2torchscript_impl(model: torch.nn.Module,

deploy_cfg = load_config(deploy_cfg)[0]

# ir_cfg = get_ir_config(deploy_cfg)
backend = get_backend(deploy_cfg).value

patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
Expand Down
15 changes: 10 additions & 5 deletions mmdeploy/backend/torchscript/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ class TorchscriptWrapper(BaseWrapper):
>>> print(outputs)
"""

def __init__(self, model: Union[str, torch.jit.RecursiveScriptModule],
input_names: Optional[Sequence[str]],
output_names: Optional[Sequence[str]]):
def __init__(self,
model: Union[str, torch.jit.RecursiveScriptModule],
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None):
# load custom ops if exist
custom_ops_path = get_ops_path()
if osp.exists(custom_ops_path):
Expand All @@ -60,10 +61,14 @@ def forward(
"""Run forward inference.
Args:
inputs: The input name and tensor pairs.
inputs (torch.Tensor | Sequence[torch.Tensor] | Dict[str,
torch.Tensor]): The input tensor, or tensor sequence, or pairs
of input names and tensors.
Return:
outputs: The output name and tensor pairs.
outputs (torch.Tensor | Sequence[torch.Tensor] | Dict[str,
torch.Tensor]): The input tensor, or tensor sequence, or pairs
of input names and tensors.
"""

is_dict_inputs = isinstance(inputs, Dict)
Expand Down

0 comments on commit 8ed6b86

Please sign in to comment.