Skip to content

Commit

Permalink
[Parser][Printer] Switch to output annotation for dataflow blocks (ap…
Browse files Browse the repository at this point in the history
…ache#9)

* Relax pretty printer initial prototype

* call into TVMScriptPrinter for PrimFuncs

* most round-trip tests pass

* address comments

* implement relax.output syntax for dataflow block outputs

* remove leftover comments

* fix Var constructor on ShapeExpr annotation

* fix DataflowVar as well
  • Loading branch information
altanh authored and junrushao committed Feb 9, 2023
1 parent 6eb8d99 commit f04dbcb
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 31 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
type_annotation: Optional[Type] = None,
span: Span = None,
) -> None:
if shape_annotation is not None:
if isinstance(shape_annotation, (list, tuple)):
shape_annotation = make_shape(shape_annotation)
self.__init_handle_by_constructor__(
_ffi_api.Var, name_hint, shape_annotation, type_annotation, span
Expand All @@ -88,7 +88,7 @@ def __init__(
type_annotation: Optional[Type] = None,
span: Span = None,
) -> None:
if shape_annotation is not None:
if isinstance(shape_annotation, (list, tuple)):
shape_annotation = make_shape(shape_annotation)
self.__init_handle_by_constructor__(
_ffi_api.DataflowVar, name_hint, shape_annotation, type_annotation, span
Expand Down
51 changes: 32 additions & 19 deletions python/tvm/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class SpecialOp(Enum):
MATCH_SHAPE = "relax.match_shape"
CALL_PACKED = "relax.call_packed"
DATAFLOW = "relax.dataflow"
DATAFLOW_OUTPUT = "relax.output"


class RelaxTransformer(Transformer):
Expand Down Expand Up @@ -660,32 +661,30 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock:
"""
assert len(block.stmts) > 0, "should never have an empty dataflow block"
bindings = []
output_vars = []

with self.new_scope():
# parse the return statement first to figure out which bindings assign normal Vars
# parse the output statement first to figure out which bindings assign normal Vars
output_stmt = block.stmts[-1]
if not isinstance(output_stmt, ast.Return):
self.report_error(
"dataflow blocks must end with returning the output variables",
output_stmt.span,
)
output_var_names = set()
unbound_output_vars = {}
output_vars = []

ret_val = output_stmt.value
if isinstance(ret_val, ast.Var):
ret_val = ast.Tuple(values=[ret_val], span=ret_val.span)

if not isinstance(ret_val, ast.Tuple) or any(
[not isinstance(f, ast.Var) for f in ret_val.values]
if (
isinstance(output_stmt, ast.UnassignedCall)
and self.transform_expr(output_stmt.call.func_name) == SpecialOp.DATAFLOW_OUTPUT
):
for var in output_stmt.call.params:
if not isinstance(var, ast.Var):
self.report_error(f"dataflow block outputs must be variables", var.span)
output_var_names.add(var.id.name)
unbound_output_vars[var.id.name] = var
else:
self.report_error(
"the returned values must be variables",
ret_val.span,
f"dataflow blocks must end with a {SpecialOp.DATAFLOW_OUTPUT.value} statement",
output_stmt.span,
)

# output variables are bound to normal (not data flow) Vars
output_var_names = {var.id.name for var in ret_val.values}

# output variables are bound to normal (not dataflow) Vars
for binding_stmt in block.stmts[:-1]:
if not isinstance(binding_stmt, (ast.Assign, ast.UnassignedCall)):
self.report_error(
Expand All @@ -704,6 +703,18 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock:
output_vars.append(var)
else:
output_vars.append(binding.var)
unbound_output_vars.pop(binding_stmt.lhs.id.name)

# check that the output variables are all bound locally
for unbound_var in unbound_output_vars.values():
self._diagnostic_context.emit(
"error",
"dataflow output variables must be bound locally in the block",
unbound_var.span,
)
# FIXME(@altanh): TVMDiagnosticCtx has hard-coded `emit` to always be an error and raise
# an exception on the first call
self._diagnostic_context.render()

# make output variables visible in parent scope
for v in output_vars:
Expand Down Expand Up @@ -769,8 +780,10 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr:
)
op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span))
args = [self.transform_expr(expr.params[1])]
else:
elif isinstance(op, (tvm.ir.Op, relay.Expr)):
args = [self.transform_expr(arg) for arg in expr.params]
else:
self.report_error(f"unsupported function in call: {op}", expr.func_name.span)
# TODO(@altanh): should we check for correct arity here eagerly, or defer to a pass?
return relay.Call(op, args, span=self.to_tvm_span(expr.span))

Expand Down
2 changes: 1 addition & 1 deletion src/relay/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) {
}
}
ICHECK(!return_vars.empty()) << "dataflow blocks should have at least one output variable";
body << "return " << Doc::Concat(return_vars, Doc::Text(", "));
body << "relax.output(" << Doc::Concat(return_vars, Doc::Text(", ")) << ")";
block << "with relax.dataflow():" << Doc::NewLine(4);
block << Doc::Indent(4, body) << Doc::NewLine();
return block;
Expand Down
43 changes: 36 additions & 7 deletions tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,10 @@ def test_dataflow():
@rx.script
def foo(x: Tensor[_, _]):
with relax.dataflow():
# TODO: parse this
# nonlocal y, w
y = add(x, x)
z = multiply(y, x)
w = subtract(z, x)
return y, w
relax.output(y, w)
t = divide(y, w)
return t

Expand Down Expand Up @@ -295,7 +293,7 @@ def foo(x: Tensor[_, _]):
z = multiply(y, x)
relax.match_shape((n, m), z.shape)
w: Tensor[(n, m), _] = subtract(z, x)
return y, w
relax.output(y, w)
t: Tensor[(n, m), _] = divide(y, w)
return t

Expand All @@ -308,7 +306,7 @@ def foo(x: Tensor[_, _]):
y = add(x, x)
z = multiply(y, x)
w = subtract(z, x)
return y, w
relax.output(y, w)
t = divide(y, z)
return t

Expand All @@ -321,7 +319,7 @@ def foo(x: Tensor[_, _]):
y = add(x, x)
z = multiply(y, x)
w = subtract(z, x)
return y, w
relax.output(y, z)
t = divide(y, z)
return t

Expand All @@ -334,11 +332,42 @@ def foo(x: Tensor[_, _]):
y = add(x, x)
z = multiply(y, x)
w = subtract(z, x)
return y, w
relax.output(y, w)
t = divide(y, z)
return t


@pytest.mark.xfail
def test_dataflow_unbound_outputs():
@rx.script
def foo(x: Tensor[_, _]):
with relax.dataflow():
y = add(x, x)
z = multiply(y, x)
w = subtract(z, x)
relax.output(x, y, w, q)
t = divide(y, z)
return t


@pytest.mark.xfail
def test_invalid_special_op_dataflow():
@rx.script
def foo(x: Tensor):
y = add(x, x)
z = relax.dataflow()
return z


@pytest.mark.xfail
def test_invalid_special_op_output():
@rx.script
def foo(x: Tensor):
y = add(x, x)
z = relax.output(y)
return z


@pytest.mark.xfail
def test_func_no_return_fail():
@rx.script
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def foo(x: Tensor[_, _]):
y = add(x, x)
z = multiply(y, x)
w = subtract(z, x)
return y, w
relax.output(y, w)
t = divide(y, w)
return t

Expand All @@ -98,7 +98,7 @@ def foo(x: Tensor[_, _]):
z = multiply(y, x)
relax.match_shape((n, m), z.shape)
w: Tensor[(n, m), _] = subtract(z, x)
return y, w
relax.output(y, w)
t: Tensor[(n, m), _] = divide(y, w)
return t

Expand Down

0 comments on commit f04dbcb

Please sign in to comment.