Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add MatMulInteger16 contrib op #9186

Merged
merged 12 commits into from
Nov 24, 2021
Merged
99 changes: 84 additions & 15 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,22 @@ def get_scalar(x, params, dtype="float32"):
return _op.cast(x, dtype)


def flatten_to_nd(x, x_shape, nd=3):
"""Helper to flatten multi dimensional arrays to specific dimension"""
ndims = infer_shape(x_shape)[0]
if ndims == nd:
return x
newshape = _op.concatenate(
[
_expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
],
0,
)
out = _op.reshape(x, fold_constant(newshape))
return out


class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""

Expand Down Expand Up @@ -803,21 +819,6 @@ def _impl_v1(cls, inputs, attr, params):
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if a_rank > 2 or b_rank > 2:

def flatten_to_nd(x, x_shape, nd=3):
ndims = infer_shape(x_shape)[0]
if ndims == nd:
return x
newshape = _op.concatenate(
[
_expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
],
0,
)
out = _op.reshape(x, fold_constant(newshape))
return out

b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
Expand Down Expand Up @@ -873,6 +874,73 @@ def flatten_to_nd(x, x_shape, nd=3):
return _op.nn.dense(inputs[0], input_1_t)


class MatMulInteger16(OnnxOpConverter):
"""Operator converter for MatMulInteger16 from Microsoft onnxruntime contrib opset."""

@classmethod
def _impl_v10(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
a_shape = shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
b_shape = shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
a_dtype = infer_type(inputs[0]).checked_type.dtype
b_dtype = infer_type(inputs[1]).checked_type.dtype
# Check input data types
assert a_dtype in ("int16", "uint16"), "MatMulInteger16: invalid dtype for first input"
assert b_dtype in ("int16", "uint16"), "MatMulInteger16: invalid dtype for second input"
out_dtype = "int32"
if a_dtype == "uint16" and b_dtype == "uint16":
out_dtype = "uint32"
if a_rank > 2 or b_rank > 2:
b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b, out_dtype=out_dtype)
else:
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
elif a_rank < b_rank:
out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
# If its unclear how broadcasting should be applied, the output
# shape is determined by choosing the maximum value from each input.
else:
out_batch = _op.concatenate(
[
_op.maximum(
_op.strided_slice(a_shape, [i], [i + 1]),
_op.strided_slice(b_shape, [i], [i + 1]),
)
for i in range(a_rank - 2)
],
0,
)
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
],
0,
)
return _op.reshape(output, fold_constant(final_shape))
# Use relay matmul
return _op.nn.matmul(inputs[0], inputs[1], out_dtype=out_dtype)


class Mod(OnnxOpConverter):
"""Operator converter for Mod."""

Expand Down Expand Up @@ -4144,6 +4212,7 @@ def _get_convert_map(opset):
"Softsign": Softsign.get_converter(opset),
"Gemm": Gemm.get_converter(opset),
"MatMul": MatMul.get_converter(opset),
"MatMulInteger16": MatMulInteger16.get_converter(opset),
"Mod": Mod.get_converter(opset),
"Xor": Renamer("logical_xor"),
# defs/nn
Expand Down
42 changes: 42 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,47 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
)


@tvm.testing.parametrize_targets
def test_matmulinteger16(target, dev):
def verify_matmulinteger16(a_shape, b_shape, out_shape):
a_dtype = "int16"
b_dtype = "int16"
low = -10
high = 10

a_proto = TensorProto.INT16
b_proto = TensorProto.INT16
out_proto = TensorProto.INT32
a_array = np.random.randint(low, high, size=a_shape).astype(a_dtype)
b_array = np.random.randint(low, high, size=b_shape).astype(b_dtype)

mul_node = helper.make_node("MatMulInteger16", ["a", "b"], ["out"], domain="com.microsoft")

graph = helper.make_graph(
[mul_node],
"matmuli16_test",
inputs=[
helper.make_tensor_value_info("a", a_proto, list(a_shape)),
helper.make_tensor_value_info("b", b_proto, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("out", out_proto, list(out_shape))],
)

model = helper.make_model(graph, producer_name="matmuli16_test")
verify_with_ort_with_inputs(model, [a_array, b_array], target=target, dev=dev)

# 2D computation to verify matmul op
verify_matmulinteger16((4, 3), (3, 4), (4, 4))
verify_matmulinteger16((5, 7), (7, 8), (5, 8))
# Verify 3D matmul using batch_matmul op
verify_matmulinteger16((2, 4, 3), (1, 3, 4), (2, 4, 4))
verify_matmulinteger16((1, 4, 3), (2, 3, 4), (2, 4, 4))
# Test implicit broadcasting
verify_matmulinteger16((2, 3, 5, 3), (2, 3, 3, 5), (2, 3, 5, 5))
verify_matmulinteger16((2, 7, 3), (3, 7), (2, 7, 7))
verify_matmulinteger16((2, 3, 4, 3), (3, 4), (2, 3, 4, 4))


def verify_simple_dynamic_model(a_shape, b_shape, target, dev):
def verify_model(model, a_shape, b_shape):
a_array = np.random.uniform(size=a_shape).astype("float32")
Expand Down Expand Up @@ -5801,6 +5842,7 @@ def repeat(N, D):
test_onehot()
test_gemm()
test_matmul()
test_matmulinteger16()
test_gather()
test_gatherelements()
test_gather_nd()
Expand Down