Skip to content

Commit

Permalink
Refactor onnx batch_matmul to use dense for certain shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jul 9, 2021
1 parent 683c5eb commit e368f29
Showing 1 changed file with 53 additions and 54 deletions.
107 changes: 53 additions & 54 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
inputs_0 = inputs[0]
inputs_1 = inputs[1]

# Need to check input shape as batch matmul must be supported.
a_shape = shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
b_shape = shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if a_rank > 2 or b_rank > 2:
a_shape = infer_shape(inputs_0)
b_shape = infer_shape(inputs_1)

def flatten_to_3d(x, x_shape):
ndims = infer_shape(x_shape)[0]
newshape = _op.concatenate(
[
_expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - 2], [ndims]),
],
0,
)
out = _op.reshape(x, fold_constant(newshape))
return out
# When performing a batch matmul, we need to properly handle N-dim shapes.
if len(a_shape) > 2 and len(b_shape) > 2:
# Convert a into a 3 dimensional tensors.
need_reshape_output = False
if len(a_shape) != 3:
a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
need_reshape_output = True
else:
a = inputs_0

# Convert a and b into 3 dimensional tensors.
a = flatten_to_3d(inputs[0], a_shape)
b = flatten_to_3d(inputs[1], b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
trans_axes = list(range(len(b_shape)))
trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
b = _op.transpose(inputs_1, trans_axes)

# Convert b into a 3 dimensional tensor. Note that the last two dimensions
# are transposed.
if len(b_shape) != 3:
b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])

# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
elif a_rank < b_rank:
out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
# If its unclear how broadcasting should be applied, the output
# shape is determined by choosing the maximum value from each input.
else:
out_batch = _op.concatenate(
[
_op.maximum(
_op.strided_slice(a_shape, [i], [i + 1]),
_op.strided_slice(b_shape, [i], [i + 1]),
)
for i in range(a_rank - 2)
],
0,
)

# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
],
0,
if need_reshape_output:
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
return output
elif len(a_shape) > 2:
inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])

if len(b_shape) > 2:
trans_axes = list(range(len(b_shape)))
trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
elif len(b_shape) == 2:
input_1 = _op.transpose(inputs_1, axes=(1, 0))
elif len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)

out = _op.nn.dense(inputs_0, input_1)

if len(b_shape) == 1:
out = _op.squeeze(out, axis=[-1])

# Reshape output into a N dimensional tensor when a or b dim > 2
if len(a_shape) > 2:
out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
elif len(b_shape) > 2:
out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
out = _op.reshape(
_op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]]
)
return _op.reshape(output, fold_constant(final_shape))
# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t)

return out


class Mod(OnnxOpConverter):
Expand Down

0 comments on commit e368f29

Please sign in to comment.