From f838e1f4ae4a88a7b408b7db46688083f770a23f Mon Sep 17 00:00:00 2001 From: Xuxue1 <1915998056@qq.com> Date: Tue, 29 Dec 2020 22:08:17 +0800 Subject: [PATCH] [Torch] Support hard_swish op (#7174) * imp_hardswish * format * fix * hard_swish_inplace test case --- python/tvm/relay/frontend/pytorch.py | 11 +++++++++++ tests/python/frontend/pytorch/test_forward.py | 14 +++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 94ee9282e4fa..8e69739544e5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -790,6 +790,15 @@ def log_sigmoid(self, inputs, input_types): data = inputs[0] return _op.log(_op.tensor.sigmoid(data)) + def hard_swish(self, inputs, input_types): + data = inputs[0] + dtype = input_types[0] + + def _relu6(input_tensor): + return _op.tensor.clip(input_tensor, 0.0, 6.0) + + return data * _relu6(data + _expr.const(3.0, dtype=dtype)) / _expr.const(6.0, dtype=dtype) + def adaptive_avg_pool_2d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] @@ -2266,6 +2275,8 @@ def create_convert_map(self): "aten::bincount": self.bincount, "aten::scatter_add": self.scatter_add, "aten::__not__": self.logical_not, + "aten::hardswish_": self.hard_swish, + "aten::hardswish": self.hard_swish, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 04f08b903bf1..f76c697a2c81 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -181,14 +181,14 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at baseline_input = [inp.cuda() for inp in baseline_input] with torch.no_grad(): - baseline_outputs = baseline_model(*baseline_input) + baseline_outputs = baseline_model(*[input.clone() for input in baseline_input]) if isinstance(baseline_outputs, tuple): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) else: baseline_outputs = (baseline_outputs.cpu().numpy(),) - trace = torch.jit.trace(baseline_model, baseline_input) + trace = torch.jit.trace(baseline_model, [input.clone() for input in baseline_input]) if isinstance(baseline_model, torch.nn.Module): trace = trace.float().eval() @@ -200,7 +200,7 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) - compiled_input = dict(zip(input_names, [inp.cpu().numpy() for inp in baseline_input])) + compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) with tvm.transform.PassContext(opt_level=3): for target, ctx in tvm.testing.enabled_targets(): @@ -3437,6 +3437,13 @@ def test_fn(x, weights=None): verify_trace_model(test_fn, [inp, weights], targets) +def test_hard_swish(): + examples = [torch.rand(8).float(), torch.rand(8, 10).float(), torch.rand(1, 1, 10).float()] + for input in examples: + verify_model(torch.nn.Hardswish().eval(), input_data=input) + verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3603,3 +3610,4 @@ def test_fn(x, weights=None): # Test convert torch script(jit) with specific inputs' types test_convert_torch_script_with_input_types() + test_hard_swish()