diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index 49d3b14505ba..a40b81c233ef 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -20,8 +20,8 @@ import tvm from tvm.ir import Op -from tvm.meta_schedule.utils import derived_object from tvm.runtime import Object +from tvm.runtime.support import derived_object from ..ir.module import IRModule from . import _ffi_api @@ -31,7 +31,6 @@ BindingBlock, Call, Constant, - Id, DataflowBlock, DataflowVar, DataTypeImm, @@ -39,6 +38,7 @@ ExternFunc, Function, GlobalVar, + Id, If, MatchCast, PrimValue, diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py index 149a66ef7b55..2669459d71a7 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/runtime/support.py @@ -18,6 +18,7 @@ """Runtime support infra of TVM.""" import re +from typing import TypeVar import tvm.ffi @@ -67,3 +68,139 @@ def _regex_match(regex_pattern: str, match_against: str) -> bool: """ match = re.match(regex_pattern, match_against) return match is not None + + +T = TypeVar("T") + + +def derived_object(cls: type[T]) -> type[T]: + """A decorator to register derived subclasses for TVM objects. + + Parameters + ---------- + cls : type + The derived class to be registered. + + Returns + ------- + cls : type + The decorated TVM object. + + Example + ------- + .. code-block:: python + + @register_object("meta_schedule.PyRunner") + class _PyRunner(meta_schedule.Runner): + def __init__(self, f_run: Callable = None): + self.__init_handle_by_constructor__(_ffi_api.RunnerPyRunner, f_run) + + class PyRunner: + _tvm_metadata = { + "cls": _PyRunner, + "methods": ["run"] + } + def run(self, runner_inputs): + raise NotImplementedError + + @derived_object + class LocalRunner(PyRunner): + def run(self, runner_inputs): + ... + """ + + import functools # pylint: disable=import-outside-toplevel + import weakref # pylint: disable=import-outside-toplevel + + def _extract(inst: type, name: str): + """Extract function from intrinsic class.""" + + def method(*args, **kwargs): + return getattr(inst, name)(*args, **kwargs) + + for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]): + # extract functions that differ from the base class + if not hasattr(base_cls, name): + continue + if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__": + continue + return method + + # for task scheduler return None means calling default function + # otherwise it will trigger a TVMError of method not implemented + # on the c++ side when you call the method, __str__ not required + return None + + assert isinstance(cls.__base__, type) + if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": # type: ignore + raise TypeError( + ( + f"Inheritance from a decorated object `{cls.__name__}` is not allowed. " + f"Please inherit from `{cls.__name__}._cls`." + ) + ) + assert hasattr( + cls, "_tvm_metadata" + ), "Please use the user-facing method overriding class, i.e., PyRunner." + + base = cls.__base__ + metadata = getattr(base, "_tvm_metadata") + fields = metadata.get("fields", []) + methods = metadata.get("methods", []) + + class TVMDerivedObject(metadata["cls"]): # type: ignore + """The derived object to avoid cyclic dependency.""" + + _cls = cls + _type = "TVMDerivedObject" + + def __init__(self, *args, **kwargs): + """Constructor.""" + self._inst = cls(*args, **kwargs) + + super().__init__( + # the constructor's parameters, builder, runner, etc. + *[getattr(self._inst, name) for name in fields], + # the function methods, init_with_tune_context, build, run, etc. + *[_extract(self._inst, name) for name in methods], + ) + + # for task scheduler hybrid funcs in c++ & python side + # using weakref to avoid cyclic dependency + self._inst._outer = weakref.ref(self) + + def __getattr__(self, name): + import inspect # pylint: disable=import-outside-toplevel + + try: + # fall back to instance attribute if there is not any + # return self._inst.__getattribute__(name) + result = self._inst.__getattribute__(name) + except AttributeError: + result = super(TVMDerivedObject, self).__getattr__(name) + + if inspect.ismethod(result): + + def method(*args, **kwargs): + return result(*args, **kwargs) + + # set __own__ to aviod implicit deconstruction + setattr(method, "__own__", self) + return method + + return result + + def __setattr__(self, name, value): + if name not in ["_inst", "key", "handle"]: + self._inst.__setattr__(name, value) + else: + super(TVMDerivedObject, self).__setattr__(name, value) + + functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__) # type: ignore + TVMDerivedObject.__name__ = cls.__name__ + TVMDerivedObject.__doc__ = cls.__doc__ + TVMDerivedObject.__module__ = cls.__module__ + for key, value in cls.__dict__.items(): + if isinstance(value, (classmethod, staticmethod)): + setattr(TVMDerivedObject, key, value) + return TVMDerivedObject diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 5c4a2b91f5d7..120d652dd817 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -108,3 +108,4 @@ from . import stmt_functor from .build import build from .pipeline import get_tir_pipeline, get_default_tir_pipeline +from .functor import PyStmtExprVisitor, PyStmtExprMutator diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py new file mode 100644 index 000000000000..06985f6645ec --- /dev/null +++ b/python/tvm/tir/functor.py @@ -0,0 +1,2051 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, arguments-differ +"""The expression and statement functor of TIR.""" +from typing import Callable + +import tvm +from tvm.ir import PrimExpr +from tvm.runtime import Object +from tvm.runtime.support import derived_object + +from . import _ffi_api +from .expr import ( + EQ, + GE, + GT, + LE, + LT, + NE, + Add, + And, + Broadcast, + BufferLoad, + Call, + Cast, + Div, + FloatImm, + FloorDiv, + FloorMod, + IntImm, + Let, + Max, + Min, + Mod, + Mul, + Not, + Or, + ProducerLoad, + Ramp, + Reduce, + Select, + Shuffle, + SizeVar, + StringImm, + Sub, + Var, +) +from .stmt import ( + Allocate, + AllocateConst, + AssertStmt, + AttrStmt, + Block, + BlockRealize, + BufferRealize, + BufferStore, + DeclBuffer, + Evaluate, + For, + IfThenElse, + LetStmt, + SeqStmt, + Stmt, + While, +) + +visitor = derived_object +""" +A decorator to wrap user-customized PyStmtExprVisitor as TVM object _PyStmtExprVisitor. + +Parameters +---------- +visitor_cls : PyStmtExprVisitor + The user-customized PyStmtExprVisitor. + +Returns +------- +cls : _PyStmtExprVisitor + The decorated TVM object _PyStmtExprVisitor(StmtExprVisitor on the C++ side). + +Example +------- +.. code-block:: python + + @tir.functor.stmt_expr_visitor + class MyStmtExprVisitor(PyStmtExprVisitor): + # customize visit function + def visit_call_(self, op: Call) -> None: + # just for demo purposes + ... + # myvisitor is now a special visitor that visit every Call with + # user-customized visit_call_ + myvisitor = MyStmtExprVisitor() + # apply myvisitor to PrimExpr and Stmt + myvisitor.visit_expr(expr) + myvisitor.visit_stmt(stmt) +""" + +mutator = derived_object +""" +A decorator to wrap user-customized PyStmtExprMutator as TVM object _PyStmtExprMutator. + +Parameters +---------- +mutator_cls : PyStmtExprMutator + The user-customized PyStmtExprMutator. + +Returns +------- +cls : _PyStmtExprMutator + The decorated TVM object _PyStmtExprMutator(StmtExprMutator on the C++ side). + +Example +------- +.. code-block:: python + + @tir.functor.stmt_expr_mutator + class MyStmtExprMutator(PyStmtExprMutator): + # customize rewrite function + def visit_add_(self, op: Add) -> PrimExpr: + # just for demo purposes + ... + + # mymutator is now a special mutator that rewrite every Add with + # user-customized visit_add_ + mymutator = MyStmtExprMutator() + # apply mymutator to PrimExpr and Stmt + mymutator.visit_expr(expr) + mymutator.visit_stmt(stmt) +""" + + +@tvm.ffi.register_object("tir.PyStmtExprVisitor") +class _PyStmtExprVisitor(Object): + """ + An internal wrapper to interface between C++ and Python StmtExprVisitor. + This is the TVM object that wraps PyStmtExprVisitor. + + Do not use this class directly. Use PyStmtExprVisitor instead. + + See also: PyStmtExprVisitor, stmt_expr_visitor + """ + + def __init__( + self, + f_visit_stmt: Callable = None, + f_visit_expr: Callable = None, + # Stmt + f_visit_let_stmt: Callable = None, + f_visit_attr_stmt: Callable = None, + f_visit_if_then_else: Callable = None, + f_visit_for: Callable = None, + f_visit_while: Callable = None, + f_visit_allocate: Callable = None, + f_visit_allocate_const: Callable = None, + f_visit_decl_buffer: Callable = None, + f_visit_buffer_store: Callable = None, + f_visit_buffer_realize: Callable = None, + f_visit_assert_stmt: Callable = None, + f_visit_seq_stmt: Callable = None, + f_visit_evaluate: Callable = None, + f_visit_block: Callable = None, + f_visit_block_realize: Callable = None, + # PrimExpr + f_visit_var: Callable = None, + f_visit_size_var: Callable = None, + f_visit_buffer_load: Callable = None, + f_visit_producer_load: Callable = None, + f_visit_let: Callable = None, + f_visit_call: Callable = None, + f_visit_add: Callable = None, + f_visit_sub: Callable = None, + f_visit_mul: Callable = None, + f_visit_div: Callable = None, + f_visit_mod: Callable = None, + f_visit_floor_div: Callable = None, + f_visit_floor_mod: Callable = None, + f_visit_min: Callable = None, + f_visit_max: Callable = None, + f_visit_eq: Callable = None, + f_visit_ne: Callable = None, + f_visit_lt: Callable = None, + f_visit_le: Callable = None, + f_visit_gt: Callable = None, + f_visit_ge: Callable = None, + f_visit_and: Callable = None, + f_visit_or: Callable = None, + f_visit_reduce: Callable = None, + f_visit_cast: Callable = None, + f_visit_not: Callable = None, + f_visit_select: Callable = None, + f_visit_ramp: Callable = None, + f_visit_broadcast: Callable = None, + f_visit_shuffle: Callable = None, + f_visit_int_imm: Callable = None, + f_visit_float_imm: Callable = None, + f_visit_string_imm: Callable = None, + ) -> None: + """Constructor.""" + self.__init_handle_by_constructor__( + _ffi_api.MakePyStmtExprVisitor, # type: ignore + f_visit_stmt, + f_visit_expr, + # Stmt + f_visit_let_stmt, + f_visit_attr_stmt, + f_visit_if_then_else, + f_visit_for, + f_visit_while, + f_visit_allocate, + f_visit_allocate_const, + f_visit_decl_buffer, + f_visit_buffer_store, + f_visit_buffer_realize, + f_visit_assert_stmt, + f_visit_seq_stmt, + f_visit_evaluate, + f_visit_block, + f_visit_block_realize, + # PrimExpr + f_visit_var, + f_visit_size_var, + f_visit_buffer_load, + f_visit_producer_load, + f_visit_let, + f_visit_call, + f_visit_add, + f_visit_sub, + f_visit_mul, + f_visit_div, + f_visit_mod, + f_visit_floor_div, + f_visit_floor_mod, + f_visit_min, + f_visit_max, + f_visit_eq, + f_visit_ne, + f_visit_lt, + f_visit_le, + f_visit_gt, + f_visit_ge, + f_visit_and, + f_visit_or, + f_visit_reduce, + f_visit_cast, + f_visit_not, + f_visit_select, + f_visit_ramp, + f_visit_broadcast, + f_visit_shuffle, + f_visit_int_imm, + f_visit_float_imm, + f_visit_string_imm, + ) + + +class PyStmtExprVisitor: + """ + A Python StmtExprVisitor to define custom visitor for both Stmt and PrimExpr. + + Users can customize any of the visit function. + """ + + _tvm_metadata = { + "cls": _PyStmtExprVisitor, + "methods": [ + "visit_stmt", + "visit_expr", + # Stmt + "visit_let_stmt_", + "visit_attr_stmt_", + "visit_if_then_else_", + "visit_for_", + "visit_while_", + "visit_allocate_", + "visit_allocate_const_", + "visit_decl_buffer_", + "visit_buffer_store_", + "visit_buffer_realize_", + "visit_assert_stmt_", + "visit_seq_stmt_", + "visit_evaluate_", + "visit_block_", + "visit_block_realize_", + # PrimExpr + "visit_var_", + "visit_size_var_", + "visit_buffer_load_", + "visit_producer_load_", + "visit_let_", + "visit_call_", + "visit_add_", + "visit_sub_", + "visit_mul_", + "visit_div_", + "visit_mod_", + "visit_floor_div_", + "visit_floor_mod_", + "visit_min_", + "visit_max_", + "visit_eq_", + "visit_ne_", + "visit_lt_", + "visit_le_", + "visit_gt_", + "visit_ge_", + "visit_and_", + "visit_or_", + "visit_reduce_", + "visit_cast_", + "visit_not_", + "visit_select_", + "visit_ramp_", + "visit_broadcast_", + "visit_shuffle_", + "visit_int_imm_", + "visit_float_imm_", + "visit_string_imm_", + ], + } + + def visit_stmt(self, stmt: Stmt) -> None: + """Visit a Stmt. + + Parameters + ---------- + stmt : Stmt + The Stmt to be visited. + """ + _ffi_api.PyStmtExprVisitorVisitStmt(self._outer(), stmt) # type: ignore + + def visit_expr(self, expr: PrimExpr) -> None: + """Visit a PrimExpr. + + Parameters + ---------- + expr : PrimExpr + The PrimExpr to be visited. + """ + _ffi_api.PyStmtExprVisitorVisitExpr(self._outer(), expr) # type: ignore + + def visit_attr_stmt_(self, op: AttrStmt) -> None: + """Visit AttrStmt. + Users can customize this function to overwrite VisitStmt_(const AttrStmtNode* op) + on the C++ side. + + Parameters + ---------- + op : AttrStmt + The AttrStmt to be visited. + """ + print("visit_attr_stmt_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_if_then_else_(self, op: IfThenElse) -> None: + """Visit IfThenElse. + Users can customize this function to overwrite VisitStmt_(const IfThenElseNode* op) + on the C++ side. + + Parameters + ---------- + op : IfThenElse + The IfThenElse to be visited. + """ + print("visit_if_then_else_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_let_stmt_(self, op: LetStmt) -> None: + """Visit LetStmt. + Users can customize this function to overwrite VisitStmt_(const LetStmtNode* op) + on the C++ side. + + Parameters + ---------- + op : LetStmt + The LetStmt to be visited. + """ + print("visit_let_stmt_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_for_(self, op: For) -> None: + """Visit For. + Users can customize this function to overwrite VisitStmt_(const ForNode* op) + on the C++ side. + + Parameters + ---------- + op : For + The For to be visited. + """ + print("visit_for_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_while_(self, op: While) -> None: + """Visit While. + Users can customize this function to overwrite VisitStmt_(const WhileNode* op) + on the C++ side. + + Parameters + ---------- + op : While + The While to be visited. + """ + print("visit_while_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_allocate_(self, op: Allocate) -> None: + """Visit Allocate. + Users can customize this function to overwrite VisitStmt_(const AllocateNode* op) + on the C++ side. + + Parameters + ---------- + op : Allocate + The Allocate to be visited. + """ + print("visit_allocate_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_allocate_const_(self, op: AllocateConst) -> None: + """Visit AllocateConst. + Users can customize this function to overwrite VisitStmt_(const AllocateConstNode* op) + on the C++ side. + + Parameters + ---------- + op : AllocateConst + The AllocateConst to be visited. + """ + print("visit_allocate_const_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_decl_buffer_(self, op: DeclBuffer) -> None: + """Visit DeclBuffer. + Users can customize this function to overwrite VisitStmt_(const DeclBufferNode* op) + on the C++ side. + + Parameters + ---------- + op : DeclBuffer + The DeclBuffer to be visited. + """ + print("visit_decl_buffer_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_buffer_store_(self, op: BufferStore) -> None: + """Visit BufferStore. + Users can customize this function to overwrite VisitStmt_(const BufferStoreNode* op) + on the C++ side. + + Parameters + ---------- + op : BufferStore + The BufferStore to be visited. + """ + print("visit_buffer_store_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_buffer_realize_(self, op: BufferRealize) -> None: + """Visit BufferRealize. + Users can customize this function to overwrite VisitStmt_(const BufferRealizeNode* op) + on the C++ side. + + Parameters + ---------- + op : BufferRealize + The BufferRealize to be visited. + """ + print("visit_buffer_realize_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_assert_stmt_(self, op: AssertStmt) -> None: + """Visit AssertStmt. + Users can customize this function to overwrite VisitStmt_(const AssertStmtNode* op) + on the C++ side. + + Parameters + ---------- + op : AssertStmt + The AssertStmt to be visited. + """ + print("visit_assert_stmt_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_seq_stmt_(self, op: SeqStmt) -> None: + """Visit SeqStmt. + Users can customize this function to overwrite VisitStmt_(const SeqStmtNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqStmt + The SeqStmt to be visited. + """ + print("visit_seq_stmt_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_evaluate_(self, op: Evaluate) -> None: + """Visit Evaluate. + Users can customize this function to overwrite VisitStmt_(const EvaluateNode* op) + on the C++ side. + + Parameters + ---------- + op : Evaluate + The Evaluate to be visited. + """ + print("visit_evaluate_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_block_(self, op: Block) -> None: + """Visit Block. + Users can customize this function to overwrite VisitStmt_(const BlockNode* op) + on the C++ side. + + Parameters + ---------- + op : Block + The Block to be visited. + """ + print("visit_block_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_block_realize_(self, op: BlockRealize) -> None: + """Visit BlockRealize. + Users can customize this function to overwrite VisitStmt_(const BlockRealizeNode* op) + on the C++ side. + + Parameters + ---------- + op : BlockRealize + The BlockRealize to be visited. + """ + print("visit_block_realize_", op) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_var_(self, op: Var) -> None: + """Visit Var. + + Users can customize this function to overwrite VisitVar_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_size_var_(self, op: SizeVar) -> None: + """Visit SizeVar. + + Users can customize this function to overwrite VisitSizeVar_(const SizeVarNode* op) + on the C++ side. + + Parameters + ---------- + op : SizeVar + The SizeVar to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_buffer_load_(self, op: BufferLoad) -> None: + """Visit BufferLoad. + + Users can customize this function to overwrite VisitBufferLoad_(const BufferLoadNode* op) + on the C++ side. + + Parameters + ---------- + op : BufferLoad + The BufferLoad to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_producer_load_(self, op: ProducerLoad) -> None: + """Visit ProducerLoad. + + Users can customize this function to overwrite + VisitProducerLoad_(const ProducerLoadNode* op) on the C++ side. + + Parameters + ---------- + op : ProducerLoad + The ProducerLoad to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_let_(self, op: Let) -> None: + """Visit Let. + + Users can customize this function to overwrite VisitLet_(const LetNode* op) + on the C++ side. + + Parameters + ---------- + op : Let + The Let to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_call_(self, op: Call) -> None: + """Visit Call. + + Users can customize this function to overwrite VisitCall_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_add_(self, op: Add) -> None: + """Visit Add. + + Users can customize this function to overwrite VisitAdd_(const AddNode* op) + on the C++ side. + + Parameters + ---------- + op : Add + The Add to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_sub_(self, op: Sub) -> None: + """Visit Sub. + + Users can customize this function to overwrite VisitSub_(const SubNode* op) + on the C++ side. + + Parameters + ---------- + op : Sub + The Sub to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_mul_(self, op: Mul) -> None: + """Visit Mul. + + Users can customize this function to overwrite VisitMul_(const MulNode* op) + on the C++ side. + + Parameters + ---------- + op : Mul + The Mul to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_div_(self, op: Div) -> None: + """Visit Div. + + Users can customize this function to overwrite VisitDiv_(const DivNode* op) + on the C++ side. + + Parameters + ---------- + op : Div + The Div to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_mod_(self, op: Mod) -> None: + """Visit Mod. + + Users can customize this function to overwrite VisitMod_(const ModNode* op) + on the C++ side. + + Parameters + ---------- + op : Mod + The Mod to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_floor_div_(self, op: FloorDiv) -> None: + """Visit FloorDiv. + + Users can customize this function to overwrite VisitFloorDiv_(const FloorDivNode* op) + on the C++ side. + + Parameters + ---------- + op : FloorDiv + The FloorDiv to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_floor_mod_(self, op: FloorMod) -> None: + """Visit FloorMod. + + Users can customize this function to overwrite VisitFloorMod_(const FloorModNode* op) + on the C++ side. + + Parameters + ---------- + op : FloorMod + The FloorMod to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_min_(self, op: Min) -> None: + """Visit Min. + + Users can customize this function to overwrite VisitMin_(const MinNode* op) + on the C++ side. + + Parameters + ---------- + op : Min + The Min to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_max_(self, op: Max) -> None: + """Visit Max. + + Users can customize this function to overwrite VisitMax_(const MaxNode* op) + on the C++ side. + + Parameters + ---------- + op : Max + The Max to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_eq_(self, op: EQ) -> None: + """Visit EQ. + + Users can customize this function to overwrite VisitEQ_(const EQNode* op) + on the C++ side. + + Parameters + ---------- + op : EQ + The EQ to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_ne_(self, op: NE) -> None: + """Visit NE. + + Users can customize this function to overwrite VisitNE_(const NENode* op) + on the C++ side. + + Parameters + ---------- + op : NE + The NE to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_lt_(self, op: LT) -> None: + """Visit LT. + + Users can customize this function to overwrite VisitLT_(const LTNode* op) + on the C++ side. + + Parameters + ---------- + op : LT + The LT to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_le_(self, op: LE) -> None: + """Visit LE. + + Users can customize this function to overwrite VisitLE_(const LENode* op) + on the C++ side. + + Parameters + ---------- + op : LE + The LE to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_gt_(self, op: GT) -> None: + """Visit GT. + + Users can customize this function to overwrite VisitGT_(const GTNode* op) + on the C++ side. + + Parameters + ---------- + op : GT + The GT to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_ge_(self, op: GE) -> None: + """Visit GE. + + Users can customize this function to overwrite VisitGE_(const GENode* op) + on the C++ side. + + Parameters + ---------- + op : GE + The GE to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_and_(self, op: And) -> None: + """Visit And. + + Users can customize this function to overwrite VisitAnd_(const AndNode* op) + on the C++ side. + + Parameters + ---------- + op : And + The And to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_or_(self, op: Or) -> None: + """Visit Or. + + Users can customize this function to overwrite VisitOr_(const OrNode* op) + on the C++ side. + + Parameters + ---------- + op : Or + The Or to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_reduce_(self, op: Reduce) -> None: + """Visit Reduce. + + Users can customize this function to overwrite VisitReduce_(const ReduceNode* op) + on the C++ side. + + Parameters + ---------- + op : Reduce + The Reduce to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_cast_(self, op: Cast) -> None: + """Visit Cast. + + Users can customize this function to overwrite VisitCast_(const CastNode* op) + on the C++ side. + + Parameters + ---------- + op : Cast + The Cast to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_not_(self, op: Not) -> None: + """Visit Not. + + Users can customize this function to overwrite VisitNot_(const NotNode* op) + on the C++ side. + + Parameters + ---------- + op : Not + The Not to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_select_(self, op: Select) -> None: + """Visit Select. + + Users can customize this function to overwrite VisitSelect_(const SelectNode* op) + on the C++ side. + + Parameters + ---------- + op : Select + The Select to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_ramp_(self, op: Ramp) -> None: + """Visit Ramp. + + Users can customize this function to overwrite VisitRamp_(const RampNode* op) + on the C++ side. + + Parameters + ---------- + op : Ramp + The Ramp to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_broadcast_(self, op: Broadcast) -> None: + """Visit Broadcast. + + Users can customize this function to overwrite VisitBroadcast_(const BroadcastNode* op) + on the C++ side. + + Parameters + ---------- + op : Broadcast + The Broadcast to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_shuffle_(self, op: Shuffle) -> None: + """Visit Shuffle. + + Users can customize this function to overwrite VisitShuffle_(const ShuffleNode* op) + on the C++ side. + + Parameters + ---------- + op : Shuffle + The Shuffle to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_int_imm_(self, op: IntImm) -> None: + """Visit IntImm. + + Users can customize this function to overwrite VisitIntImm_(const IntImmNode* op) + on the C++ side. + + Parameters + ---------- + op : IntImm + The IntImm to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_float_imm_(self, op: FloatImm) -> None: + """Visit FloatImm. + + Users can customize this function to overwrite VisitFloatImm_(const FloatImmNode* op) + on the C++ side. + + Parameters + ---------- + op : FloatImm + The FloatImm to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_string_imm_(self, op: StringImm) -> None: + """Visit StringImm. + + Users can customize this function to overwrite VisitStringImm_(const StringImmNode* op) + on the C++ side. + + Parameters + ---------- + op : StringImm + The StringImm to be visited. + """ + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + +@tvm.ffi.register_object("tir.PyStmtExprMutator") +class _PyStmtExprMutator(Object): + """ + A TVM object to support customization of StmtExprMutator on the python side. + This is the decorated result returned from stmt_expr_mutator decorator. + + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: stmt_expr_mutator, PyStmtExprMutator + """ + + def __init__( + self, + f_visit_stmt: Callable = None, + f_visit_expr: Callable = None, + # Stmt + f_visit_let_stmt: Callable = None, + f_visit_attr_stmt: Callable = None, + f_visit_if_then_else: Callable = None, + f_visit_for: Callable = None, + f_visit_while: Callable = None, + f_visit_allocate: Callable = None, + f_visit_allocate_const: Callable = None, + f_visit_decl_buffer: Callable = None, + f_visit_buffer_store: Callable = None, + f_visit_buffer_realize: Callable = None, + f_visit_assert_stmt: Callable = None, + f_visit_seq_stmt: Callable = None, + f_visit_evaluate: Callable = None, + f_visit_block: Callable = None, + f_visit_block_realize: Callable = None, + # PrimExpr + f_visit_var: Callable = None, + f_visit_size_var: Callable = None, + f_visit_buffer_load: Callable = None, + f_visit_producer_load: Callable = None, + f_visit_let: Callable = None, + f_visit_call: Callable = None, + f_visit_add: Callable = None, + f_visit_sub: Callable = None, + f_visit_mul: Callable = None, + f_visit_div: Callable = None, + f_visit_mod: Callable = None, + f_visit_floor_div: Callable = None, + f_visit_floor_mod: Callable = None, + f_visit_min: Callable = None, + f_visit_max: Callable = None, + f_visit_eq: Callable = None, + f_visit_ne: Callable = None, + f_visit_lt: Callable = None, + f_visit_le: Callable = None, + f_visit_gt: Callable = None, + f_visit_ge: Callable = None, + f_visit_and: Callable = None, + f_visit_or: Callable = None, + f_visit_reduce: Callable = None, + f_visit_cast: Callable = None, + f_visit_not: Callable = None, + f_visit_select: Callable = None, + f_visit_ramp: Callable = None, + f_visit_broadcast: Callable = None, + f_visit_shuffle: Callable = None, + f_visit_int_imm: Callable = None, + f_visit_float_imm: Callable = None, + f_visit_string_imm: Callable = None, + ) -> None: + """Constructor.""" + self.__init_handle_by_constructor__( + _ffi_api.MakePyStmtExprMutator, # type: ignore + f_visit_stmt, + f_visit_expr, + # Stmt + f_visit_let_stmt, + f_visit_attr_stmt, + f_visit_if_then_else, + f_visit_for, + f_visit_while, + f_visit_allocate, + f_visit_allocate_const, + f_visit_decl_buffer, + f_visit_buffer_store, + f_visit_buffer_realize, + f_visit_assert_stmt, + f_visit_seq_stmt, + f_visit_evaluate, + f_visit_block, + f_visit_block_realize, + # PrimExpr + f_visit_var, + f_visit_size_var, + f_visit_buffer_load, + f_visit_producer_load, + f_visit_let, + f_visit_call, + f_visit_add, + f_visit_sub, + f_visit_mul, + f_visit_div, + f_visit_mod, + f_visit_floor_div, + f_visit_floor_mod, + f_visit_min, + f_visit_max, + f_visit_eq, + f_visit_ne, + f_visit_lt, + f_visit_le, + f_visit_gt, + f_visit_ge, + f_visit_and, + f_visit_or, + f_visit_reduce, + f_visit_cast, + f_visit_not, + f_visit_select, + f_visit_ramp, + f_visit_broadcast, + f_visit_shuffle, + f_visit_int_imm, + f_visit_float_imm, + f_visit_string_imm, + ) + + +class PyStmtExprMutator: + """ + A Python StmtExprMutator to define custom mutator for both Stmt and PrimExpr. + + Users can customize any of the visit function. + """ + + _tvm_metadata = { + "cls": _PyStmtExprMutator, + "methods": [ + "visit_stmt", + "visit_expr", + # Stmt + "visit_let_stmt_", + "visit_attr_stmt_", + "visit_if_then_else_", + "visit_for_", + "visit_while_", + "visit_allocate_", + "visit_allocate_const_", + "visit_decl_buffer_", + "visit_buffer_store_", + "visit_buffer_realize_", + "visit_assert_stmt_", + "visit_seq_stmt_", + "visit_evaluate_", + "visit_block_", + "visit_block_realize_", + # PrimExpr + "visit_var_", + "visit_size_var_", + "visit_buffer_load_", + "visit_producer_load_", + "visit_let_", + "visit_call_", + "visit_add_", + "visit_sub_", + "visit_mul_", + "visit_div_", + "visit_mod_", + "visit_floor_div_", + "visit_floor_mod_", + "visit_min_", + "visit_max_", + "visit_eq_", + "visit_ne_", + "visit_lt_", + "visit_le_", + "visit_gt_", + "visit_ge_", + "visit_and_", + "visit_or_", + "visit_reduce_", + "visit_cast_", + "visit_not_", + "visit_select_", + "visit_ramp_", + "visit_broadcast_", + "visit_shuffle_", + "visit_int_imm_", + "visit_float_imm_", + "visit_string_imm_", + ], + } + + def visit_expr(self, expr: PrimExpr) -> PrimExpr: + """Visit PrimExpr. + Users can customize this function to overwrite VisitExpr(const PrimExpr& expr) + on the C++ side. + + Parameters + ---------- + expr : PrimExpr + The PrimExpr to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorVisitExpr(self._outer(), expr) # type: ignore + + def visit_stmt(self, stmt: Stmt) -> Stmt: + """Visit Stmt. + Users can customize this function to overwrite VisitStmt(const Stmt& stmt) + on the C++ side. + + Parameters + ---------- + stmt : Stmt + The Stmt to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorVisitStmt(self._outer(), stmt) # type: ignore + + def visit_attr_stmt_(self, op: AttrStmt) -> Stmt: + """Visit AttrStmt. + Users can customize this function to overwrite VisitStmt_(const AttrStmtNode* op) + on the C++ side. + + Parameters + ---------- + op : AttrStmt + The AttrStmt to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_if_then_else_(self, op: IfThenElse) -> Stmt: + """Visit IfThenElse. + Users can customize this function to overwrite VisitStmt_(const IfThenElseNode* op) + on the C++ side. + + Parameters + ---------- + op : IfThenElse + The IfThenElse to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_let_stmt_(self, op: LetStmt) -> Stmt: + """Visit LetStmt. + Users can customize this function to overwrite VisitStmt_(const LetStmtNode* op) + on the C++ side. + + Parameters + ---------- + op : LetStmt + The LetStmt to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_for_(self, op: For) -> Stmt: + """Visit For. + Users can customize this function to overwrite VisitStmt_(const ForNode* op) + on the C++ side. + + Parameters + ---------- + op : For + The For to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_while_(self, op: While) -> Stmt: + """Visit While. + Users can customize this function to overwrite VisitStmt_(const WhileNode* op) + on the C++ side. + + Parameters + ---------- + op : While + The While to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_allocate_(self, op: Allocate) -> Stmt: + """Visit Allocate. + Users can customize this function to overwrite VisitStmt_(const AllocateNode* op) + on the C++ side. + + Parameters + ---------- + op : Allocate + The Allocate to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_allocate_const_(self, op: AllocateConst) -> Stmt: + """Visit AllocateConst. + Users can customize this function to overwrite VisitStmt_(const AllocateConstNode* op) + on the C++ side. + + Parameters + ---------- + op : AllocateConst + The AllocateConst to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_decl_buffer_(self, op: DeclBuffer) -> Stmt: + """Visit DeclBuffer. + Users can customize this function to overwrite VisitStmt_(const DeclBufferNode* op) + on the C++ side. + + Parameters + ---------- + op : DeclBuffer + The DeclBuffer to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_buffer_store_(self, op: BufferStore) -> Stmt: + """Visit BufferStore. + Users can customize this function to overwrite VisitStmt_(const BufferStoreNode* op) + on the C++ side. + + Parameters + ---------- + op : BufferStore + The BufferStore to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_buffer_realize_(self, op: BufferRealize) -> Stmt: + """Visit BufferRealize. + Users can customize this function to overwrite VisitStmt_(const BufferRealizeNode* op) + on the C++ side. + + Parameters + ---------- + op : BufferRealize + The BufferRealize to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_assert_stmt_(self, op: AssertStmt) -> Stmt: + """Visit AssertStmt. + Users can customize this function to overwrite VisitStmt_(const AssertStmtNode* op) + on the C++ side. + + Parameters + ---------- + op : AssertStmt + The AssertStmt to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_seq_stmt_(self, op: SeqStmt) -> Stmt: + """Visit SeqStmt. + Users can customize this function to overwrite VisitStmt_(const SeqStmtNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqStmt + The SeqStmt to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_evaluate_(self, op: Evaluate) -> Stmt: + """Visit Evaluate. + Users can customize this function to overwrite VisitStmt_(const EvaluateNode* op) + on the C++ side. + + Parameters + ---------- + op : Evaluate + The Evaluate to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_block_(self, op: Block) -> Stmt: + """Visit Block. + Users can customize this function to overwrite VisitStmt_(const BlockNode* op) + on the C++ side. + + Parameters + ---------- + op : Block + The Block to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_block_realize_(self, op: BlockRealize) -> Stmt: + """Visit BlockRealize. + Users can customize this function to overwrite VisitStmt_(const BlockRealizeNode* op) + on the C++ side. + + Parameters + ---------- + op : BlockRealize + The BlockRealize to be visited. + + Returns + ------- + result : Stmt + The mutated Stmt. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + + def visit_var_(self, op: Var) -> PrimExpr: + """Visit Var. + + Users can customize this function to overwrite VisitVar_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_size_var_(self, op: SizeVar) -> PrimExpr: + """Visit SizeVar. + + Users can customize this function to overwrite VisitSizeVar_(const SizeVarNode* op) + on the C++ side. + + Parameters + ---------- + op : SizeVar + The SizeVar to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_buffer_load_(self, op: BufferLoad) -> PrimExpr: + """Visit BufferLoad. + + Users can customize this function to overwrite VisitBufferLoad_(const BufferLoadNode* op) + on the C++ side. + + Parameters + ---------- + op : BufferLoad + The BufferLoad to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_producer_load_(self, op: ProducerLoad) -> PrimExpr: + """Visit ProducerLoad. + + Users can customize this function to overwrite + VisitProducerLoad_(const ProducerLoadNode* op) on the C++ side. + + Parameters + ---------- + op : ProducerLoad + The ProducerLoad to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_let_(self, op: Let) -> PrimExpr: + """Visit Let. + + Users can customize this function to overwrite VisitLet_(const LetNode* op) + on the C++ side. + + Parameters + ---------- + op : Let + The Let to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_call_(self, op: Call) -> PrimExpr: + """Visit Call. + + Users can customize this function to overwrite VisitCall_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_add_(self, op: Add) -> PrimExpr: + """Visit Add. + + Users can customize this function to overwrite VisitAdd_(const AddNode* op) + on the C++ side. + + Parameters + ---------- + op : Add + The Add to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_sub_(self, op: Sub) -> PrimExpr: + """Visit Sub. + + Users can customize this function to overwrite VisitSub_(const SubNode* op) + on the C++ side. + + Parameters + ---------- + op : Sub + The Sub to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_mul_(self, op: Mul) -> PrimExpr: + """Visit Mul. + + Users can customize this function to overwrite VisitMul_(const MulNode* op) + on the C++ side. + + Parameters + ---------- + op : Mul + The Mul to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_div_(self, op: Div) -> PrimExpr: + """Visit Div. + + Users can customize this function to overwrite VisitDiv_(const DivNode* op) + on the C++ side. + + Parameters + ---------- + op : Div + The Div to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_mod_(self, op: Mod) -> PrimExpr: + """Visit Mod. + + Users can customize this function to overwrite VisitMod_(const ModNode* op) + on the C++ side. + + Parameters + ---------- + op : Mod + The Mod to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_floor_div_(self, op: FloorDiv) -> PrimExpr: + """Visit FloorDiv. + + Users can customize this function to overwrite VisitFloorDiv_(const FloorDivNode* op) + on the C++ side. + + Parameters + ---------- + op : FloorDiv + The FloorDiv to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_floor_mod_(self, op: FloorMod) -> PrimExpr: + """Visit FloorMod. + + Users can customize this function to overwrite VisitFloorMod_(const FloorModNode* op) + on the C++ side. + + Parameters + ---------- + op : FloorMod + The FloorMod to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_min_(self, op: Min) -> PrimExpr: + """Visit Min. + + Users can customize this function to overwrite VisitMin_(const MinNode* op) + on the C++ side. + + Parameters + ---------- + op : Min + The Min to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_max_(self, op: Max) -> PrimExpr: + """Visit Max. + + Users can customize this function to overwrite VisitMax_(const MaxNode* op) + on the C++ side. + + Parameters + ---------- + op : Max + The Max to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_eq_(self, op: EQ) -> PrimExpr: + """Visit EQ. + + Users can customize this function to overwrite VisitEQ_(const EQNode* op) + on the C++ side. + + Parameters + ---------- + op : EQ + The EQ to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_ne_(self, op: NE) -> PrimExpr: + """Visit NE. + + Users can customize this function to overwrite VisitNE_(const NENode* op) + on the C++ side. + + Parameters + ---------- + op : NE + The NE to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_lt_(self, op: LT) -> PrimExpr: + """Visit LT. + + Users can customize this function to overwrite VisitLT_(const LTNode* op) + on the C++ side. + + Parameters + ---------- + op : LT + The LT to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_le_(self, op: LE) -> PrimExpr: + """Visit LE. + + Users can customize this function to overwrite VisitLE_(const LENode* op) + on the C++ side. + + Parameters + ---------- + op : LE + The LE to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_gt_(self, op: GT) -> PrimExpr: + """Visit GT. + + Users can customize this function to overwrite VisitGT_(const GTNode* op) + on the C++ side. + + Parameters + ---------- + op : GT + The GT to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_ge_(self, op: GE) -> PrimExpr: + """Visit GE. + + Users can customize this function to overwrite VisitGE_(const GENode* op) + on the C++ side. + + Parameters + ---------- + op : GE + The GE to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_and_(self, op: And) -> PrimExpr: + """Visit And. + + Users can customize this function to overwrite VisitAnd_(const AndNode* op) + on the C++ side. + + Parameters + ---------- + op : And + The And to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_or_(self, op: Or) -> PrimExpr: + """Visit Or. + + Users can customize this function to overwrite VisitOr_(const OrNode* op) + on the C++ side. + + Parameters + ---------- + op : Or + The Or to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_reduce_(self, op: Reduce) -> PrimExpr: + """Visit Reduce. + + Users can customize this function to overwrite VisitReduce_(const ReduceNode* op) + on the C++ side. + + Parameters + ---------- + op : Reduce + The Reduce to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_cast_(self, op: Cast) -> PrimExpr: + """Visit Cast. + + Users can customize this function to overwrite VisitCast_(const CastNode* op) + on the C++ side. + + Parameters + ---------- + op : Cast + The Cast to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_not_(self, op: Not) -> PrimExpr: + """Visit Not. + + Users can customize this function to overwrite VisitNot_(const NotNode* op) + on the C++ side. + + Parameters + ---------- + op : Not + The Not to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_select_(self, op: Select) -> PrimExpr: + """Visit Select. + + Users can customize this function to overwrite VisitSelect_(const SelectNode* op) + on the C++ side. + + Parameters + ---------- + op : Select + The Select to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_ramp_(self, op: Ramp) -> PrimExpr: + """Visit Ramp. + + Users can customize this function to overwrite VisitRamp_(const RampNode* op) + on the C++ side. + + Parameters + ---------- + op : Ramp + The Ramp to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_broadcast_(self, op: Broadcast) -> PrimExpr: + """Visit Broadcast. + + Users can customize this function to overwrite VisitBroadcast_(const BroadcastNode* op) + on the C++ side. + + Parameters + ---------- + op : Broadcast + The Broadcast to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_shuffle_(self, op: Shuffle) -> PrimExpr: + """Visit Shuffle. + + Users can customize this function to overwrite VisitShuffle_(const ShuffleNode* op) + on the C++ side. + + Parameters + ---------- + op : Shuffle + The Shuffle to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_int_imm_(self, op: IntImm) -> PrimExpr: + """Visit IntImm. + + Users can customize this function to overwrite VisitIntImm_(const IntImmNode* op) + on the C++ side. + + Parameters + ---------- + op : IntImm + The IntImm to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_float_imm_(self, op: FloatImm) -> PrimExpr: + """Visit FloatImm. + + Users can customize this function to overwrite VisitFloatImm_(const FloatImmNode* op) + on the C++ side. + + Parameters + ---------- + op : FloatImm + The FloatImm to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_string_imm_(self, op: StringImm) -> PrimExpr: + """Visit StringImm. + + Users can customize this function to overwrite VisitStringImm_(const StringImmNode* op) + on the C++ side. + + Parameters + ---------- + op : StringImm + The StringImm to be visited. + + Returns + ------- + result : PrimExpr + The mutated PrimExpr. + """ + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index aca4c99e1197..7ac84ce894a4 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -147,7 +147,7 @@ class LCADetector : public StmtExprVisitor { auto do_collect_itervar_scope = [this](const IterVar& itervar, const PrimExpr& binding) -> const ScopeInfo* { const ScopeInfo* highest_scope = nullptr; - PostOrderVisit(binding, [this, &itervar, &highest_scope](const ObjectRef& obj) { + PostOrderVisit(binding, [this, &highest_scope](const ObjectRef& obj) { if (const VarNode* loop_var = obj.as()) { auto it = loop_scope_map_.find(loop_var); if (it == loop_scope_map_.end()) { diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc new file mode 100644 index 000000000000..6152c99eaf7d --- /dev/null +++ b/src/tir/ir/py_functor.cc @@ -0,0 +1,859 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tir/ir/py_functor.cc + * \brief The python interface of ExprVisitor/ExprMutator, StmtVisitor/StmtMutator, + * StmtExprVisitor/StmtExprMutator. + */ + +#include +#include + +namespace tvm { +namespace tir { + +// ================================================ +// Helper Macros +// ================================================ +#define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \ + void VisitExpr_(const OP* op) override { \ + if (PY_FUNC != nullptr) { \ + PY_FUNC(op); \ + } else { \ + StmtExprVisitor::VisitExpr_(op); \ + } \ + } + +#define IR_EXPR_VISITOR_DEFAULT_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + self->StmtExprVisitor::VisitExpr_(static_cast(n.get())); \ + }); + +#define PY_STMT_VISITOR_DISPATCH(OP, PY_FUNC) \ + void VisitStmt_(const OP* op) override { \ + if (PY_FUNC != nullptr) { \ + PY_FUNC(op); \ + } else { \ + StmtExprVisitor::VisitStmt_(op); \ + } \ + } + +#define PY_STMT_VISITOR_DEFAULT_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + self->StmtExprVisitor::VisitStmt_(static_cast(n.get())); \ + }); + +#define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \ + PrimExpr VisitExpr_(const OP* op) override { \ + if (PY_FUNC != nullptr) { \ + return PY_FUNC(op).cast(); \ + } else { \ + return StmtExprMutator::VisitExpr_(op); \ + } \ + } + +#define PY_EXPR_MUTATOR_DEFAULT_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + return self->StmtExprMutator::VisitExpr_(static_cast(n.get())); \ + }); + +#define PY_STMT_MUTATOR_DISPATCH(OP, PY_FUNC) \ + Stmt VisitStmt_(const OP* op) override { \ + if (PY_FUNC != nullptr) { \ + return PY_FUNC(op).cast(); \ + } else { \ + return StmtExprMutator::VisitStmt_(op); \ + } \ + } + +#define PY_STMT_MUTATOR_DEFAULT_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + return self->StmtExprMutator::VisitStmt_(static_cast(n.get())); \ + }); + +/*! \brief The python interface of StmtExprVisitor. */ +class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { + private: + using TSelf = PyStmtExprVisitorNode; + using FExprType = tvm::NodeFunctor; + using FStmtType = tvm::NodeFunctor; + + public: + // Expression functions + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + ffi::Function f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + ffi::Function f_visit_var{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SizeVarNode* op)` function. */ + ffi::Function f_visit_size_var{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const BufferLoadNode* op)` function. */ + ffi::Function f_visit_buffer_load{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ProducerLoadNode* op)` function. */ + ffi::Function f_visit_producer_load{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const LetNode* op)` function. */ + ffi::Function f_visit_let{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + ffi::Function f_visit_call{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const AddNode* op)` function. */ + ffi::Function f_visit_add{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SubNode* op)` function. */ + ffi::Function f_visit_sub{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const MulNode* op)` function. */ + ffi::Function f_visit_mul{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DivNode* op)` function. */ + ffi::Function f_visit_div{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ModNode* op)` function. */ + ffi::Function f_visit_mod{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FloorDivNode* op)` function. */ + ffi::Function f_visit_floor_div{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FloorModNode* op)` function. */ + ffi::Function f_visit_floor_mod{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const MinNode* op)` function. */ + ffi::Function f_visit_min{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const MaxNode* op)` function. */ + ffi::Function f_visit_max{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const EQNode* op)` function. */ + ffi::Function f_visit_eq{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const NENode* op)` function. */ + ffi::Function f_visit_ne{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const LTNode* op)` function. */ + ffi::Function f_visit_lt{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const LENode* op)` function. */ + ffi::Function f_visit_le{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GTNode* op)` function. */ + ffi::Function f_visit_gt{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GENode* op)` function. */ + ffi::Function f_visit_ge{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const AndNode* op)` function. */ + ffi::Function f_visit_and{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OrNode* op)` function. */ + ffi::Function f_visit_or{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ReduceNode* op)` function. */ + ffi::Function f_visit_reduce{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CastNode* op)` function. */ + ffi::Function f_visit_cast{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const NotNode* op)` function. */ + ffi::Function f_visit_not{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SelectNode* op)` function. */ + ffi::Function f_visit_select{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const RampNode* op)` function. */ + ffi::Function f_visit_ramp{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const BroadcastNode* op)` function. */ + ffi::Function f_visit_broadcast{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShuffleNode* op)` function. */ + ffi::Function f_visit_shuffle{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IntImmNode* op)` function. */ + ffi::Function f_visit_int_imm{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FloatImmNode* op)` function. */ + ffi::Function f_visit_float_imm{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ + ffi::Function f_visit_string_imm{nullptr}; + + // Statement functions + /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` function. */ + ffi::Function f_visit_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ + ffi::Function f_visit_attr_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` function. */ + ffi::Function f_visit_if_then_else{nullptr}; // NOLINT(readability/braces) + /*! \brief The packed function to the `VisitStmt_(const ForNode* op)` function. */ + ffi::Function f_visit_let_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ + ffi::Function f_visit_for{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)` function. */ + ffi::Function f_visit_while{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)` function. */ + ffi::Function f_visit_allocate{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode* op)` function. */ + ffi::Function f_visit_allocate_const{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)` function. */ + ffi::Function f_visit_decl_buffer{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode* op)` function. */ + ffi::Function f_visit_buffer_store{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode* op)` function. */ + ffi::Function f_visit_buffer_realize{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)` function. */ + ffi::Function f_visit_assert_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)` function. */ + ffi::Function f_visit_seq_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)` function. */ + ffi::Function f_visit_evaluate{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)` function. */ + ffi::Function f_visit_block{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode* op)` function. */ + ffi::Function f_visit_block_realize{nullptr}; + + using StmtExprVisitor::VisitExpr; + using StmtExprVisitor::VisitStmt; + + void DefaultVisitExpr(const PrimExpr& expr) { + static FExprType vtable = InitExprVTable(); + vtable(expr, this); + } + + void DefaultVisitStmt(const Stmt& stmt) { + static FStmtType vtable = InitStmtVTable(); + vtable(stmt, this); + } + + void VisitAttrs(AttrVisitor* v) {} + static constexpr const char* _type_key = "tir.PyStmtExprVisitor"; + TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprVisitorNode, Object); + + private: + // Statement functions + PY_STMT_VISITOR_DISPATCH(LetStmtNode, f_visit_let_stmt); + PY_STMT_VISITOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); + PY_STMT_VISITOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); + PY_STMT_VISITOR_DISPATCH(ForNode, f_visit_for); + PY_STMT_VISITOR_DISPATCH(WhileNode, f_visit_while); + PY_STMT_VISITOR_DISPATCH(AllocateNode, f_visit_allocate); + PY_STMT_VISITOR_DISPATCH(AllocateConstNode, f_visit_allocate_const); + PY_STMT_VISITOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer); + PY_STMT_VISITOR_DISPATCH(BufferStoreNode, f_visit_buffer_store); + PY_STMT_VISITOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize); + PY_STMT_VISITOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt); + PY_STMT_VISITOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt); + PY_STMT_VISITOR_DISPATCH(EvaluateNode, f_visit_evaluate); + PY_STMT_VISITOR_DISPATCH(BlockNode, f_visit_block); + PY_STMT_VISITOR_DISPATCH(BlockRealizeNode, f_visit_block_realize); + // Expression functions + PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var); + PY_EXPR_VISITOR_DISPATCH(SizeVarNode, f_visit_size_var); + PY_EXPR_VISITOR_DISPATCH(BufferLoadNode, f_visit_buffer_load); + PY_EXPR_VISITOR_DISPATCH(ProducerLoadNode, f_visit_producer_load); + PY_EXPR_VISITOR_DISPATCH(LetNode, f_visit_let); + PY_EXPR_VISITOR_DISPATCH(CallNode, f_visit_call); + PY_EXPR_VISITOR_DISPATCH(AddNode, f_visit_add); + PY_EXPR_VISITOR_DISPATCH(SubNode, f_visit_sub); + PY_EXPR_VISITOR_DISPATCH(MulNode, f_visit_mul); + PY_EXPR_VISITOR_DISPATCH(DivNode, f_visit_div); + PY_EXPR_VISITOR_DISPATCH(ModNode, f_visit_mod); + PY_EXPR_VISITOR_DISPATCH(FloorDivNode, f_visit_floor_div); + PY_EXPR_VISITOR_DISPATCH(FloorModNode, f_visit_floor_mod); + PY_EXPR_VISITOR_DISPATCH(MinNode, f_visit_min); + PY_EXPR_VISITOR_DISPATCH(MaxNode, f_visit_max); + PY_EXPR_VISITOR_DISPATCH(EQNode, f_visit_eq); + PY_EXPR_VISITOR_DISPATCH(NENode, f_visit_ne); + PY_EXPR_VISITOR_DISPATCH(LTNode, f_visit_lt); + PY_EXPR_VISITOR_DISPATCH(LENode, f_visit_le); + PY_EXPR_VISITOR_DISPATCH(GTNode, f_visit_gt); + PY_EXPR_VISITOR_DISPATCH(GENode, f_visit_ge); + PY_EXPR_VISITOR_DISPATCH(AndNode, f_visit_and); + PY_EXPR_VISITOR_DISPATCH(OrNode, f_visit_or); + PY_EXPR_VISITOR_DISPATCH(ReduceNode, f_visit_reduce); + PY_EXPR_VISITOR_DISPATCH(CastNode, f_visit_cast); + PY_EXPR_VISITOR_DISPATCH(NotNode, f_visit_not); + PY_EXPR_VISITOR_DISPATCH(SelectNode, f_visit_select); + PY_EXPR_VISITOR_DISPATCH(RampNode, f_visit_ramp); + PY_EXPR_VISITOR_DISPATCH(BroadcastNode, f_visit_broadcast); + PY_EXPR_VISITOR_DISPATCH(ShuffleNode, f_visit_shuffle); + PY_EXPR_VISITOR_DISPATCH(IntImmNode, f_visit_int_imm); + PY_EXPR_VISITOR_DISPATCH(FloatImmNode, f_visit_float_imm); + PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm); + + private: + static FExprType InitExprVTable() { + FExprType vtable; + // Set dispatch + IR_EXPR_VISITOR_DEFAULT_DISPATCH(VarNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(SizeVarNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(BufferLoadNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(ProducerLoadNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(LetNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(CallNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(AddNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(SubNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(MulNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(DivNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(ModNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloorDivNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloorModNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(MinNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(MaxNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(EQNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(NENode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(LTNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(LENode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(GTNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(GENode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(AndNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(OrNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(ReduceNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(CastNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(NotNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(SelectNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(RampNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(ShuffleNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(BroadcastNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(IntImmNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloatImmNode); + IR_EXPR_VISITOR_DEFAULT_DISPATCH(StringImmNode); + vtable.Finalize(); + return vtable; + } + + static FStmtType InitStmtVTable() { + FStmtType vtable; + PY_STMT_VISITOR_DEFAULT_DISPATCH(LetStmtNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(AttrStmtNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(IfThenElseNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(ForNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(WhileNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateConstNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(DeclBufferNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferStoreNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferRealizeNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(AssertStmtNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(SeqStmtNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(EvaluateNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockRealizeNode); + vtable.Finalize(); + return vtable; + } +}; + +/*! + * \brief Managed reference to PyStmtExprVisitorNode. + * \sa PyStmtExprVisitorNode + */ +class PyStmtExprVisitor : public ObjectRef { + public: + TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function f_visit_stmt, // + ffi::Function f_visit_expr, // + ffi::Function f_visit_let_stmt, // + ffi::Function f_visit_attr_stmt, // + ffi::Function f_visit_if_then_else, // + ffi::Function f_visit_for, // + ffi::Function f_visit_while, // + ffi::Function f_visit_allocate, // + ffi::Function f_visit_allocate_const, // + ffi::Function f_visit_decl_buffer, // + ffi::Function f_visit_buffer_store, // + ffi::Function f_visit_buffer_realize, // + ffi::Function f_visit_assert_stmt, // + ffi::Function f_visit_seq_stmt, // + ffi::Function f_visit_evaluate, // + ffi::Function f_visit_block, // + ffi::Function f_visit_block_realize, // + ffi::Function f_visit_var, // + ffi::Function f_visit_size_var, // + ffi::Function f_visit_buffer_load, // + ffi::Function f_visit_producer_load, // + ffi::Function f_visit_let, // + ffi::Function f_visit_call, // + ffi::Function f_visit_add, // + ffi::Function f_visit_sub, // + ffi::Function f_visit_mul, // + ffi::Function f_visit_div, // + ffi::Function f_visit_mod, // + ffi::Function f_visit_floor_div, // + ffi::Function f_visit_floor_mod, // + ffi::Function f_visit_min, // + ffi::Function f_visit_max, // + ffi::Function f_visit_eq, // + ffi::Function f_visit_ne, // + ffi::Function f_visit_lt, // + ffi::Function f_visit_le, // + ffi::Function f_visit_gt, // + ffi::Function f_visit_ge, // + ffi::Function f_visit_and, // + ffi::Function f_visit_or, // + ffi::Function f_visit_reduce, // + ffi::Function f_visit_cast, // + ffi::Function f_visit_not, // + ffi::Function f_visit_select, // + ffi::Function f_visit_ramp, // + ffi::Function f_visit_broadcast, // + ffi::Function f_visit_shuffle, // + ffi::Function f_visit_int_imm, // + ffi::Function f_visit_float_imm, // + ffi::Function f_visit_string_imm) { + ObjectPtr n = make_object(); + n->f_visit_stmt = std::move(f_visit_stmt); + n->f_visit_expr = std::move(f_visit_expr); + // Set statement functions + n->f_visit_let_stmt = std::move(f_visit_let_stmt); + n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); + n->f_visit_if_then_else = std::move(f_visit_if_then_else); + n->f_visit_for = std::move(f_visit_for); + n->f_visit_while = std::move(f_visit_while); + n->f_visit_allocate = std::move(f_visit_allocate); + n->f_visit_allocate_const = std::move(f_visit_allocate_const); + n->f_visit_decl_buffer = std::move(f_visit_decl_buffer); + n->f_visit_buffer_store = std::move(f_visit_buffer_store); + n->f_visit_buffer_realize = std::move(f_visit_buffer_realize); + n->f_visit_assert_stmt = std::move(f_visit_assert_stmt); + n->f_visit_seq_stmt = std::move(f_visit_seq_stmt); + n->f_visit_evaluate = std::move(f_visit_evaluate); + n->f_visit_block = std::move(f_visit_block); + n->f_visit_block_realize = std::move(f_visit_block_realize); + // Set expression functions + n->f_visit_var = std::move(f_visit_var); + n->f_visit_size_var = std::move(f_visit_size_var); + n->f_visit_buffer_load = std::move(f_visit_buffer_load); + n->f_visit_producer_load = std::move(f_visit_producer_load); + n->f_visit_let = std::move(f_visit_let); + n->f_visit_call = std::move(f_visit_call); + n->f_visit_add = std::move(f_visit_add); + n->f_visit_sub = std::move(f_visit_sub); + n->f_visit_mul = std::move(f_visit_mul); + n->f_visit_div = std::move(f_visit_div); + n->f_visit_mod = std::move(f_visit_mod); + n->f_visit_floor_div = std::move(f_visit_floor_div); + n->f_visit_floor_mod = std::move(f_visit_floor_mod); + n->f_visit_min = std::move(f_visit_min); + n->f_visit_max = std::move(f_visit_max); + n->f_visit_eq = std::move(f_visit_eq); + n->f_visit_ne = std::move(f_visit_ne); + n->f_visit_lt = std::move(f_visit_lt); + n->f_visit_le = std::move(f_visit_le); + n->f_visit_gt = std::move(f_visit_gt); + n->f_visit_ge = std::move(f_visit_ge); + n->f_visit_and = std::move(f_visit_and); + n->f_visit_or = std::move(f_visit_or); + n->f_visit_reduce = std::move(f_visit_reduce); + n->f_visit_cast = std::move(f_visit_cast); + n->f_visit_not = std::move(f_visit_not); + n->f_visit_select = std::move(f_visit_select); + n->f_visit_ramp = std::move(f_visit_ramp); + n->f_visit_broadcast = std::move(f_visit_broadcast); + n->f_visit_shuffle = std::move(f_visit_shuffle); + n->f_visit_int_imm = std::move(f_visit_int_imm); + n->f_visit_float_imm = std::move(f_visit_float_imm); + n->f_visit_string_imm = std::move(f_visit_string_imm); + return PyStmtExprVisitor(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyStmtExprVisitor, ObjectRef, + PyStmtExprVisitorNode); +}; + +/*! \brief The python interface of StmtExprMutator. */ +class PyStmtExprMutatorNode : public Object, public StmtExprMutator { + private: + using TSelf = PyStmtExprMutatorNode; + using FExprType = tvm::NodeFunctor; + using FStmtType = tvm::NodeFunctor; + + public: + // Expression functions + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + ffi::Function f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + ffi::Function f_visit_var{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SizeVarNode* op)` function. */ + ffi::Function f_visit_size_var{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const BufferLoadNode* op)` function. */ + ffi::Function f_visit_buffer_load{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ProducerLoadNode* op)` function. */ + ffi::Function f_visit_producer_load{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const LetNode* op)` function. */ + ffi::Function f_visit_let{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + ffi::Function f_visit_call{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const AddNode* op)` function. */ + ffi::Function f_visit_add{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SubNode* op)` function. */ + ffi::Function f_visit_sub{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const MulNode* op)` function. */ + ffi::Function f_visit_mul{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DivNode* op)` function. */ + ffi::Function f_visit_div{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ModNode* op)` function. */ + ffi::Function f_visit_mod{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FloorDivNode* op)` function. */ + ffi::Function f_visit_floor_div{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FloorModNode* op)` function. */ + ffi::Function f_visit_floor_mod{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const MinNode* op)` function. */ + ffi::Function f_visit_min{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const MaxNode* op)` function. */ + ffi::Function f_visit_max{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const EQNode* op)` function. */ + ffi::Function f_visit_eq{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const NENode* op)` function. */ + ffi::Function f_visit_ne{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const LTNode* op)` function. */ + ffi::Function f_visit_lt{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const LENode* op)` function. */ + ffi::Function f_visit_le{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GTNode* op)` function. */ + ffi::Function f_visit_gt{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GENode* op)` function. */ + ffi::Function f_visit_ge{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const AndNode* op)` function. */ + ffi::Function f_visit_and{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OrNode* op)` function. */ + ffi::Function f_visit_or{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ReduceNode* op)` function. */ + ffi::Function f_visit_reduce{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CastNode* op)` function. */ + ffi::Function f_visit_cast{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const NotNode* op)` function. */ + ffi::Function f_visit_not{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SelectNode* op)` function. */ + ffi::Function f_visit_select{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const RampNode* op)` function. */ + ffi::Function f_visit_ramp{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const BroadcastNode* op)` function. */ + ffi::Function f_visit_broadcast{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShuffleNode* op)` function. */ + ffi::Function f_visit_shuffle{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IntImmNode* op)` function. */ + ffi::Function f_visit_int_imm{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FloatImmNode* op)` function. */ + ffi::Function f_visit_float_imm{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ + ffi::Function f_visit_string_imm{nullptr}; + + // Statement functions + /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` function. */ + ffi::Function f_visit_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ + ffi::Function f_visit_let_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ + ffi::Function f_visit_attr_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` function. */ + ffi::Function f_visit_if_then_else{nullptr}; // NOLINT(readability/braces) + /*! \brief The packed function to the `VisitStmt_(const ForNode* op)` function. */ + ffi::Function f_visit_for{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)` function. */ + ffi::Function f_visit_while{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)` function. */ + ffi::Function f_visit_allocate{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode* op)` function. */ + ffi::Function f_visit_allocate_const{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)` function. */ + ffi::Function f_visit_decl_buffer{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode* op)` function. */ + ffi::Function f_visit_buffer_store{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode* op)` function. */ + ffi::Function f_visit_buffer_realize{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)` function. */ + ffi::Function f_visit_assert_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)` function. */ + ffi::Function f_visit_seq_stmt{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)` function. */ + ffi::Function f_visit_evaluate{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)` function. */ + ffi::Function f_visit_block{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode* op)` function. */ + ffi::Function f_visit_block_realize{nullptr}; + + using StmtExprMutator::VisitExpr; + using StmtExprMutator::VisitStmt; + + void DefaultVisitExpr(const PrimExpr& expr) { + static FExprType vtable = InitExprVTable(); + vtable(expr, this); + } + + void DefaultVisitStmt(const Stmt& stmt) { + static FStmtType vtable = InitStmtVTable(); + vtable(stmt, this); + } + void VisitAttrs(AttrVisitor* v) {} + static constexpr const char* _type_key = "tir.PyStmtExprMutator"; + TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprMutatorNode, Object); + + private: + // Statement functions + PY_STMT_MUTATOR_DISPATCH(LetStmtNode, f_visit_let_stmt); + PY_STMT_MUTATOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); + PY_STMT_MUTATOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); + PY_STMT_MUTATOR_DISPATCH(ForNode, f_visit_for); + PY_STMT_MUTATOR_DISPATCH(WhileNode, f_visit_while); + PY_STMT_MUTATOR_DISPATCH(AllocateNode, f_visit_allocate); + PY_STMT_MUTATOR_DISPATCH(AllocateConstNode, f_visit_allocate_const); + PY_STMT_MUTATOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer); + PY_STMT_MUTATOR_DISPATCH(BufferStoreNode, f_visit_buffer_store); + PY_STMT_MUTATOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize); + PY_STMT_MUTATOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt); + PY_STMT_MUTATOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt); + PY_STMT_MUTATOR_DISPATCH(EvaluateNode, f_visit_evaluate); + PY_STMT_MUTATOR_DISPATCH(BlockNode, f_visit_block); + PY_STMT_MUTATOR_DISPATCH(BlockRealizeNode, f_visit_block_realize); + // Expression functions + PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var); + PY_EXPR_MUTATOR_DISPATCH(SizeVarNode, f_visit_size_var); + PY_EXPR_MUTATOR_DISPATCH(BufferLoadNode, f_visit_buffer_load); + PY_EXPR_MUTATOR_DISPATCH(ProducerLoadNode, f_visit_producer_load); + PY_EXPR_MUTATOR_DISPATCH(LetNode, f_visit_let); + PY_EXPR_MUTATOR_DISPATCH(CallNode, f_visit_call); + PY_EXPR_MUTATOR_DISPATCH(AddNode, f_visit_add); + PY_EXPR_MUTATOR_DISPATCH(SubNode, f_visit_sub); + PY_EXPR_MUTATOR_DISPATCH(MulNode, f_visit_mul); + PY_EXPR_MUTATOR_DISPATCH(DivNode, f_visit_div); + PY_EXPR_MUTATOR_DISPATCH(ModNode, f_visit_mod); + PY_EXPR_MUTATOR_DISPATCH(FloorDivNode, f_visit_floor_div); + PY_EXPR_MUTATOR_DISPATCH(FloorModNode, f_visit_floor_mod); + PY_EXPR_MUTATOR_DISPATCH(MinNode, f_visit_min); + PY_EXPR_MUTATOR_DISPATCH(MaxNode, f_visit_max); + PY_EXPR_MUTATOR_DISPATCH(EQNode, f_visit_eq); + PY_EXPR_MUTATOR_DISPATCH(NENode, f_visit_ne); + PY_EXPR_MUTATOR_DISPATCH(LTNode, f_visit_lt); + PY_EXPR_MUTATOR_DISPATCH(LENode, f_visit_le); + PY_EXPR_MUTATOR_DISPATCH(GTNode, f_visit_gt); + PY_EXPR_MUTATOR_DISPATCH(GENode, f_visit_ge); + PY_EXPR_MUTATOR_DISPATCH(AndNode, f_visit_and); + PY_EXPR_MUTATOR_DISPATCH(OrNode, f_visit_or); + PY_EXPR_MUTATOR_DISPATCH(ReduceNode, f_visit_reduce); + PY_EXPR_MUTATOR_DISPATCH(CastNode, f_visit_cast); + PY_EXPR_MUTATOR_DISPATCH(NotNode, f_visit_not); + PY_EXPR_MUTATOR_DISPATCH(SelectNode, f_visit_select); + PY_EXPR_MUTATOR_DISPATCH(RampNode, f_visit_ramp); + PY_EXPR_MUTATOR_DISPATCH(BroadcastNode, f_visit_broadcast); + PY_EXPR_MUTATOR_DISPATCH(ShuffleNode, f_visit_shuffle); + PY_EXPR_MUTATOR_DISPATCH(IntImmNode, f_visit_int_imm); + PY_EXPR_MUTATOR_DISPATCH(FloatImmNode, f_visit_float_imm); + PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm); + + private: + private: + static FExprType InitExprVTable() { + FExprType vtable; + // Set dispatch + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(VarNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SizeVarNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(BufferLoadNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ProducerLoadNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LetNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(CallNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(AddNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SubNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MulNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(DivNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ModNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloorDivNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloorModNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MinNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MaxNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(EQNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(NENode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LTNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LENode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(GTNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(GENode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(AndNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(OrNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ReduceNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(CastNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(NotNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SelectNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(RampNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ShuffleNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(BroadcastNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(IntImmNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloatImmNode); + PY_EXPR_MUTATOR_DEFAULT_DISPATCH(StringImmNode); + vtable.Finalize(); + return vtable; + } + + static FStmtType InitStmtVTable() { + FStmtType vtable; + PY_STMT_MUTATOR_DEFAULT_DISPATCH(LetStmtNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(AttrStmtNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(IfThenElseNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(ForNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(WhileNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateConstNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(DeclBufferNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferStoreNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferRealizeNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(AssertStmtNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(SeqStmtNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(EvaluateNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockRealizeNode); + vtable.Finalize(); + return vtable; + } +}; + +/*! \brief Managed reference to PyStmtExprMutatorNode. */ +class PyStmtExprMutator : public ObjectRef { + public: + /*! + * \brief Create a PyStmtExprMutator with customized methods on the python-side. + * \return The PyStmtExprMutator created. + */ + TVM_DLL static PyStmtExprMutator MakePyStmtExprMutator(ffi::Function f_visit_stmt, // + ffi::Function f_visit_expr, // + ffi::Function f_visit_let_stmt, // + ffi::Function f_visit_attr_stmt, // + ffi::Function f_visit_if_then_else, // + ffi::Function f_visit_for, // + ffi::Function f_visit_while, // + ffi::Function f_visit_allocate, // + ffi::Function f_visit_allocate_const, // + ffi::Function f_visit_decl_buffer, // + ffi::Function f_visit_buffer_store, // + ffi::Function f_visit_buffer_realize, // + ffi::Function f_visit_assert_stmt, // + ffi::Function f_visit_seq_stmt, // + ffi::Function f_visit_evaluate, // + ffi::Function f_visit_block, // + ffi::Function f_visit_block_realize, // + ffi::Function f_visit_var, // + ffi::Function f_visit_size_var, // + ffi::Function f_visit_buffer_load, // + ffi::Function f_visit_producer_load, // + ffi::Function f_visit_let, // + ffi::Function f_visit_call, // + ffi::Function f_visit_add, // + ffi::Function f_visit_sub, // + ffi::Function f_visit_mul, // + ffi::Function f_visit_div, // + ffi::Function f_visit_mod, // + ffi::Function f_visit_floor_div, // + ffi::Function f_visit_floor_mod, // + ffi::Function f_visit_min, // + ffi::Function f_visit_max, // + ffi::Function f_visit_eq, // + ffi::Function f_visit_ne, // + ffi::Function f_visit_lt, // + ffi::Function f_visit_le, // + ffi::Function f_visit_gt, // + ffi::Function f_visit_ge, // + ffi::Function f_visit_and, // + ffi::Function f_visit_or, // + ffi::Function f_visit_reduce, // + ffi::Function f_visit_cast, // + ffi::Function f_visit_not, // + ffi::Function f_visit_select, // + ffi::Function f_visit_ramp, // + ffi::Function f_visit_broadcast, // + ffi::Function f_visit_shuffle, // + ffi::Function f_visit_int_imm, // + ffi::Function f_visit_float_imm, // + ffi::Function f_visit_string_imm) { + ObjectPtr n = make_object(); + n->f_visit_stmt = std::move(f_visit_stmt); + n->f_visit_expr = std::move(f_visit_expr); + // Statement functions + n->f_visit_let_stmt = std::move(f_visit_let_stmt); + n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); + n->f_visit_if_then_else = std::move(f_visit_if_then_else); + n->f_visit_for = std::move(f_visit_for); + n->f_visit_while = std::move(f_visit_while); + n->f_visit_allocate = std::move(f_visit_allocate); + n->f_visit_allocate_const = std::move(f_visit_allocate_const); + n->f_visit_decl_buffer = std::move(f_visit_decl_buffer); + n->f_visit_buffer_store = std::move(f_visit_buffer_store); + n->f_visit_buffer_realize = std::move(f_visit_buffer_realize); + n->f_visit_assert_stmt = std::move(f_visit_assert_stmt); + n->f_visit_seq_stmt = std::move(f_visit_seq_stmt); + n->f_visit_evaluate = std::move(f_visit_evaluate); + n->f_visit_block = std::move(f_visit_block); + n->f_visit_block_realize = std::move(f_visit_block_realize); + // Expression functions + n->f_visit_var = std::move(f_visit_var); + n->f_visit_size_var = std::move(f_visit_size_var); + n->f_visit_buffer_load = std::move(f_visit_buffer_load); + n->f_visit_producer_load = std::move(f_visit_producer_load); + n->f_visit_let = std::move(f_visit_let); + n->f_visit_call = std::move(f_visit_call); + n->f_visit_add = std::move(f_visit_add); + n->f_visit_sub = std::move(f_visit_sub); + n->f_visit_mul = std::move(f_visit_mul); + n->f_visit_div = std::move(f_visit_div); + n->f_visit_mod = std::move(f_visit_mod); + n->f_visit_floor_div = std::move(f_visit_floor_div); + n->f_visit_floor_mod = std::move(f_visit_floor_mod); + n->f_visit_min = std::move(f_visit_min); + n->f_visit_max = std::move(f_visit_max); + n->f_visit_eq = std::move(f_visit_eq); + n->f_visit_ne = std::move(f_visit_ne); + n->f_visit_lt = std::move(f_visit_lt); + n->f_visit_le = std::move(f_visit_le); + n->f_visit_gt = std::move(f_visit_gt); + n->f_visit_ge = std::move(f_visit_ge); + n->f_visit_and = std::move(f_visit_and); + n->f_visit_or = std::move(f_visit_or); + n->f_visit_reduce = std::move(f_visit_reduce); + n->f_visit_cast = std::move(f_visit_cast); + n->f_visit_not = std::move(f_visit_not); + n->f_visit_select = std::move(f_visit_select); + n->f_visit_ramp = std::move(f_visit_ramp); + n->f_visit_broadcast = std::move(f_visit_broadcast); + n->f_visit_shuffle = std::move(f_visit_shuffle); + n->f_visit_int_imm = std::move(f_visit_int_imm); + n->f_visit_float_imm = std::move(f_visit_float_imm); + n->f_visit_string_imm = std::move(f_visit_string_imm); + return PyStmtExprMutator(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyStmtExprMutator, ObjectRef, + PyStmtExprMutatorNode); +}; + +// ================================================ +// TVM Register +// ================================================ + +TVM_REGISTER_NODE_TYPE(PyStmtExprVisitorNode); +TVM_REGISTER_NODE_TYPE(PyStmtExprMutatorNode); + +TVM_FFI_REGISTER_GLOBAL("tir.MakePyStmtExprVisitor") + .set_body_typed(PyStmtExprVisitor::MakePyStmtExprVisitor); +TVM_FFI_REGISTER_GLOBAL("tir.MakePyStmtExprMutator") + .set_body_typed(PyStmtExprMutator::MakePyStmtExprMutator); + +// StmtExprVisitor +TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorDefaultVisitExpr") + .set_body_typed([](PyStmtExprVisitor visitor, const PrimExpr& expr) { + visitor->DefaultVisitExpr(expr); + }); +TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorDefaultVisitStmt") + .set_body_typed([](PyStmtExprVisitor visitor, const Stmt& stmt) { + visitor->DefaultVisitStmt(stmt); + }); +TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorVisitStmt") + .set_body_typed([](PyStmtExprVisitor visitor, const Stmt& stmt) { visitor->VisitStmt(stmt); }); +TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorVisitExpr") + .set_body_typed([](PyStmtExprVisitor visitor, const PrimExpr& expr) { + visitor->VisitExpr(expr); + }); + +// StmtExprMutator +TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorDefaultVisitExpr") + .set_body_typed([](PyStmtExprMutator mutator, const PrimExpr& expr) { + return mutator->DefaultVisitExpr(expr); + }); +TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorDefaultVisitStmt") + .set_body_typed([](PyStmtExprMutator mutator, const Stmt& stmt) { + return mutator->DefaultVisitStmt(stmt); + }); +TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorVisitExpr") + .set_body_typed([](PyStmtExprMutator mutator, const PrimExpr& expr) { + return mutator->VisitExpr(expr); + }); +TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorVisitStmt") + .set_body_typed([](PyStmtExprMutator mutator, const Stmt& stmt) { + return mutator->VisitStmt(stmt); + }); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/tir-transform/test_tir_functor.py b/tests/python/tir-transform/test_tir_functor.py new file mode 100644 index 000000000000..e8463027dbe5 --- /dev/null +++ b/tests/python/tir-transform/test_tir_functor.py @@ -0,0 +1,436 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import tir +from tvm.tir import ( + EQ, + LT, + Add, + Cast, + Evaluate, + FloatImm, + For, + IfThenElse, + IntImm, + Max, + Min, + Mul, + PyStmtExprMutator, + PyStmtExprVisitor, + StringImm, + Sub, + Var, +) + + +class ASTLog: + """Helper class to log AST""" + + def __init__(self) -> None: + self.log = [] + self.indent = "\t" + self.level = 0 + + def push_scope(self): + self.level += 1 + + def pop_scope(self): + self.level -= 1 + + def add(self, s: str): + self.log.append(self.indent * self.level + s) + + def __str__(self) -> str: + return "\n".join(self.log) + + +@tir.functor.visitor +class ASTPrinter(PyStmtExprVisitor): + """Print tir AST in structured format. The shape of Node is ignored.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_var_(self, op: Var) -> None: + self.log.add("Stmt: Var") + super().visit_var_(op) + + def visit_add_(self, op: Add) -> None: + self.log.add("Stmt: Add") + super().visit_add_(op) + + +@tir.functor.visitor +class SimpleExprCounter(PyStmtExprVisitor): + """Count expressions without recursion""" + + def __init__(self): + super().__init__() + self.var_count = 0 + self.add_count = 0 + self.mul_count = 0 + + def visit_var_(self, op: Var): + self.var_count += 1 + # Don't recursively visit children to avoid infinite recursion + + def visit_add_(self, op: Add): + self.add_count += 1 + # Visit children manually + super().visit_add_(op) + + def visit_mul_(self, op: Mul): + self.mul_count += 1 + # Visit children manually + super().visit_mul_(op) + + +@tir.functor.mutator +class VariableReplacer(PyStmtExprMutator): + """Replace variables with constants""" + + def __init__(self, replacements): + super().__init__() + self.replacements = replacements + + def visit_var_(self, op: Var): + if op.name in self.replacements: + return IntImm("int32", self.replacements[op.name]) + return op + + +@tir.functor.mutator +class AddToSubMutator(PyStmtExprMutator): + """Convert Add operations to Sub operations""" + + def visit_add_(self, op: Add): + # First mutate the operands + a = self.visit_expr(op.a) + b = self.visit_expr(op.b) + # Convert Add to Sub + return Sub(a, b) + + +@tir.functor.visitor +class SimpleStmtCounter(PyStmtExprVisitor): + """Count statements without recursion""" + + def __init__(self): + super().__init__() + self.for_count = 0 + self.if_count = 0 + self.evaluate_count = 0 + + def visit_for_(self, op: For): + self.for_count += 1 + super().visit_for_(op) + + def visit_if_then_else_(self, op: IfThenElse): + self.if_count += 1 + super().visit_if_then_else_(op) + + def visit_evaluate_(self, op: Evaluate): + self.evaluate_count += 1 + super().visit_evaluate_(op) + + +@tir.functor.mutator +class ForLoopUnroller(PyStmtExprMutator): + """Simple loop unroller for demonstration""" + + def __init__(self, unroll_factor=2): + super().__init__() + self.unroll_factor = unroll_factor + + def visit_for_(self, op: For): + # For demonstration, just return the original for now + # In a real implementation, we would unroll small loops + return super().visit_for_(op) + + +@tir.functor.visitor +class SimpleStmtExprVisitor(PyStmtExprVisitor): + """Visitor that handles both statements and expressions""" + + def __init__(self): + super().__init__() + self.expr_count = 0 + self.stmt_count = 0 + self.var_names = set() + + def visit_var_(self, op: Var): + self.var_names.add(op.name) + self.expr_count += 1 + + def visit_evaluate_(self, op: Evaluate): + self.stmt_count += 1 + # Visit the expression + self.visit_expr(op.value) + + +@tir.functor.mutator +class ComplexMutator(PyStmtExprMutator): + """Mutator that handles both statements and expressions""" + + def __init__(self): + super().__init__() + self.modifications = 0 + + def visit_add_(self, op: Add): + self.modifications += 1 + # Convert a + b to a * 2 + b for demonstration + a = self.visit_expr(op.a) + b = self.visit_expr(op.b) + return Add(Mul(a, IntImm("int32", 2)), b) + + +def test_basic_visitor(): + """Test the basic AST printer visitor""" + expr = Add(Var("x", dtype="int32"), Var("y", dtype="int32")) + printer = ASTPrinter() + printer.visit_expr(expr) + assert str(printer.log) == "\n".join(["Stmt: Add", "Stmt: Var", "Stmt: Var"]) + + +def test_simple_expr_counter(): + """Test simple expression counting visitor""" + x = Var("x", dtype="int32") + y = Var("y", dtype="int32") + + # Create simple expression: x + y + expr = Add(x, y) + + counter = SimpleExprCounter() + counter.visit_expr(expr) + + assert counter.var_count == 2 # x and y + assert counter.add_count == 1 # one add + + +def test_variable_replacer(): + """Test expression mutator that replaces variables""" + x = Var("x", dtype="int32") + y = Var("y", dtype="int32") + expr = Add(x, Mul(y, IntImm("int32", 3))) + + replacer = VariableReplacer({"x": 10, "y": 5}) + result = replacer.visit_expr(expr) + + # Should be Add(IntImm(10), Mul(IntImm(5), IntImm(3))) + assert isinstance(result, Add) + assert isinstance(result.a, IntImm) + assert result.a.value == 10 + assert isinstance(result.b, Mul) + assert isinstance(result.b.a, IntImm) + assert result.b.a.value == 5 + + +def test_add_to_sub_mutator(): + """Test mutator that converts Add to Sub""" + x = Var("x", dtype="int32") + y = Var("y", dtype="int32") + expr = Add(x, y) + + mutator = AddToSubMutator() + result = mutator.visit_expr(expr) + + assert isinstance(result, Sub) + assert isinstance(result.a, Var) + assert isinstance(result.b, Var) + assert result.a.name == "x" + assert result.b.name == "y" + + +def test_simple_stmt_counter(): + """Test statement visitor that counts statements""" + i = Var("i", dtype="int32") + + # Create a simple for loop + loop_body = Evaluate(IntImm("int32", 0)) + for_stmt = For(i, IntImm("int32", 0), IntImm("int32", 10), tir.ForKind.SERIAL, loop_body) + + counter = SimpleStmtCounter() + counter.visit_stmt(for_stmt) + + assert counter.for_count == 1 # One for loop + assert counter.evaluate_count == 1 # One evaluate in the body + + +def test_if_then_else_visitor(): + """Test visitor with if-then-else statements""" + x = Var("x", dtype="int32") + condition = EQ(x, IntImm("int32", 0)) + then_stmt = Evaluate(IntImm("int32", 1)) + else_stmt = Evaluate(IntImm("int32", 2)) + + if_stmt = IfThenElse(condition, then_stmt, else_stmt) + + counter = SimpleStmtCounter() + counter.visit_stmt(if_stmt) + + assert counter.if_count == 1 + assert counter.for_count == 0 + + +def test_simple_stmt_expr_visitor(): + """Test stmt_expr_visitor with mixed statements and expressions""" + x = Var("x", dtype="int32") + y = Var("y", dtype="int32") + + # Create an evaluate statement with an expression + expr = Add(x, y) + stmt = Evaluate(expr) + + visitor = SimpleStmtExprVisitor() + visitor.visit_stmt(stmt) + + assert visitor.stmt_count == 1 # One Evaluate statement + assert visitor.expr_count == 2 # Two variables + assert "x" in visitor.var_names + assert "y" in visitor.var_names + + +def test_complex_mutator(): + """Test stmt_expr_mutator""" + x = Var("x", dtype="int32") + y = Var("y", dtype="int32") + + # Expression with Add operations + expr = Add(x, y) + stmt = Evaluate(expr) + + mutator = ComplexMutator() + result = mutator.visit_stmt(stmt) + print(type(mutator)) + + assert mutator.modifications == 1 # One Add operation modified + assert isinstance(result, Evaluate) + + # Check that the expression was modified + modified_expr = result.value + assert isinstance(modified_expr, Add) + assert isinstance(modified_expr.a, Mul) # First operand should be multiplied by 2 + + +def test_different_expr_types(): + """Test visitor with various expression types""" + x = Var("x", dtype="int32") + + # Test different expression types individually + exprs = [ + IntImm("int32", 42), + FloatImm("float32", 3.14), + StringImm("hello"), + Cast("float32", x), + Min(x, IntImm("int32", 10)), + Max(x, IntImm("int32", 0)), + LT(x, IntImm("int32", 5)), + ] + + # Just test that we can create and visit each type + counter = SimpleExprCounter() + for expr in exprs: + try: + counter.visit_expr(expr) + except Exception as e: + # Some expressions might not be supported, that's ok + pass + + +def test_decorator_functionality(): + """Test that decorators work correctly""" + + # Test that decorated classes are properly wrapped + visitor = SimpleExprCounter() + assert hasattr(visitor, "_outer") # Should have the wrapper functionality + + mutator = VariableReplacer({}) + assert hasattr(mutator, "_outer") + + +def test_empty_expressions(): + """Test handling of simple expressions""" + counter = SimpleExprCounter() + + # Test with just a variable + x = Var("x", dtype="int32") + counter.visit_expr(x) + + assert counter.var_count == 1 + + # Test with just a constant + counter = SimpleExprCounter() + const = IntImm("int32", 5) + counter.visit_expr(const) + + # Constants don't increase var_count + assert counter.var_count == 0 + + +def test_stmt_mutator(): + """Test basic statement mutator functionality""" + x = Var("x", dtype="int32") + stmt = Evaluate(Add(x, IntImm("int32", 1))) + + unroller = ForLoopUnroller() + result = unroller.visit_stmt(stmt) + + # Should return the same statement (no actual unrolling implemented) + assert isinstance(result, Evaluate) + + +def test_nested_expressions(): + """Test with nested expressions""" + x = Var("x", dtype="int32") + y = Var("y", dtype="int32") + z = Var("z", dtype="int32") + + # Create nested expression: (x + y) * z + inner_add = Add(x, y) + expr = Mul(inner_add, z) + + counter = SimpleExprCounter() + counter.visit_expr(expr) + + assert counter.var_count == 3 # x, y, z + assert counter.add_count == 1 # one add + assert counter.mul_count == 1 # one mul + + +def test_simple_mutations(): + """Test simple expression mutations""" + x = Var("x", dtype="int32") + y = Var("y", dtype="int32") + + # Test multiple replacements + expr = Add(x, y) + replacer = VariableReplacer({"x": 1, "y": 2}) + result = replacer.visit_expr(expr) + + assert isinstance(result, Add) + assert isinstance(result.a, IntImm) + assert isinstance(result.b, IntImm) + assert result.a.value == 1 + assert result.b.value == 2 + + +if __name__ == "__main__": + test_basic_visitor() + tvm.testing.main()