diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 9d09df3d8e5f..9969dd80f5ed 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -19,6 +19,8 @@ import ast from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union +import tvm + from . import dispatch, doc from .error import ParserError @@ -173,18 +175,19 @@ def _visit(self, node: doc.AST) -> Any: isinstance(node, doc.Call) and hasattr(node.func, "attr") and node.func.attr not in ["reads", "writes", "match_buffer", "realize"] - ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)): + ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp, doc.IfExp)): if isinstance(node, doc.BinOp): args = [node.left, node.right] elif isinstance(node, doc.UnaryOp): args = [node.operand] elif isinstance(node, doc.Compare): args = [node.left, *node.comparators] - else: - if isinstance(node, doc.Call): - args = node.args - elif isinstance(node, doc.BoolOp): - args = node.values + elif isinstance(node, doc.IfExp): + args = [node.test, node.body, node.orelse] + elif isinstance(node, doc.Call): + args = node.args + elif isinstance(node, doc.BoolOp): + args = node.values for arg in args: if isinstance(arg, doc.Subscript) and isinstance(arg.slice, (doc.Slice, doc.Tuple)): if isinstance(arg.slice, doc.Slice): @@ -256,6 +259,8 @@ def _visit(self, node: doc.AST) -> Any: value = self._eval_unary_op(fields) elif isinstance(node, doc.BinOp): value = self._eval_bin_op(fields) + elif isinstance(node, doc.IfExp): + value = self._eval_if_exp(fields) elif isinstance(node, doc.Slice): value = self._eval_slice(fields) else: @@ -364,6 +369,30 @@ def _eval_bin_op(self, fields: Dict[str, Any]) -> Any: ], ) + def _eval_if_exp(self, fields: Dict[str, Any]) -> Any: + """The doc AST if-else expression node evaluating method. + + Parameters + ---------- + fields : Dict[str, Any] + The dictionary of if-else expression information, + e.g., test, body, orelse. + + Returns + ------- + res : Any + The evaluation result. + """ + test = self._eval_expr(fields["test"]) + body = self._eval_expr(fields["body"]) + orelse = self._eval_expr(fields["orelse"]) + if isinstance(test, bool): + return body if test else orelse + elif isinstance(test, tvm.tir.PrimExpr) and test.dtype == "bool": + return tvm.tir.op.if_then_else(test, body, orelse) + else: + raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}") + def _eval_slice(self, fields: Dict[str, Any]) -> slice: """The doc AST slice node evaluating method. diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index fd196be72a8c..d28e4680ae16 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -612,5 +612,19 @@ def expected() -> None: tvm.ir.assert_structural_equal(func, expected) +def test_ifexp(): + @T.prim_func(private=True) + def func(A: T.buffer((128, 128), "float32")): + for i, j in T.grid(128, 128): + A[i, j] = i if i < j else j + + @T.prim_func(private=True) + def expected(A: T.buffer((128, 128), "float32")): + for i, j in T.grid(128, 128): + A[i, j] = T.if_then_else(i < j, i, j) + + tvm.ir.assert_structural_equal(func, expected) + + if __name__ == "__main__": tvm.testing.main()