diff --git a/python/tvm/relax/transform/fuse_transpose_matmul.py b/python/tvm/relax/transform/fuse_transpose_matmul.py index 1d2324a28b3e..141f926cd3f8 100644 --- a/python/tvm/relax/transform/fuse_transpose_matmul.py +++ b/python/tvm/relax/transform/fuse_transpose_matmul.py @@ -41,7 +41,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR "transpose_matmul_fuse", *_pattern(), ), - ] + ], + bind_constants=False, )(mod) transpose_matmul_codegen = _TransposeMatmulFuser(mod) for g_var, func in mod.functions_items(): diff --git a/tests/python/relax/test_transform_fuse_transpose_matmul.py b/tests/python/relax/test_transform_fuse_transpose_matmul.py index 4b2b1fff8aba..446102dcbbc6 100644 --- a/tests/python/relax/test_transform_fuse_transpose_matmul.py +++ b/tests/python/relax/test_transform_fuse_transpose_matmul.py @@ -22,6 +22,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T +import numpy as np def test_transform_fuse_transpose_matmul(): @@ -78,5 +79,58 @@ def main( tvm.ir.assert_structural_equal(after, Expected) +def test_transform_fuse_transpose_matmul_const(): + w = relax.const(np.random.uniform(-1e-3, 1e-3, (128, 256)), "float32") + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor((128, 256), "float32"), + ) -> R.Tensor((128, 128), "float32"): + with R.dataflow(): + wT = R.permute_dims(w, [1, 0]) + o = R.matmul(x, wT) + R.output(o) + return o + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def NT_matmul( + x: T.Buffer((T.int64(128), T.int64(256)), "float32"), + w: T.Buffer((T.int64(128), T.int64(256)), "float32"), + NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)): + with T.block("NT_matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(x[v_i0, v_k], w[v_i1, v_k]) + T.writes(NT_matmul[v_i0, v_i1]) + with T.init(): + NT_matmul[v_i0, v_i1] = T.float32(0) + NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0, v_k] * w[v_i1, v_k] + + @R.function + def main(x: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"): + cls = Expected + with R.dataflow(): + gv = R.call_tir( + cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128), dtype="float32") + ) + R.output(gv) + return gv + + after = tvm.ir.transform.Sequential( + [ + relax.transform.FuseTransposeMatmul(), + relax.transform.FuseTIR(), # Only used for remove unused primitive function + ] + )(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main()