Skip to content

Commit

Permalink
[Parser][Printer] relax call_packed arity, return IRModule factory, p…
Browse files Browse the repository at this point in the history
…rint IRModule PrimFuncs (#17)
  • Loading branch information
altanh authored and junrushao committed Feb 5, 2023
1 parent 926c696 commit 30b5482
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 41 deletions.
29 changes: 16 additions & 13 deletions python/tvm/relax/parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import inspect
from typing import TypeVar, Generic, Union, Dict, List, Tuple, Optional
from typing import TypeVar, Generic, Union, Dict, List, Tuple, Optional, Callable
from io import StringIO
from enum import Enum

Expand Down Expand Up @@ -837,19 +837,14 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, rx.Expr]:
op = self.transform_expr(expr.func_name)

if op == SpecialOp.CALL_PACKED:
if len(expr.params) != 2:
self.report_error(
op.value + " takes an extern function name and a tuple of arguments",
expr.span,
)
extern_func = expr.params[0]
if not (isinstance(extern_func, ast.Constant) and isinstance(extern_func.value, str)):
self.report_error(
"the first argument of " + op.value + " must be the extern function name",
extern_func.span,
)
op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span))
args = [self.transform_expr(expr.params[1])]
args = [self.transform_expr(arg) for arg in expr.params[1:]]

elif isinstance(op, ArithmeticOp):
args = [self.transform_expr(arg) for arg in expr.params]
Expand Down Expand Up @@ -1045,7 +1040,7 @@ def transform_block(self, block: ast.Block) -> rx.SeqExpr:
# self.tvm_diag_ctx.render()


def script(f) -> Union[rx.Function, tvm.IRModule]:
def script(f) -> Union[rx.Function, Callable[[], tvm.IRModule]]:
"""Parses the decorated Relax function or module (in Relax IR) to a Relax AST.
Parameters
Expand All @@ -1056,15 +1051,19 @@ def script(f) -> Union[rx.Function, tvm.IRModule]:
Returns
-------
Union[rx.Function, IRModule]
The parsed Relax function or IRModule.
The parsed Relax function or IRModule factory (which returns the parsed IRModule when
called).
"""
diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx()
ast = synr.to_ast(f, diag_ctx)
return RelaxTransformer().do_transform(ast, diag_ctx)
mod = RelaxTransformer().do_transform(ast, diag_ctx)
if isinstance(mod, tvm.IRModule):
return lambda: mod
return mod


def fromtext(source: str, source_name: str = "from_string") -> Union[rx.Function, tvm.IRModule]:
"""Parses the given input string (in the Relax text format) to a Relax function or IRModule.
"""Parses the given input string (in the Relax text format) to a Relax AST.
Parameters
----------
Expand All @@ -1076,12 +1075,16 @@ def fromtext(source: str, source_name: str = "from_string") -> Union[rx.Function
Returns
-------
Union[rx.Function, IRModule]
The parsed Relax function or IRModule.
The parsed Relax function or IRModule factory (which returns the parsed IRModule when
called).
"""
# TODO(@altanh): actually use source_name somewhere?
diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx()
ast = synr.to_ast(source, diag_ctx)
return RelaxTransformer().do_transform(ast, diag_ctx)
mod = RelaxTransformer().do_transform(ast, diag_ctx)
if isinstance(mod, tvm.IRModule):
return lambda: mod
return mod


def pretty_print(node):
Expand Down
40 changes: 26 additions & 14 deletions src/relay/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class RelaxScriptPrinter : public relax::IRFunctor<Doc(const ObjectRef&)>,
Doc VisitExpr_(const tir::FloorDivNode* op) override;

Doc PrintIRModule(const IRModule& mod);
Doc PrintPrimFunc(const String& name, const tir::PrimFunc& func);

Doc PrintIfStmt(const relax::Var& var, const relay::If& ite);
Doc PrintFunctionDef(const Doc& name, const relax::Function& func);
Expand Down Expand Up @@ -178,17 +179,19 @@ Doc RelaxScriptPrinter::VisitNode_(const relay::CallNode* op) {
// TODO(@altanh): how to support when func cannot be printed as Python expr?
// e.g. Function or If
Doc doc;
std::vector<Doc> args;

if (op->op.as<relax::ExternFuncNode>()) {
ICHECK_EQ(op->args.size(), 1) << "extern calls should only have one argument";
doc << "relax.call_packed(" << Print(op->op) << ", " << Print(op->args[0]);
doc << "relax.call_packed";
args.push_back(Print(op->op));
} else {
std::vector<Doc> args;
for (const Expr& arg : op->args) {
args.push_back(Print(arg));
}
doc << Print(op->op) << "(" << Doc::Concat(args, Doc::Text(", "));
doc << Print(op->op);
}

for (const Expr& arg : op->args) {
args.push_back(Print(arg));
}
doc << "(" << Doc::Concat(args, Doc::Text(", "));

std::vector<Doc> attrs = PrintAttrs(op->attrs);
if (!attrs.empty()) {
Expand Down Expand Up @@ -260,12 +263,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::VarBindingNode* op) {
} else if (const relax::FunctionNode* func = op->value.as<relax::FunctionNode>()) {
return PrintFunctionDef(Print(op->var), GetRef<relax::Function>(func));
} else if (const tir::PrimFuncNode* prim_func = op->value.as<tir::PrimFuncNode>()) {
// we need the mod for TVMScriptPrinter to properly print the function name - maybe it's worth
// refactoring to avoid this?
tir::PrimFunc prim_func_ref = GetRef<tir::PrimFunc>(prim_func);
IRModule mod;
mod->Add(relay::GlobalVar(op->var->name_hint()), prim_func_ref);
return tir::AsTVMScriptDoc(mod, false, prim_func_ref);
return PrintPrimFunc(op->var->name_hint(), GetRef<tir::PrimFunc>(prim_func));
} else {
Doc doc;
doc << Print(op->var) << PrintVarAnnotation(op->var);
Expand Down Expand Up @@ -426,11 +424,25 @@ Doc RelaxScriptPrinter::PrintIRModule(const IRModule& mod) {
Doc doc;
doc << "class Module:";
for (const std::pair<GlobalVar, BaseFunc>& pr : mod->functions) {
doc << Doc::Indent(4, Doc::NewLine() << Print(pr.second));
Doc func;
if (pr.second.as<tir::PrimFuncNode>()) {
func = PrintPrimFunc(pr.first->name_hint, Downcast<tir::PrimFunc>(pr.second));
} else {
func = Print(pr.second);
}
doc << Doc::Indent(4, Doc::NewLine() << func);
}
return doc;
}

Doc RelaxScriptPrinter::PrintPrimFunc(const String& name, const tir::PrimFunc& func) {
// we need the mod for TVMScriptPrinter to properly print the function name - maybe it's worth
// refactoring to avoid this?
IRModule mod;
mod->Add(relay::GlobalVar(name), func);
return tir::AsTVMScriptDoc(mod, false, func);
}

Doc RelaxScriptPrinter::PrintIfStmt(const relax::Var& var, const relay::If& ite) {
const relax::SeqExprNode* true_branch = ite->true_branch.as<relax::SeqExprNode>();
const relax::SeqExprNode* false_branch = ite->false_branch.as<relax::SeqExprNode>();
Expand Down
32 changes: 25 additions & 7 deletions tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,9 @@ def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:

def test_call_packed():
@rx.script
def f(x: Tensor[(3, 4), "float32"]):
def f(x: Tensor[(3, 3), "float32"]):
# test that we can intro dim vars
z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x), mp=False)
z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False)
return z

x = f.params[0]
Expand All @@ -444,7 +444,7 @@ def f(x: Tensor[(3, 4), "float32"]):
assert isinstance(z_bind.value.op, rx.ExternFunc)
assert z_bind.value.op.global_symbol == "contrib.my_matmul"
assert "mp" in z_bind.value.attrs and z_bind.value.attrs["mp"] == False
assert structural_equal(z_bind.value.args, [rx.Tuple([x, x])])
assert structural_equal(z_bind.value.args, [x, x])


def test_primexpr_arithmetic():
Expand Down Expand Up @@ -480,18 +480,36 @@ def f(x: Tensor):

def test_class_irmodule():
@rx.script
class my_module:
def f(x: Tensor[(n, m), _]) -> Tensor:
class MyModule:
@tvm.script.tir
def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])

with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = tir.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

def f(x: Tensor[(n, n), _]) -> Tensor:
return g(x)

def g(y: Tensor[(n, m), _]) -> Tensor:
return y
def g(y: Tensor[(n, n), _]) -> Tensor:
return relax.call_dps((n, n), my_matmul, (y, y))

def h(x, y, z):
_ = my_matmul(x, y, z)
return z

my_module = MyModule()
assert isinstance(my_module, tvm.IRModule)

var_f = my_module.get_global_var("f")
var_g = my_module.get_global_var("g")
var_my_matmul = my_module.get_global_var("my_matmul")
f = my_module[var_f]
g = my_module[var_g]

assert f.body.body.op == var_g
assert g.body.body.args[1] == var_my_matmul
32 changes: 25 additions & 7 deletions tests/python/relax/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

def check_roundtrip(f_pre):
f_post = rx.parser.fromtext(rx.parser.astext(f_pre))
if isinstance(f_pre, tvm.IRModule):
f_post = f_post()
assert_structural_equal(f_pre, f_post, map_free_vars=True)


Expand Down Expand Up @@ -140,9 +142,9 @@ def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:

def test_call_packed():
@rx.script
def foo(x: Tensor[(3, 4), "float32"]):
def foo(x: Tensor[(3, 3), "float32"]):
# test that we can intro dim vars
z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x), mp=False)
z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False)
return z

check_roundtrip(foo)
Expand All @@ -169,11 +171,27 @@ def foo(x: Tensor):

def test_class_irmodule():
@rx.script
class my_module:
def f(x: Tensor[(n, m), _]) -> Tensor:
class MyModule:
@tvm.script.tir
def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])

with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = tir.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

def f(x: Tensor[(n, n), _]) -> Tensor:
return g(x)

def g(y: Tensor[(n, m), _]) -> Tensor:
return y
def g(y: Tensor[(n, n), _]) -> Tensor:
return relax.call_dps((n, n), my_matmul, (y, y))

def h(x, y, z):
_ = my_matmul(x, y, z)
return z

check_roundtrip(my_module)
mod = MyModule()
check_roundtrip(mod)

0 comments on commit 30b5482

Please sign in to comment.