Skip to content

Commit

Permalink
[Op][Fix] Fix legalizer on Matmul with symbolic input shape (apache#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored Nov 24, 2022
1 parent f5dd68b commit dfe9d42
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
26 changes: 12 additions & 14 deletions python/tvm/relax/transform/op_legalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def gelu(x):
x.shape,
lambda *i: 0.5
* x(*i)
* (1 + te.tanh(math.sqrt(2 / math.pi) *
(x(*i) + 0.044715 * te.power(x(*i), 3)))),
* (1 + te.tanh(math.sqrt(2 / math.pi) * (x(*i) + 0.044715 * te.power(x(*i), 3)))),
)

return bb.call_te(gelu, args[0])
Expand Down Expand Up @@ -214,8 +213,7 @@ def layer_norm(x, gamma, beta, axis, eps):
for dim in axis:
shape_prod = shape_prod * x.shape[dim.value]
mean = topi.sum(x, axis=axis, keepdims=True) / shape_prod
var = topi.sum((x - mean) * (x - mean), axis=axis,
keepdims=True) / shape_prod
var = topi.sum((x - mean) * (x - mean), axis=axis, keepdims=True) / shape_prod
return gamma * ((x - mean) / topi.sqrt(var + eps)) + beta

return bb.call_te(layer_norm, args[0], args[1], args[2], axis=attrs.axis, eps=attrs.epsilon)
Expand All @@ -237,8 +235,7 @@ def _nn_matmul(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: E
b_shape.append(1)

is_a_larger = len(a_shape) > len(b_shape)
offset = len(a_shape) - \
len(b_shape) if is_a_larger else len(b_shape) - len(a_shape)
offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape)

def matmul(a, b):
def matmul_compute(*idx_spatial):
Expand All @@ -254,12 +251,14 @@ def multiply_compute(idx_reduce):
else:
b_indices.append(idx_spatial[i])
for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)):
a_idx = i if is_a_larger else i - offset
b_idx = i if not is_a_larger else i - offset
a_indices.append(
idx_spatial[i] if a_shape[a_idx] > 1 else 0)
b_indices.append(
idx_spatial[i] if b_shape[b_idx] > 1 else 0)
a_dim = a_shape[i if is_a_larger else i - offset]
b_dim = b_shape[i if not is_a_larger else i - offset]
# Since we have no knowledge on the symbolic dimension, we assume the dimension
# has length greater than 1.
a_dim_is_one = isinstance(a_dim, tvm.tir.IntImm) and a_dim == 1
b_dim_is_one = isinstance(b_dim, tvm.tir.IntImm) and b_dim == 1
a_indices.append(0 if a_dim_is_one else idx_spatial[i])
b_indices.append(0 if b_dim_is_one else idx_spatial[i])
if not a_prepended:
a_indices.append(idx_spatial[-2 + b_appended])
a_indices.append(idx_reduce)
Expand Down Expand Up @@ -300,8 +299,7 @@ def _sum(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr):

def _mean(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr):
shape_prod = tvm.tir.const(1, "int32")
axis = attrs.axis if attrs.axis is not None else range(
0, len(args[0].shape))
axis = attrs.axis if attrs.axis is not None else range(0, len(args[0].shape))
for dim in axis:
shape_prod = shape_prod * args[0].shape[dim.value]
sum_var = bb.emit_te(topi.sum, args[0], axis, attrs.keepdims)
Expand Down
51 changes: 51 additions & 0 deletions tests/python/relax/test_op_legalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,57 @@ def matmul(
tvm.ir.assert_structural_equal(mod, Expected)


def test_matmul_3_3_with_symbolic_broadcast_dim():
a = tvm.tir.Var("a", dtype="int64")

@I.ir_module
class Matmul:
@R.function
def main(
x: R.Tensor((a, 3, 4), "float32"), y: R.Tensor((1, 4, 5), "float32")
) -> R.Tensor(None, "float32", ndim=3):
gv: R.Tensor((a, 3, 5), "float32") = R.matmul(x, y)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((a, 3, 4), "float32"), y: R.Tensor((1, 4, 5), "float32")
) -> R.Tensor(None, "float32", ndim=3):
gv = R.call_tir(matmul, (x, y), (a, 3, 5), dtype="float32")
return gv

@T.prim_func
def matmul(
var_rxplaceholder: T.handle,
rxplaceholder: T.Buffer[(T.int64(1), T.int64(4), T.int64(5)), "float32"],
var_matmul: T.handle,
):
T.func_attr({"global_symbol": "matmul", "tir.noalias": True})
a = T.var("int64")
rxplaceholder_1 = T.match_buffer(
var_rxplaceholder, [a, T.int64(3), T.int64(4)], dtype="float32"
)
matmul = T.match_buffer(var_matmul, [a, T.int64(3), T.int64(5)], dtype="float32")
for i0, i1, i2, i3 in T.grid(a, T.int64(3), T.int64(5), T.int64(4)):
with T.block("matmul"):
i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
T.reads(rxplaceholder_1[i0_1, i1_1, k], rxplaceholder[T.int64(0), k, i2_1])
T.writes(matmul[i0_1, i1_1, i2_1])
with T.init():
matmul[i0_1, i1_1, i2_1] = T.float32(0)
matmul[i0_1, i1_1, i2_1] = (
matmul[i0_1, i1_1, i2_1]
+ rxplaceholder_1[i0_1, i1_1, k] * rxplaceholder[T.int64(0), k, i2_1]
)

mod = OperatorLegalizer(Matmul).transform()
# TVMScript and Relax function now have limited support on understanding symbolic variables. So
# at this moment we only compare the the generated PrimFunc.
tvm.ir.assert_structural_equal(mod["matmul"], Expected["matmul"])


def test_softmax():
@I.ir_module
class Softmax:
Expand Down

0 comments on commit dfe9d42

Please sign in to comment.