From 33ebc85acbadaef30a5e26765a8a06920ed45d19 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 24 Jun 2020 13:49:43 +0200 Subject: [PATCH] Don't multiply by constant 1 uselessly in dense (#5911) --- python/tvm/relay/frontend/pytorch.py | 4 ++-- tests/python/frontend/pytorch/test_forward.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 92373036d2f2..84b0907d877a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -995,11 +995,11 @@ def _impl(inputs, input_types): beta = inputs[3] alpha = inputs[4] - if not isinstance(alpha, _expr.Expr): + if not isinstance(alpha, _expr.Expr) and alpha != 1: alpha = _create_typed_const(alpha, data_type) data *= alpha - if not isinstance(beta, _expr.Expr): + if not isinstance(beta, _expr.Expr) and beta != 1: beta = _create_typed_const(beta, data_type) weight *= beta diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 12d1260a4a50..0694fa5621ec 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -33,6 +33,18 @@ sys.setrecursionlimit(10000) +def list_ops(expr): + class OpLister(tvm.relay.ExprVisitor): + def visit_op(self, expr): + if expr not in self.node_set: + self.node_list.append(expr) + return super().visit_op(expr) + def list_nodes(self, expr): + self.node_set = {} + self.node_list = [] + self.visit(expr) + return self.node_list + return OpLister().list_nodes(expr) def assert_shapes_match(tru, est): if tru.shape != est.shape: @@ -1047,6 +1059,13 @@ def forward(self, *args): verify_model(Dense1().float().eval(), input_data=input_data) verify_model(Dense2().float().eval(), input_data=input_data) + trace = torch.jit.trace(Dense1(), [input_data]) + mod, params = relay.frontend.from_pytorch( + trace, + [('input', input_shape)], + ) + assert not any([op.name == "multiply" for op in list_ops(mod['main'])]) + def test_forward_dropout(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10]