From 0d200e44c3f270f9ad49f808ac66e80e7a80e33d Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Mar 2021 21:21:21 -0700 Subject: [PATCH] [Torch] Remove unnecessary reshapes for batch_matmul (#7675) * [Torch] Remove unnecessary reshapes for batch_matmul * lint * fix * reorder * lint --- python/tvm/relay/frontend/pytorch.py | 29 +++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c709e2b4e7bd..fd0a07e35c15 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1094,8 +1094,7 @@ def instance_norm(self, inputs, input_types): data, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale ) - @staticmethod - def get_dims(data): + def get_dims(self, data): import torch if isinstance(data, _expr.Expr): @@ -1575,15 +1574,31 @@ def matmul(self, inputs, input_types): # When performing a batch matmul, we need to properly handle N-dim shapes. if len(a_shape) > 2 or len(b_shape) > 2: - # Convert a and b into 3 dimensional tensors. - a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) - b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) + # Convert a into a 3 dimensional tensors. + need_reshape_output = False + if len(a_shape) != 3: + a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) + need_reshape_output = True + else: + a = inputs_0 + # Transpose matrix dimensions of b. - b = _op.transpose(b, [0, 2, 1]) + trans_axes = list(range(len(b_shape))) + trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2] + b = _op.transpose(inputs_1, trans_axes) + + # Convert b into a 3 dimensional tensor. Note that the last two dimensions + # are transposed. + if len(b_shape) != 3: + b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]]) + # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) + # Reshape output to original dimensions. - return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) + if need_reshape_output: + return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) + return output # Otherwise a simple dense op will get the job done. if len(b_shape) == 1: