From c99f55f0a5ceeda51893d3181ca480737fc9afe3 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 4 Nov 2021 01:12:09 +0800 Subject: [PATCH] [TVMScript] Use // and % for FloorDiv/FloorMod (#9437) --- python/tvm/script/tir/intrin.py | 5 +++++ src/printer/tvmscript_printer.cc | 14 ++++--------- .../unittest/test_tvmscript_roundtrip.py | 21 +++++++++++++++++++ 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index 2e800355bef6..d31e93c72b15 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -121,6 +121,11 @@ def floormod(x, y, span): return tvm.tir.floormod(x, y, span) +@register +def truncmod(x, y, span): + return tvm.tir.truncmod(x, y, span) + + @register def abs(x, span): return tvm.tir.abs(x, span) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index d82ad74fd5c3..f43c8272c083 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -578,7 +578,8 @@ Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_preceden TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", ExprPrecedence::kMultiplicationDivision) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", ExprPrecedence::kMultiplicationDivision) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(ModNode, " % ", ExprPrecedence::kMultiplicationDivision) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", ExprPrecedence::kMultiplicationDivision) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", ExprPrecedence::kMultiplicationDivision) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", ExprPrecedence::kAdditionSubtraction) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", ExprPrecedence::kAdditionSubtraction) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", ExprPrecedence::kRelational) @@ -590,17 +591,10 @@ TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", ExprPrecedence::kEquality) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", ExprPrecedence::kAnd) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", ExprPrecedence::kOr) -Doc TVMScriptPrinter::VisitExpr_(const FloorDivNode* op, ExprPrecedence* out_precedence) { +Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << tir_prefix_ << ".floordiv(" << Print(op->a) << ", " << Print(op->b) << ")"; - return doc; -} - -Doc TVMScriptPrinter::VisitExpr_(const FloorModNode* op, ExprPrecedence* out_precedence) { - *out_precedence = ExprPrecedence::kIdentity; - Doc doc; - doc << tir_prefix_ << ".floormod(" << Print(op->a) << ", " << Print(op->b) << ")"; + doc << tir_prefix_ << ".truncmod(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 93b052ee1d96..4e1308b030f1 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3155,5 +3155,26 @@ def test_primfunc_with_multiple_commreducer(): tvm.ir.assert_structural_equal(func, rt_func, True) +@T.prim_func +def func_div_mod(): + a = T.var("int32") + b = T.var("int32") + T.evaluate(a // b) + T.evaluate(a % b) + T.evaluate(a / b) + T.evaluate(T.truncmod(a, b)) + + +def test_div_mod(): + func = func_div_mod + rt_func = tvm.script.from_source(func.script()) + tvm.ir.assert_structural_equal(func, rt_func, True) + + assert isinstance(func.body[0].value, tvm.tir.FloorDiv) + assert isinstance(func.body[1].value, tvm.tir.FloorMod) + assert isinstance(func.body[2].value, tvm.tir.Div) + assert isinstance(func.body[3].value, tvm.tir.Mod) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))