From a936aa54955df0088008928b6399daecba05944d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 12 Sep 2022 07:25:53 -0700 Subject: [PATCH 1/2] [BugFix] Enable emit global MatchShape --- src/relax/ir/block_builder.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index e1aafdff90..de457554a9 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -626,8 +626,6 @@ Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array& p Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) { BlockFrame* cur_frame = CurrentFrame(); if (binding->var.defined()) { - ICHECK(!cur_frame->is_dataflow || binding->var.as()) - << "EmitMatchShape can only be used for local bindings in a dataflow block."; ICHECK(cur_frame->is_dataflow || !binding->var.as()) << "cannot emit dataflow vars outside a dataflow block: " << binding->var->name_hint(); binding_table_[binding->var->vid] = binding->value; From a18309d392fc8eaf9736649d0c6afa54981fbe5e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 13 Sep 2022 19:52:21 -0700 Subject: [PATCH 2/2] Add unit tests --- tests/python/relax/test_blockbuilder.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 0da256f3b1..c5301abf9f 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -300,6 +300,30 @@ def test_emit_match_shape(): assert b1.var == lv1 +def test_emit_match_shape_binding_in_dataflow_block(): + bb = rx.BlockBuilder() + + x = rx.Var("x", type_annotation=rx.DynTensorType(-1, "float32")) + m = tir.Var("m", dtype="int32") + gv = rx.Var("gv") + match_shape = rx.MatchShape(x, (m,), gv) + + with bb.function("main", [x]): + with bb.dataflow(): + bb.match_shape_binding(match_shape) + bb.emit_output(gv) + bb.emit_func_output(x) + + func = bb.get()["main"] + block = func.body.blocks[0] + b0 = block.bindings[0] + assert isinstance(b0, rx.MatchShape) + + assert b0.value == x + assert b0.pattern[0] == m + assert b0.var == gv + + def test_normalize(): m = tir.Var("m", "int32") n = tir.Var("n", "int32")