diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f876b1d14fa1d..097a85e33815a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) + inputs_0 = inputs[0] + inputs_1 = inputs[1] + # Need to check input shape as batch matmul must be supported. - a_shape = shape_of(inputs[0]) - a_rank = infer_shape(a_shape)[0] - b_shape = shape_of(inputs[1]) - b_rank = infer_shape(b_shape)[0] - # When performing a batch matmul, we need to properly handle N-dim shapes. - if a_rank > 2 or b_rank > 2: + a_shape = infer_shape(inputs_0) + b_shape = infer_shape(inputs_1) - def flatten_to_3d(x, x_shape): - ndims = infer_shape(x_shape)[0] - newshape = _op.concatenate( - [ - _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), - _op.strided_slice(x_shape, [ndims - 2], [ndims]), - ], - 0, - ) - out = _op.reshape(x, fold_constant(newshape)) - return out + # When performing a batch matmul, we need to properly handle N-dim shapes. + if len(a_shape) > 2 and len(b_shape) > 2: + # Convert a into a 3 dimensional tensors. + need_reshape_output = False + if len(a_shape) != 3: + a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) + need_reshape_output = True + else: + a = inputs_0 - # Convert a and b into 3 dimensional tensors. - a = flatten_to_3d(inputs[0], a_shape) - b = flatten_to_3d(inputs[1], b_shape) # Transpose matrix dimensions of b. - b = _op.transpose(b, [0, 2, 1]) + trans_axes = list(range(len(b_shape))) + trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2] + b = _op.transpose(inputs_1, trans_axes) + + # Convert b into a 3 dimensional tensor. Note that the last two dimensions + # are transposed. + if len(b_shape) != 3: + b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]]) + # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) - # Determine the output batch dimension. - if a_rank > b_rank: - out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) - elif a_rank < b_rank: - out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2]) - # If its unclear how broadcasting should be applied, the output - # shape is determined by choosing the maximum value from each input. - else: - out_batch = _op.concatenate( - [ - _op.maximum( - _op.strided_slice(a_shape, [i], [i + 1]), - _op.strided_slice(b_shape, [i], [i + 1]), - ) - for i in range(a_rank - 2) - ], - 0, - ) + # Reshape output to original dimensions. - final_shape = _op.concatenate( - [ - out_batch, - _op.strided_slice( - a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1] - ), - _op.strided_slice( - b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] - ), - ], - 0, + if need_reshape_output: + return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) + return output + elif len(a_shape) > 2: + inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]]) + + if len(b_shape) > 2: + trans_axes = list(range(len(b_shape))) + trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2] + input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]]) + elif len(b_shape) == 2: + input_1 = _op.transpose(inputs_1, axes=(1, 0)) + elif len(b_shape) == 1: + input_1 = _op.expand_dims(inputs_1, 0, 1) + + out = _op.nn.dense(inputs_0, input_1) + + if len(b_shape) == 1: + out = _op.squeeze(out, axis=[-1]) + + # Reshape output into a N dimensional tensor when a or b dim > 2 + if len(a_shape) > 2: + out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]]) + elif len(b_shape) > 2: + out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]]) + out = _op.reshape( + _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]] ) - return _op.reshape(output, fold_constant(final_shape)) - # Otherwise a simple dense op will get the job done. - input_1_t = _op.transpose(inputs[1], axes=(1, 0)) - return _op.nn.dense(inputs[0], input_1_t) + + return out class Mod(OnnxOpConverter):