Skip to content

Commit

Permalink
[Relay] Remove overwriting of matmul shapes when they are static (#13615
Browse files Browse the repository at this point in the history
)

In the Relay Matmul shape relation, we are a little over enthusiastic about unifying dynamic shapes. If one of the shapes is static, it does not need to be unified. This change only rewrites dynamic shapes to required static constraints.

* Remove overwriting of matmul shapes when they are static

* Simplify nesting

* Add shape check to dense tests.
  • Loading branch information
Josh Fromm authored Dec 15, 2022
1 parent 7fd0cdb commit cc0f27a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
33 changes: 21 additions & 12 deletions src/relay/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,32 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
std::vector<PrimExpr> B_shape(tensor_b->shape.begin(), tensor_b->shape.end());
auto sa = A_shape.size();
auto sb = B_shape.size();
size_t index_swap_A;
size_t index_swap_B;
if (transpose_a && transpose_b) {
auto tmp = A_shape[sa - 2];
A_shape[sa - 2] = B_shape[sb - 1];
B_shape[sb - 1] = tmp;
index_swap_A = sa - 2;
index_swap_B = sb - 1;
} else if (transpose_a) {
auto tmp = A_shape[sa - 2];
A_shape[sa - 2] = B_shape[sb - 2];
B_shape[sb - 2] = tmp;
index_swap_A = sa - 2;
index_swap_B = sb - 2;
} else if (transpose_b) {
auto tmp = A_shape[sa - 1];
A_shape[sa - 1] = B_shape[sb - 1];
B_shape[sb - 1] = tmp;
index_swap_A = sa - 1;
index_swap_B = sb - 1;
} else {
auto tmp = A_shape[sa - 1];
A_shape[sa - 1] = B_shape[sb - 2];
B_shape[sb - 2] = tmp;
index_swap_A = sa - 1;
index_swap_B = sb - 2;
}

// Rewrite dynamic axes to static where constraints allow.
auto tmp = A_shape[index_swap_A];
if (A_shape[index_swap_A].as<tir::AnyNode>()) {
A_shape[index_swap_A] = B_shape[index_swap_B];
}
if (B_shape[index_swap_B].as<tir::AnyNode>()) {
B_shape[index_swap_B] = tmp;
}

// Update input types with new constrained shapes.
reporter->Assign(types[0], TensorType(A_shape, tensor_a->dtype));
reporter->Assign(types[1], TensorType(B_shape, tensor_b_dtype));
}
Expand Down
3 changes: 3 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tvm.topi.testing
from tvm.contrib.nvcc import have_fp16
import tvm.testing
from tvm.topi.utils import get_const_tuple

executor_kind = tvm.testing.parameter("graph", "vm")

Expand Down Expand Up @@ -695,6 +696,8 @@ def test_dense(executor_kind):
w = relay.var("w", relay.TensorType((k, n), dtype))
y = relay.nn.dense(x, w)
yy = run_infer_type(y)
# Confirm that input shape has not been rewritten to become dynamic.
assert get_const_tuple(yy.type_args[0].shape) == (4, 2)

n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
Expand Down

0 comments on commit cc0f27a

Please sign in to comment.