Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[BugFix] Enable emit global MatchShape (#246)
Browse files Browse the repository at this point in the history
Fix an incorrect check which disables emitting global MatchShape outside a dataflow block and mistakenly enables emitting dataflow MatchShape outside a dataflow block.
  • Loading branch information
MasterJH5574 authored and YuchenJin committed Jan 13, 2023
1 parent 1a427a2 commit d707003
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
2 changes: 0 additions & 2 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,6 @@ Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array<PrimExpr>& p
Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) {
BlockFrame* cur_frame = CurrentFrame();
if (binding->var.defined()) {
ICHECK(!cur_frame->is_dataflow || binding->var.as<DataflowVarNode>())
<< "EmitMatchShape can only be used for local bindings in a dataflow block.";
ICHECK(cur_frame->is_dataflow || !binding->var.as<DataflowVarNode>())
<< "cannot emit dataflow vars outside a dataflow block: " << binding->var->name_hint();
binding_table_[binding->var->vid] = binding->value;
Expand Down
24 changes: 24 additions & 0 deletions tests/python/relax/test_blockbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,30 @@ def test_emit_match_shape():
assert b1.var == lv1


def test_emit_match_shape_binding_in_dataflow_block():
bb = rx.BlockBuilder()

x = rx.Var("x", type_annotation=rx.DynTensorType(-1, "float32"))
m = tir.Var("m", dtype="int32")
gv = rx.Var("gv")
match_shape = rx.MatchShape(x, (m,), gv)

with bb.function("main", [x]):
with bb.dataflow():
bb.match_shape_binding(match_shape)
bb.emit_output(gv)
bb.emit_func_output(x)

func = bb.get()["main"]
block = func.body.blocks[0]
b0 = block.bindings[0]
assert isinstance(b0, rx.MatchShape)

assert b0.value == x
assert b0.pattern[0] == m
assert b0.var == gv


def test_normalize():
m = tir.Var("m", "int32")
n = tir.Var("n", "int32")
Expand Down

0 comments on commit d707003

Please sign in to comment.