diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index f5497a4603bf..cf601ff5f11b 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -113,23 +113,32 @@ bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, std::vector B_shape(tensor_b->shape.begin(), tensor_b->shape.end()); auto sa = A_shape.size(); auto sb = B_shape.size(); + size_t index_swap_A; + size_t index_swap_B; if (transpose_a && transpose_b) { - auto tmp = A_shape[sa - 2]; - A_shape[sa - 2] = B_shape[sb - 1]; - B_shape[sb - 1] = tmp; + index_swap_A = sa - 2; + index_swap_B = sb - 1; } else if (transpose_a) { - auto tmp = A_shape[sa - 2]; - A_shape[sa - 2] = B_shape[sb - 2]; - B_shape[sb - 2] = tmp; + index_swap_A = sa - 2; + index_swap_B = sb - 2; } else if (transpose_b) { - auto tmp = A_shape[sa - 1]; - A_shape[sa - 1] = B_shape[sb - 1]; - B_shape[sb - 1] = tmp; + index_swap_A = sa - 1; + index_swap_B = sb - 1; } else { - auto tmp = A_shape[sa - 1]; - A_shape[sa - 1] = B_shape[sb - 2]; - B_shape[sb - 2] = tmp; + index_swap_A = sa - 1; + index_swap_B = sb - 2; } + + // Rewrite dynamic axes to static where constraints allow. + auto tmp = A_shape[index_swap_A]; + if (A_shape[index_swap_A].as()) { + A_shape[index_swap_A] = B_shape[index_swap_B]; + } + if (B_shape[index_swap_B].as()) { + B_shape[index_swap_B] = tmp; + } + + // Update input types with new constrained shapes. reporter->Assign(types[0], TensorType(A_shape, tensor_a->dtype)); reporter->Assign(types[1], TensorType(B_shape, tensor_b_dtype)); } diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 30d9d88ad7cb..bd4e1b72c3cd 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -25,6 +25,7 @@ import tvm.topi.testing from tvm.contrib.nvcc import have_fp16 import tvm.testing +from tvm.topi.utils import get_const_tuple executor_kind = tvm.testing.parameter("graph", "vm") @@ -695,6 +696,8 @@ def test_dense(executor_kind): w = relay.var("w", relay.TensorType((k, n), dtype)) y = relay.nn.dense(x, w) yy = run_infer_type(y) + # Confirm that input shape has not been rewritten to become dynamic. + assert get_const_tuple(yy.type_args[0].shape) == (4, 2) n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2 x = relay.var("x", relay.TensorType((n, c, h, w), dtype))