From cc0f27a8b130f1dee5047f4fa6bfdbe82ed1a24a Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 14 Dec 2022 18:51:08 -0800 Subject: [PATCH] [Relay] Remove overwriting of matmul shapes when they are static (#13615) In the Relay Matmul shape relation, we are a little over enthusiastic about unifying dynamic shapes. If one of the shapes is static, it does not need to be unified. This change only rewrites dynamic shapes to required static constraints. * Remove overwriting of matmul shapes when they are static * Simplify nesting * Add shape check to dense tests. --- src/relay/op/nn/nn.h | 33 ++++++++++++++++++---------- tests/python/relay/test_op_level1.py | 3 +++ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index f5497a4603bf..cf601ff5f11b 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -113,23 +113,32 @@ bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, std::vector B_shape(tensor_b->shape.begin(), tensor_b->shape.end()); auto sa = A_shape.size(); auto sb = B_shape.size(); + size_t index_swap_A; + size_t index_swap_B; if (transpose_a && transpose_b) { - auto tmp = A_shape[sa - 2]; - A_shape[sa - 2] = B_shape[sb - 1]; - B_shape[sb - 1] = tmp; + index_swap_A = sa - 2; + index_swap_B = sb - 1; } else if (transpose_a) { - auto tmp = A_shape[sa - 2]; - A_shape[sa - 2] = B_shape[sb - 2]; - B_shape[sb - 2] = tmp; + index_swap_A = sa - 2; + index_swap_B = sb - 2; } else if (transpose_b) { - auto tmp = A_shape[sa - 1]; - A_shape[sa - 1] = B_shape[sb - 1]; - B_shape[sb - 1] = tmp; + index_swap_A = sa - 1; + index_swap_B = sb - 1; } else { - auto tmp = A_shape[sa - 1]; - A_shape[sa - 1] = B_shape[sb - 2]; - B_shape[sb - 2] = tmp; + index_swap_A = sa - 1; + index_swap_B = sb - 2; } + + // Rewrite dynamic axes to static where constraints allow. + auto tmp = A_shape[index_swap_A]; + if (A_shape[index_swap_A].as()) { + A_shape[index_swap_A] = B_shape[index_swap_B]; + } + if (B_shape[index_swap_B].as()) { + B_shape[index_swap_B] = tmp; + } + + // Update input types with new constrained shapes. reporter->Assign(types[0], TensorType(A_shape, tensor_a->dtype)); reporter->Assign(types[1], TensorType(B_shape, tensor_b_dtype)); } diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 30d9d88ad7cb..bd4e1b72c3cd 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -25,6 +25,7 @@ import tvm.topi.testing from tvm.contrib.nvcc import have_fp16 import tvm.testing +from tvm.topi.utils import get_const_tuple executor_kind = tvm.testing.parameter("graph", "vm") @@ -695,6 +696,8 @@ def test_dense(executor_kind): w = relay.var("w", relay.TensorType((k, n), dtype)) y = relay.nn.dense(x, w) yy = run_infer_type(y) + # Confirm that input shape has not been rewritten to become dynamic. + assert get_const_tuple(yy.type_args[0].shape) == (4, 2) n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2 x = relay.var("x", relay.TensorType((n, c, h, w), dtype))