From e83d9b99025129651de0c8acc75bb07dea9e6b1a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 10 Jul 2021 02:24:27 +0000 Subject: [PATCH 1/2] [ONNX]Add batch_matmul to dense optimization --- python/tvm/relay/frontend/onnx.py | 25 ++++++++++++++-------- tests/python/frontend/onnx/test_forward.py | 7 +++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f876b1d14fa1..706a060d1f0b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -678,25 +678,32 @@ def _impl_v1(cls, inputs, attr, params): # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: - def flatten_to_3d(x, x_shape): + def flatten_to_nd(x, x_shape, nd=3): ndims = infer_shape(x_shape)[0] newshape = _op.concatenate( [ _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), - _op.strided_slice(x_shape, [ndims - 2], [ndims]), + _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), ], 0, ) out = _op.reshape(x, fold_constant(newshape)) return out - # Convert a and b into 3 dimensional tensors. - a = flatten_to_3d(inputs[0], a_shape) - b = flatten_to_3d(inputs[1], b_shape) - # Transpose matrix dimensions of b. - b = _op.transpose(b, [0, 2, 1]) - # Perform a batch matmul. - output = _op.nn.batch_matmul(a, b) + b_type = infer_type(inputs[1]) + # Convert to dense if the second matrix is 2d and non-dynamic + if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + a = flatten_to_nd(inputs[0], a_shape, 2) + b = _op.transpose(inputs[1]) + output = _op.nn.dense(a, b) + else: + # Convert a and b into 3 dimensional tensors. + a = flatten_to_nd(inputs[0], a_shape, 3) + b = flatten_to_nd(inputs[1], b_shape, 3) + # Transpose matrix dimensions of b. + b = _op.transpose(b, [0, 2, 1]) + # Perform a batch matmul. + output = _op.nn.batch_matmul(a, b) # Determine the output batch dimension. if a_rank > b_rank: out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c5407697de46..2d9666706a5d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1173,8 +1173,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, target, dev): verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, targets=[target]) -# TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed -@tvm.testing.parametrize_targets("llvm") +@tvm.testing.uses_gpu def test_batch_matmul(target, dev): verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4), target, dev) verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4), target, dev) @@ -1183,6 +1182,8 @@ def test_batch_matmul(target, dev): verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4), target, dev) verify_batch_matmul((2, 4, 3), (1, 3, 4), (2, 4, 4), target, dev) verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4), target, dev) + verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32), target, dev) + verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16), target, dev) def verify_simple_dynamic_model(a_shape, b_shape, target, dev): @@ -1221,7 +1222,6 @@ def verify_model(ex, a_shape, b_shape): b_anys = [relay.Any()] * len(b_shape) mod, params = relay.frontend.from_onnx(model, {"a": a_anys, "b": b_anys}) - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) verify_model(ex, a_shape, b_shape) verify_model(ex, [a * 2 for a in a_shape], [b * 2 for b in b_shape]) @@ -4955,3 +4955,4 @@ def test_qlinearadd(): test_reverse_sequence() test_eyelike() test_qlinearconv() + test_batch_matmul() From c541682668fe7c8c90d571dc767fb0db5c8c1856 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 12 Jul 2021 22:59:20 +0000 Subject: [PATCH 2/2] Add extra check to avoid unnecessary reshape --- python/tvm/relay/frontend/onnx.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 706a060d1f0b..5faaa32295d9 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -680,6 +680,8 @@ def _impl_v1(cls, inputs, attr, params): def flatten_to_nd(x, x_shape, nd=3): ndims = infer_shape(x_shape)[0] + if ndims == nd: + return x newshape = _op.concatenate( [ _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),