diff --git a/docs/api/python/tir.rst b/docs/api/python/tir.rst index ea1ac669b273..dd08758e92ce 100644 --- a/docs/api/python/tir.rst +++ b/docs/api/python/tir.rst @@ -22,3 +22,12 @@ tvm.tir :imported-members: :exclude-members: PrimExpr, const :autosummary: + + + +tvm.tir.transform +----------------- +.. automodule:: tvm.tir.transform + :members: + :imported-members: + :autosummary: diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 85b393750913..44244df83ff6 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -150,7 +150,7 @@ class RelayExprNode : public BaseExprNode { /*! * \return The checked_type */ - const Type& checked_type() const; + inline const Type& checked_type() const; /*! * \brief Check if the inferred(checked) type of the Expr * is backed by a TTypeNode and return it. diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 476538c5da36..55071911fb80 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -93,6 +93,7 @@ class TypeFunctor { virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; // unreachable, written to stop compiler warning @@ -115,6 +116,7 @@ class TypeFunctor { TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode); TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode); TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode); return vtable; } }; @@ -138,6 +140,7 @@ class TVM_DLL TypeVisitor : void VisitType_(const TypeCallNode* op) override; void VisitType_(const TypeDataNode* op) override; void VisitType_(const PrimTypeNode* op) override; + void VisitType_(const PointerTypeNode* op) override; }; /*! @@ -158,6 +161,7 @@ class TVM_DLL TypeMutator : Type VisitType_(const TypeCallNode* op) override; Type VisitType_(const TypeDataNode* op) override; Type VisitType_(const PrimTypeNode* op) override; + Type VisitType_(const PointerTypeNode* op) override; private: Array MutateArray(Array arr); diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h new file mode 100644 index 000000000000..514967788394 --- /dev/null +++ b/include/tvm/tir/transform.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/transform.h + * \brief TIR specific transformation passes. + */ +#ifndef TVM_TIR_TRANSFORM_H_ +#define TVM_TIR_TRANSFORM_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace tir { +namespace transform { + +using tvm::transform::Pass; +using tvm::transform::PassNode; +using tvm::transform::PassInfo; +using tvm::transform::PassInfoNode; +using tvm::transform::PassContext; +using tvm::transform::PassContextNode; +using tvm::transform::Sequential; + +/* + * \brief Create a function pass that optimizes PrimFuncs. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the function pass. + * \param name The name of the function pass. + * \param required The list of the passes that the function pass is dependent on. + * + * \return The created function pass. + */ +TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< + PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required); + +/*! + * \brief Create PrimFuncPass to combine context calls in the host function. + * + * \return The pass. + */ +Pass CombineContextCall(); + +} // namespace transform +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_TRANSFORM_H_ diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index d4d389ad5f9a..f0d4d931a370 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -45,3 +45,4 @@ from . import ir_builder from . import ir_pass +from . import transform diff --git a/python/tvm/tir/transform/__init__.py b/python/tvm/tir/transform/__init__.py new file mode 100644 index 000000000000..5947f413ba99 --- /dev/null +++ b/python/tvm/tir/transform/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Namespace of all TIR transformations""" +# pylint: disable=wildcard-import, invalid-name + +from .function_pass import prim_func_pass, PrimFuncPass +from .transform import * diff --git a/python/tvm/tir/transform/_ffi_api.py b/python/tvm/tir/transform/_ffi_api.py new file mode 100644 index 000000000000..86f7bdf5dac3 --- /dev/null +++ b/python/tvm/tir/transform/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.transform""" +import tvm._ffi + + +tvm._ffi._init_api("tir.transform", __name__) diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py new file mode 100644 index 000000000000..93bb996084f4 --- /dev/null +++ b/python/tvm/tir/transform/function_pass.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TIR specific function pass support.""" +import inspect +import functools + +import tvm._ffi +from tvm.ir.transform import Pass, PassInfo + +from . import _ffi_api + + +@tvm._ffi.register_object("tir.PrimFuncPass") +class PrimFuncPass(Pass): + """A pass that works on each :py:func:`tvm.tir.PrimFunc` in a module. A function + pass class should be created through py:func:`tvm.tir.transform.function_pass`. + """ + + +def _wrap_class_function_pass(pass_cls, pass_info): + """Wrap a python class as function pass""" + class PyFunctionPass(PrimFuncPass): + """Internal wrapper class to create a class instance.""" + def __init__(self, *args, **kwargs): + # initialize handle in cass pass_cls creation failed.fg + self.handle = None + inst = pass_cls(*args, **kwargs) + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_function(func, mod, ctx) + self.__init_handle_by_constructor__( + _ffi_api.CreatePrimFuncPass, _pass_func, pass_info) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__) + PyFunctionPass.__name__ = pass_cls.__name__ + PyFunctionPass.__doc__ = pass_cls.__doc__ + PyFunctionPass.__module__ = pass_cls.__module__ + return PyFunctionPass + + +def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None): + """Decorate a function pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(PrimFunc, IRModule, PassContext) -> PrimFunc]] + The transformation function or class. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the function pass is dependent on. + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new FunctionPass will be returned when we decorate a pass function. + A new FunctionPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a function pass class. + + .. code-block:: python + + @tvm.tir.transform.prim_func_pass(opt_level=1) + class TestReplaceFunc: + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + # just for demo purposes + # transform func to new_func + return self.new_func + + The following code creates a function pass by decorating + a user defined transform function. + + .. code-block:: python + + @tvm.tir.transform.prim_func_pass(opt_level=2) + def transform(func, mod, ctx): + # my transformations here. + return func + + function_pass = transform + assert isinstance(function_pass, transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now constant folding should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the funtion pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_function_pass(pass_arg): + """Internal function that creates a function pass""" + fname = name if name else pass_arg.__name__ + info = PassInfo(opt_level, fname, required) + if inspect.isclass(pass_arg): + return _wrap_class_function_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + return _ffi_api.MakeFunctionPass(pass_arg, info) + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py new file mode 100644 index 000000000000..1eec94e054e6 --- /dev/null +++ b/python/tvm/tir/transform/transform.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Wrapping existing transformations.""" +# pylint: disable=invalid-name + +from . import _ffi_api + + +def CombineContextCall(): + """Combine context calls in the host function. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.CombineContextCall() diff --git a/src/ir/module.cc b/src/ir/module.cc index 45f39d5ade88..7d743942bb57 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -158,13 +158,13 @@ void IRModuleNode::Add(const GlobalVar& var, GetRef(ptr)); } - auto type = checked_func->checked_type(); + Type type = checked_func->checked_type(); CHECK(type.as() == nullptr); if (functions.find(var) != functions.end()) { CHECK(update) << "Already have definition for " << var->name_hint; - auto old_type = functions[var].as()->checked_type(); + auto old_type = functions[var]->checked_type(); CHECK(relay::AlphaEqual(type, old_type)) << "Module#update changes type, not possible in this mode."; } diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index cbd3538b066c..9d9167fa1c0f 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -93,6 +93,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { void TypeVisitor::VisitType_(const PrimTypeNode* op) { } +void TypeVisitor::VisitType_(const PointerTypeNode* op) { + this->VisitType(op->element_type); +} + Type TypeMutator::VisitType(const Type& t) { return t.defined() ? TypeFunctor::VisitType(t) : t; } @@ -209,6 +213,16 @@ Type TypeMutator::VisitType_(const PrimTypeNode* op) { return GetRef(op); } +Type TypeMutator::VisitType_(const PointerTypeNode* op) { + Type element_type = VisitType(op->element_type); + + if (element_type.same_as(op->element_type)) { + return GetRef(op); + } else { + return PointerType(element_type); + } +} + // Implements bind. class TypeBinder : public TypeMutator { public: diff --git a/src/relay/analysis/alpha_equal.cc b/src/relay/analysis/alpha_equal.cc index 540284848d7c..28c768138be3 100644 --- a/src/relay/analysis/alpha_equal.cc +++ b/src/relay/analysis/alpha_equal.cc @@ -202,6 +202,22 @@ class AlphaEqualHandler: return LeafObjectEqual(GetRef(lhs), other); } + bool VisitType_(const PrimTypeNode* lhs, const Type& other) final { + if (const PrimTypeNode* rhs = other.as()) { + return lhs->dtype == rhs->dtype; + } else { + return false; + } + } + + bool VisitType_(const PointerTypeNode* lhs, const Type& other) final { + if (const PointerTypeNode* rhs = other.as()) { + return TypeEqual(lhs->element_type, rhs->element_type); + } else { + return false; + } + } + bool VisitType_(const TypeVarNode* lhs, const Type& other) final { if (const TypeVarNode* rhs = other.as()) { if (lhs->kind != rhs->kind) return false; diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 7dc23b6f9fd1..e9ff23495a3b 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -310,6 +310,9 @@ TVM_REGISTER_GLOBAL("target.Build") } }); +TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule") +.set_body_typed(ToIRModule); + // Export two auxiliary function to the runtime namespace. TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC") .set_body_typed(PackImportsToC); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 7464e3ad4370..0891c47ab58c 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -23,10 +23,12 @@ */ #include #include +#include namespace tvm { namespace tir { +// Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, @@ -43,6 +45,7 @@ PrimFunc::PrimFunc(Array params, n->ret_type = std::move(ret_type); n->buffer_map = std::move(buffer_map); n->attrs = std::move(attrs); + n->checked_type_ = n->func_type_annotation(); data_ = std::move(n); } diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc new file mode 100644 index 000000000000..f991e908ca02 --- /dev/null +++ b/src/tir/ir/transform.cc @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/ir/transform.cc + * \brief TIR specific transformation passes. + */ +#include +#include +#include + + +namespace tvm { +namespace tir { +namespace transform { + + +/*! + * \brief Function level pass that applies transformations to all + * TIR functions within the module. + */ +class PrimFuncPassNode : public PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The pass function called on each. */ + runtime::TypedPackedFunc pass_func; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pass_info", &pass_info); + } + + /*! + * \brief Run a function pass on given pass context. + * + * \param mod The module that an optimization pass is applied on. + * \param pass_ctx The context that an optimization pass executes on. + * + * \return Return the updated module. + */ + IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; + + /*! + * \brief Get the pass information/meta data. + */ + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "tir.PrimFuncPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncPassNode, PassNode); +}; + +class PrimFuncPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL PrimFuncPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode); +}; + +PrimFuncPass::PrimFuncPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + auto n = make_object(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform Module -> Module optimizations at the PrimFunc level. +IRModule PrimFuncPassNode::operator()(const IRModule& mod, + const PassContext& pass_ctx) const { + const PassInfo& pass_info = Info(); + CHECK(mod.defined()); + pass_ctx.Trace(mod, pass_info, true); + // Execute the pass function and return a new module. + IRModule updated_mod = IRModule( + mod->functions, mod->type_definitions, mod->Imports()); + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relay::PrimFunc + if (auto* n = it.second.as()) { + PrimFunc func = GetRef(n); + auto updated_func = + pass_func(func, updated_mod, pass_ctx); + updates.push_back({it.first, updated_func}); + } + } + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + pass_ctx.Trace(updated_mod, pass_info, false); + return updated_mod; +} + +Pass CreatePrimFuncPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required) { + PassInfo pass_info = PassInfo(opt_level, name, required); + return PrimFuncPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); + +TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") +.set_body_typed([](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + return PrimFuncPass(pass_func, pass_info); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "PrimFuncPass(" << info->name + << ", opt_level=" << info->opt_level << ")"; +}); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/pass/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc similarity index 88% rename from src/tir/pass/combine_context_call.cc rename to src/tir/transforms/combine_context_call.cc index 5f043bc8ac73..ed352c1e5c25 100644 --- a/src/tir/pass/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -25,7 +25,11 @@ #include #include #include +#include +#include + #include + #include namespace tvm { @@ -114,5 +118,20 @@ LoweredFunc CombineContextCall(LoweredFunc f) { return LoweredFunc(n); } +namespace transform { + +Pass CombineContextCall() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = ContextCallCombiner().Combine(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "CombineContextCall", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall") +.set_body_typed(CombineContextCall); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_pass_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py similarity index 84% rename from tests/python/unittest/test_tir_pass_combine_context_call.py rename to tests/python/unittest/test_tir_transform_combine_context_call.py index e51d4d874ec9..e76fb33a2c63 100644 --- a/tests/python/unittest/test_tir_pass_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -37,9 +37,15 @@ def device_context(dev_id): ("int32", "fadd", device_context(0), A)) body = ib.get() f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True) - f = tvm.tir.ir_pass.CombineContextCall(f) - assert f.body.value.dtype == "handle" - assert f.body.body.value.dtype == "handle" + + # temp adapter to convert loweredFunc to IRModule + # to test passes in the new style. + mod = tvm.testing.LoweredFuncsToIRModule([f]) + + mod = tvm.tir.transform.CombineContextCall()(mod) + + assert mod["func"].body.value.dtype == "handle" + assert mod["func"].body.body.value.dtype == "handle" if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_prim_func_pass.py b/tests/python/unittest/test_tir_transform_prim_func_pass.py new file mode 100644 index 000000000000..87aecd178909 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_prim_func_pass.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + + +def test_prim_func_pass(): + @tvm.tir.transform.prim_func_pass(opt_level=1) + class TestReplaceFunc: + """Simple test function to replace one argument to another.""" + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + return self.new_func + + x = te.var('x') + y = te.var('y') + b = tvm.tir.decl_buffer((x,), "float32") + stmt = tvm.tir.LetStmt( + x, 10, tvm.tir.Evaluate(x + 1)); + + func = tvm.tir.PrimFunc( + [x, y, b], stmt) + + new_func = tvm.tir.PrimFunc( + [x, y, b], tvm.tir.Evaluate(0)) + + mod = tvm.IRModule({"main": func}) + mod = TestReplaceFunc(new_func)(mod) + + assert tvm.tir.ir_pass.Equal(mod["main"].body, new_func.body) + + +if __name__ == "__main__": + test_prim_func_pass()