diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index fb08dfe96a17..88efad86cfdc 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -49,7 +49,7 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor { // caller -> callee table. std::map> caller2callees; - const VarNode* cur_user_; + const VarNode* cur_user_ = nullptr; void VisitBinding_(const VarBindingNode* binding) override { // init diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index f67b0530ca87..03a3beb2f27e 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1053,9 +1053,17 @@ def main( assert ctx.match_dfb(dfb) is None -def get_qkv_proj_rewriter( - inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 -): +def get_qkv_proj_rewriter(): + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + def qkv_proj_rewriter(matchings, _): inp = matchings[inp_pat] Q_weight = matchings[Q_weight_pat] @@ -1071,7 +1079,7 @@ def qkv_proj_rewriter(matchings, _): return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V} - return qkv_proj_rewriter + return ctx, qkv_proj_rewriter def test_combine_matmul_twice(): @@ -1123,21 +1131,63 @@ def expected( R.output(out) return out - with PatternContext() as ctx: - inp_pat = wildcard() - Q_weight_pat = wildcard() - K_weight_pat = wildcard() - V_weight_pat = wildcard() + ctx, rewriter = get_qkv_proj_rewriter() + rewritten = rewrite_bindings(ctx, rewriter, qkv_x2) + tvm.ir.assert_structural_equal(rewritten, expected) - matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) - matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) - matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) - rewriter = get_qkv_proj_rewriter( - inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 - ) - rewritten = rewrite_bindings(ctx, rewriter, qkv_x2) - tvm.ir.assert_structural_equal(rewritten, expected) +def test_dataflow_may_start_with_match_cast(): + """Inputs to rewrite_bindings may contain R.match_cast + + This is a regression test. In previous implementations, applying + `rewrite_bindings` when `R.match_cast` is the first binding of a + `R.dataflow` block would cause a segfault. + + """ + + @R.function(private=True) + def before( + x_untyped: R.Tensor, + w0_untyped: R.Tensor, + w1_untyped: R.Tensor, + w2_untyped: R.Tensor, + ): + with R.dataflow(): + x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32")) + w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32")) + w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32")) + w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32")) + out_0 = R.matmul(x, w0) + out_1 = R.matmul(x, w1) + out_2 = R.matmul(x, w2) + out = (out_0, out_1, out_2) + R.output(out) + return out + + @R.function(private=True) + def expected( + x_untyped: R.Tensor, + w0_untyped: R.Tensor, + w1_untyped: R.Tensor, + w2_untyped: R.Tensor, + ): + with R.dataflow(): + x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32")) + w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32")) + w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32")) + w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32")) + w_concat = R.concat((w0, w1, w2), axis=1) + out_concat = R.matmul(x, w_concat) + out_0 = R.strided_slice(out_concat, axes=[2], begin=[0], end=[640]) + out_1 = R.strided_slice(out_concat, axes=[2], begin=[640], end=[1280]) + out_2 = R.strided_slice(out_concat, axes=[2], begin=[1280], end=[1920]) + out = (out_0, out_1, out_2) + R.output(out) + return out + + ctx, rewriter = get_qkv_proj_rewriter() + rewritten = rewrite_bindings(ctx, rewriter, before) + tvm.ir.assert_structural_equal(rewritten, expected) def test_combine_matmul_emit_order(): @@ -1181,27 +1231,16 @@ def expected( R.output(out) return out - with PatternContext() as ctx: - inp_pat = wildcard() - Q_weight_pat = wildcard() - K_weight_pat = wildcard() - V_weight_pat = wildcard() + ctx, rewriter = get_qkv_proj_rewriter() - matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) - matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) - matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + rewritten = rewrite_bindings(ctx, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected) - rewriter = get_qkv_proj_rewriter( - inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 - ) - rewritten = rewrite_bindings(ctx, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected) - - # make sure it builds - mod = tvm.IRModule() - mod["main"] = rewritten + # make sure it builds + mod = tvm.IRModule() + mod["main"] = rewritten - rx.build(mod, target="llvm") + rx.build(mod, target="llvm") def test_combine_transposed_matmul_twice():