diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index cb9ea6a043f4..a31c44a369f9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1580,7 +1580,7 @@ def matmul(self, inputs, input_types): b_shape = self.infer_shape_with_prelude(inputs_1) # When performing a batch matmul, we need to properly handle N-dim shapes. - if len(a_shape) > 2 or len(b_shape) > 2: + if len(a_shape) > 2 and len(b_shape) > 2: # Convert a into a 3 dimensional tensors. need_reshape_output = False if len(a_shape) != 3: @@ -1606,18 +1606,32 @@ def matmul(self, inputs, input_types): if need_reshape_output: return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) return output + elif len(a_shape) > 2: + inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]]) - # Otherwise a simple dense op will get the job done. - if len(b_shape) == 1: - input_1 = _op.expand_dims(inputs_1, 0, 1) - else: + if len(b_shape) > 2: + trans_axes = list(range(len(b_shape))) + trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2] + input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]]) + elif len(b_shape) == 2: input_1 = _op.transpose(inputs_1, axes=(1, 0)) + elif len(b_shape) == 1: + input_1 = _op.expand_dims(inputs_1, 0, 1) out = _op.nn.dense(inputs_0, input_1) if len(b_shape) == 1: out = _op.squeeze(out, axis=[-1]) + # Reshape output into a N dimensional tensor when a or b dim > 2 + if len(a_shape) > 2: + out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]]) + elif len(b_shape) > 2: + out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]]) + out = _op.reshape( + _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]] + ) + return out def expand(self, inputs, input_types): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9ec52987c354..bff5bb60e24f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -162,7 +162,9 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40): return est -def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5): +def verify_model( + model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[] +): """Assert that the output of a compiled model matches with that of its baseline.""" if isinstance(model_name, str): @@ -219,6 +221,20 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at assert_shapes_match(baseline_output, compiled_output) tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol) + + if expected_ops: + + def visit(op): + if isinstance(op, tvm.ir.op.Op): + if op.name in expected_ops: + expected_ops.remove(op.name) + + tvm.relay.analysis.post_order_visit(mod["main"].body, visit) + + if expected_ops: + msg = "TVM Relay do not contain expected ops {}" + raise AssertionError(msg.format(expected_ops)) + del model_name del baseline_model torch.cuda.empty_cache() @@ -3304,17 +3320,24 @@ def forward(self, *args): # matrix x matrix tensor1 = torch.randn(10, 4) tensor2 = torch.randn(4, 10) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) # batched matrix x batched matrix tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(10, 4, 5) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"] + ) # batched matrix x broadcasted matrix tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4, 5) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) + + # broadcasted matrix x batched matrix + tensor1 = torch.randn(10, 4) + tensor2 = torch.randn(3, 4, 5) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) # batched matrix x batched matrix tensor1 = torch.randn(1, 12, 14, 64)