Skip to content

Commit

Permalink
[TVMScript] Use // and % for FloorDiv/FloorMod (#9437)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored Nov 3, 2021
1 parent a6c948a commit c99f55f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
5 changes: 5 additions & 0 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def floormod(x, y, span):
return tvm.tir.floormod(x, y, span)


@register
def truncmod(x, y, span):
return tvm.tir.truncmod(x, y, span)


@register
def abs(x, span):
return tvm.tir.abs(x, span)
Expand Down
14 changes: 4 additions & 10 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_preceden

TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(ModNode, " % ", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", ExprPrecedence::kAdditionSubtraction)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", ExprPrecedence::kAdditionSubtraction)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", ExprPrecedence::kRelational)
Expand All @@ -590,17 +591,10 @@ TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", ExprPrecedence::kEquality)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", ExprPrecedence::kAnd)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", ExprPrecedence::kOr)

Doc TVMScriptPrinter::VisitExpr_(const FloorDivNode* op, ExprPrecedence* out_precedence) {
Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
Doc doc;
doc << tir_prefix_ << ".floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}

Doc TVMScriptPrinter::VisitExpr_(const FloorModNode* op, ExprPrecedence* out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
Doc doc;
doc << tir_prefix_ << ".floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
doc << tir_prefix_ << ".truncmod(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}

Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3155,5 +3155,26 @@ def test_primfunc_with_multiple_commreducer():
tvm.ir.assert_structural_equal(func, rt_func, True)


@T.prim_func
def func_div_mod():
a = T.var("int32")
b = T.var("int32")
T.evaluate(a // b)
T.evaluate(a % b)
T.evaluate(a / b)
T.evaluate(T.truncmod(a, b))


def test_div_mod():
func = func_div_mod
rt_func = tvm.script.from_source(func.script())
tvm.ir.assert_structural_equal(func, rt_func, True)

assert isinstance(func.body[0].value, tvm.tir.FloorDiv)
assert isinstance(func.body[1].value, tvm.tir.FloorMod)
assert isinstance(func.body[2].value, tvm.tir.Div)
assert isinstance(func.body[3].value, tvm.tir.Mod)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit c99f55f

Please sign in to comment.