Skip to content

Commit

Permalink
[Torch] Remove unnecessary reshapes for batch_matmul (apache#7675)
Browse files Browse the repository at this point in the history
* [Torch] Remove unnecessary reshapes for batch_matmul

* lint

* fix

* reorder

* lint
  • Loading branch information
comaniac authored and Trevor Morris committed May 6, 2021
1 parent 0f2bd24 commit 0d200e4
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,8 +1094,7 @@ def instance_norm(self, inputs, input_types):
data, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale
)

@staticmethod
def get_dims(data):
def get_dims(self, data):
import torch

if isinstance(data, _expr.Expr):
Expand Down Expand Up @@ -1575,15 +1574,31 @@ def matmul(self, inputs, input_types):

# When performing a batch matmul, we need to properly handle N-dim shapes.
if len(a_shape) > 2 or len(b_shape) > 2:
# Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]])
# Convert a into a 3 dimensional tensors.
need_reshape_output = False
if len(a_shape) != 3:
a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
need_reshape_output = True
else:
a = inputs_0

# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
trans_axes = list(range(len(b_shape)))
trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
b = _op.transpose(inputs_1, trans_axes)

# Convert b into a 3 dimensional tensor. Note that the last two dimensions
# are transposed.
if len(b_shape) != 3:
b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])

# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)

# Reshape output to original dimensions.
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
if need_reshape_output:
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
return output

# Otherwise a simple dense op will get the job done.
if len(b_shape) == 1:
Expand Down

0 comments on commit 0d200e4

Please sign in to comment.