diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index a55fe6797d45..ff576d4ebb6a 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -260,6 +260,63 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana = nullptr); +//----------------------------------- +// General IR analysis +//----------------------------------- +/*! + * \brief Get all bound variables from expression expr. + * + * Bound variables are all variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array BoundVars(const Expr& expr); + +/*! + * \brief Get free type parameters from expression expr. + * + * Free variables are variables that are not bound by a + * varbinding or a function parameter in the context. + * + * \param expr the expression. + * + * \return List of free vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array FreeVars(const Expr& expr); + +/*! + * \brief Get all variables from expression expr. + * + * \param expr the expression. + * + * \return List of all vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllVars(const Expr& expr); + +/*! + * \brief Get all global variables used in calls in expression expr. + * + * \param expr the expression. + * + * \return List of all global variables called in expr. + */ +TVM_DLL tvm::Array CalledGlobalVars(const Expr& expr); + +/*! + * \brief Get all global variables from expression expr. + * + * AllVars is a superset of BoundVars and FreeVars. + * The union of BoundVars and FreeVars is Allvars. + * + * \param expr the expression. + * + * \return List of all global variables, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); + /*! * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. * diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 0f973db290f8..1a525431dd48 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -47,6 +47,16 @@ def ToNonDataflow() -> tvm.ir.transform.Pass: return _ffi_api.ToNonDataflow() # type: ignore +def LambdaLift(): + """A pass that lifts local functions into global. + + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.LambdaLift() + + def CallTIRRewrite() -> tvm.ir.transform.Pass: """Perform explicit tensor allocation for call_tir. diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc new file mode 100644 index 000000000000..33197308fa1b --- /dev/null +++ b/src/relax/analysis/analysis.cc @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file analysis.cc + * + * \brief Analysis functions for Relax. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +template +struct InsertionSet { + std::unordered_set set; + std::vector data; + void Insert(const T& t) { + if (set.count(t) == 0) { + set.insert(t); + data.push_back(t); + } + } +}; + +class VarVisitor : protected ExprVisitor { + public: + Array Free(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + if (bound_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; + } + + Array Collect() { + Array ret; + for (const auto& v : bound_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array Bound(const Expr& expr) { + this->VisitExpr(expr); + return Collect(); + } + + Array All(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array AllGlobalVars(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : global_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array CalledGlobalVars(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : called_global_vars_.data) { + ret.push_back(v); + } + return ret; + } + + void MarkBounded(const Var& v) { + bound_vars_.Insert(v); + vars_.Insert(v); + } + + void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } + + void VisitExpr_(const FunctionNode* op) final { + for (const auto& param : op->params) { + MarkBounded(param); + } + VisitExpr(op->body); + } + + void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef(op)); } + + void VisitExpr_(const CallNode* call_node) final { + VisitSpan(call_node->span); + VisitExpr(call_node->op); + + for (StructInfo sinfo_arg : call_node->sinfo_args) { + VisitExprDepStructInfoField(sinfo_arg); + } + + for (Expr arg : call_node->args) { + VisitExpr(arg); + } + + if (const GlobalVarNode* global_var_node = call_node->op.as()) { + called_global_vars_.Insert(GetRef(global_var_node)); + } + } + + void VisitBinding_(const VarBindingNode* binding) final { + MarkBounded(binding->var); + VisitExpr(binding->value); + VisitVarDef(binding->var); + } + + void VisitBinding_(const MatchCastNode* binding) final { + MarkBounded(binding->var); + ExprVisitor::VisitBinding_(binding); + } + + private: + InsertionSet vars_; + InsertionSet bound_vars_; + InsertionSet global_vars_; + InsertionSet called_global_vars_; +}; + +tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } + +tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } + +tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } + +tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } + +tvm::Array CalledGlobalVars(const Expr& expr) { + return VarVisitor().CalledGlobalVars(expr); +} + +TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); + +TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); + +TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc new file mode 100644 index 000000000000..f08499036b1c --- /dev/null +++ b/src/relax/transform/lambda_lift.cc @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/lambda_lift.cc + * \brief Lift local functions into global functions. + */ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/* The goal of this class is to lift out any nested functions into top-level + * functions. + * + * We will lift a function out into a global which takes the set of the free + * vars and then return the new created function. + */ +class LambdaLifter : public ExprMutator { + public: + explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ = module; } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + if (auto const* var = call_node->op.as()) { + bool has_closure = HasClosure(GetRef(var)); + auto val = builder_->LookupBinding(GetRef(var)); + // Call "relax.invoke_closure" to invoke closure + if (has_closure && val.as()) { + Var clo_arg = GetRef(var); + if (this->var_remap_.find(var->vid) != this->var_remap_.end()) { + clo_arg = this->var_remap_.at(var->vid); + } + return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {}, + {GetStructInfo(GetRef(call_node))}); + } + } + if (auto global_var_node = call_node->op.as()) { + String rec_name = global_var_node->name_hint; + auto global_var = GetRef(global_var_node); + auto it = lambda_map_.find(global_var); + if (it != lambda_map_.end()) { + // flatten nested call, e.g. call(y)(x) -> call(x, y)) + Array new_args; + for (const auto arg : call->args) { + new_args.push_back(arg); + } + if (const auto* nest_call = it->second.as()) { + for (const auto arg : nest_call->args) { + new_args.push_back(arg); + } + return Call(nest_call->op, new_args, call_node->attrs, call_node->sinfo_args); + } + return Call(it->second, call->args, call_node->attrs, call_node->sinfo_args); + } + } + return std::move(call); + } + + Expr VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + + // TODO(@yongwww): consider appending inner func name into the lifted func name + String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++); + auto global = GlobalVar(lift_func_name); + Array captured_vars = FreeVars(func); + recur_vars_ = CalledGlobalVars(func); + auto all_global_vars = AllGlobalVars(func); + + Array typed_captured_vars; + Map rebinding_map; + for (auto free_var : captured_vars) { + Var var = Var(free_var->name_hint(), GetStructInfo(free_var), free_var->span); + typed_captured_vars.push_back(var); + rebinding_map.Set(free_var, var); + } + + // recursive call + if (!recur_vars_.empty()) { + if (!captured_vars.empty()) { + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + lambda_map_.emplace(recur_vars_.back(), Call(global, fvs)); + } else { + if (recur_vars_.size() > 0) { + lambda_map_.emplace(recur_vars_.back(), global); + } + } + } + + tvm::Array params; + bool all_params_unchanged = true; + for (Var param : func_node->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + } + + Expr body = this->VisitWithNewScope(func_node->body); + Expr visited_func; + + if (all_params_unchanged && body.same_as(func_node->body)) { + visited_func = GetRef(func_node); + } else if (const auto& body_sinfo = MatchStructInfo(body)) { + visited_func = Function(params, body, body_sinfo.value(), func_node->attrs); + } else { + visited_func = Function(params, body, func_node->ret_struct_info, func_node->attrs); + } + auto new_func = Downcast(visited_func); + + Function lifted_func; + bool is_closure = IsClosure(captured_vars); + if (!is_closure) { + lifted_func = Function( + /*params=*/new_func->params, + /*body=*/new_func->body, + /*ret_struct_info=*/new_func->ret_struct_info, + /*attrs=*/new_func->attrs, + /*span=*/new_func->span); + } else { + // Flatten the Closure + std::vector closure_params; + closure_params.reserve(func->params.size() + typed_captured_vars.size()); + for (size_t i = 0; i < func->params.size(); ++i) { + closure_params.emplace_back(func->params[i]); + } + for (size_t i = 0; i < typed_captured_vars.size(); ++i) { + closure_params.emplace_back(typed_captured_vars[i]); + } + + lifted_func = Function(/*params=*/closure_params, + /*body=*/Bind(new_func->body, rebinding_map), + /*ret_struct_info=*/new_func->ret_struct_info, + /*attrs=*/new_func->attrs, + /*span=*/func->span); + + Array param_types; + for (Var param : closure_params) { + CHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_"; + param_types.push_back(param->checked_type_); + } + } + + ICHECK(lifted_func.defined()); + + // Add the lifted function to the module. + UpdateStructInfo(global, GetStructInfo(lifted_func)); + builder_->UpdateFunction(global, lifted_func); + + if (!is_closure) { + return std::move(global); + } else { + // If we need to allocate a closure, + // we pass the variables in its environment here. + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + // Call make_closure intrinsic + return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {}); + } + } + + bool HasClosure(const Var& var) { + auto val = builder_->LookupBinding(var); + if (const auto* value = val.as()) { + IRModule ctx_mod = builder_->GetContextIRModule(); + ICHECK(ctx_mod->functions.size() > 0); + BaseFunc func = ctx_mod->Lookup(GetRef(value)); + if (const auto* func_node = func.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } else if (const auto* seq_expr_node = func_node->body.as()) { + // the return var points to a make_closure intrinsic + if (const auto* var = seq_expr_node->body.as()) { + return HasClosure(GetRef(var)); + } + } + } + } else if (const auto* func_node = val.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } + } else if (const auto* call_node = val.as()) { + // recursive call + auto op = call_node->op; + if (make_closure_op_ == op) { + return true; + } + if (const auto* lv = op.as()) { + return HasClosure(GetRef(lv)); + } + } + return false; + } + + bool IsClosure(const Array& captured_vars) { return captured_vars.size() > 0; } + + IRModule Lift() { + auto glob_funcs = mod_->functions; + for (auto pair : glob_funcs) { + if (auto* n = pair.second.as()) { + auto func = GetRef(n); + func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->attrs); + builder_->UpdateFunction(pair.first, func); + } + } + return builder_->GetContextIRModule(); + } + + private: + std::unordered_map lambda_map_; + Array recur_vars_; + IRModule mod_; + size_t lift_func_num_ = 0; + /*! \brief Cache ops that would be used later to reduce lookup overhead. */ + const Op& make_closure_op_ = Op::Get("relax.make_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); +}; + +namespace transform { + +Pass LambdaLift() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::LambdaLifter(m).Lift(); }; + return CreateModulePass(pass_func, 1, "LambdaLift", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 5846f8116df2..24414f250cbc 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -22,6 +22,51 @@ namespace tvm { namespace relax { +/*! \brief Helper to implement bind params.*/ +class ExprBinder : public ExprMutator { + public: + explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} + + Expr VisitExpr_(const VarNode* op) final { + auto id = GetRef(op); + auto it = args_map_.find(id); + if (it != args_map_.end()) { + return (*it).second; + } else { + return ExprMutator::VisitExpr_(op); + } + } + + private: + const tvm::Map& args_map_; +}; + +/*! + * \brief Bind params on expr + * \param expr The expr where to bind params + * \param args_map The map from param var to the expr it binds to + * \return The result expr after bind params + */ +Expr Bind(const Expr& expr, const tvm::Map& args_map) { + if (const FunctionNode* func = expr.as()) { + Expr new_body = ExprBinder(args_map).VisitExpr(func->body); + Array new_params; + for (size_t i = 0; i < func->params.size(); ++i) { + if (!args_map.count(func->params[i])) { + new_params.push_back(func->params[i]); + } + } + if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { + return expr; + } + // The checked_type_ of the new function is deduced from the function body + // TODO(@relax-team): Should infer the shape from the body as well + return Function(new_params, new_body, NullOpt, func->attrs); + } else { + return ExprBinder(args_map).VisitExpr(expr); + } +} + bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { const DynTensorTypeNode* tt = ty.as(); if (!tt) { diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py new file mode 100644 index 000000000000..fbdb1fbdcea9 --- /dev/null +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -0,0 +1,304 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing +from tvm import relax +import tvm.script +from tvm.script import relax as R, tir as T +from tvm.relax import transform +from tvm.ir.base import assert_structural_equal + + +def _check_equal(x, y): + tvm.ir.assert_structural_equal(x, y) + tvm.ir.assert_structural_equal(y, x) + + xhash = tvm.ir.structural_hash(x, map_free_vars=True) + yhash = tvm.ir.structural_hash(y, map_free_vars=True) + assert xhash == yhash + + +def _check_save_roundtrip(x): + y = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, y) + + +def test_basic(): + # the target IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + inner = lifted_func_0 + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @tvm.script.ir_module + class Before: + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lambda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_closure(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + outer_func = lifted_func_0 + in_call = outer_func(x) + res = R.invoke_closure(in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32"))) + return res + + @R.function + def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 3), "float32")): + r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return r_1 + + @R.function + def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object: + inner_func = R.make_closure(lifted_func_1, (y,)) + return inner_func + + # IRModule to perform Lambda Lifting + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + @R.function + def outer_func(c1: R.Tensor((2, 3), "float32")): + @R.function + def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + s: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + before = Before + after = transform.LambdaLift()(before) + expected = Expected + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +@pytest.mark.skip(reason="Need fix after parser switch over") +def test_recursive(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond: R.Tensor((), "bool") = R.call_packed( + "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) + ) + c: R.Tensor((), "int32") = R.const(1, dtype="int32") + if cond: + new_i: R.Tensor((), "int32") = R.add(i, c) + new_s: R.Tensor((2, 3), "float32") = R.add(s, x) + r = lifted_func_0(new_i, new_s, x) + else: + r = s + return r + + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + while_loop = R.make_closure(lifted_func_0, (x,)) + gv = R.invoke_closure( + while_loop, + (relax.const(0), x), + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ) + return gv + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + @R.function + def while_loop( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond: R.Tensor((), "bool") = R.call_packed( + "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) + ) + c: R.Tensor((), "int32") = R.const(1, dtype="int32") + if cond: + new_i: R.Tensor((), "int32") = R.add(i, c) + new_s: R.Tensor((2, 3), "float32") = R.add(s, x) + r: R.Tensor((2, 3), "float32") = while_loop(new_i, new_s) + else: + r: R.Tensor((2, 3), "float32") = s + return r + + gv: R.Tensor((2, 3), "float32") = while_loop(relax.const(0), x) + return gv + + before = Before + expected = Expected + # Perform Lamda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +@pytest.mark.skip(reason="Need fix after parser switch over") +def test_multi_func(): + # expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def glob_func_1( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + inner = lifted_func_1 + gv1 = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + inner1 = lifted_func_0 + gv11 = inner1(x11, y11) + return gv11 + + @R.function + def lifted_func_0( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + @R.function + def lifted_func_1( + x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + s1: R.Tensor((10, 5), "float32") = R.add(x21, y21) + return s1 + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def glob_func_1( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lamda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 4 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_no_local_func(): + @tvm.script.ir_module + class Before: + @T.prim_func + def sub( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("sub"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] - B[vi, vj] + + @R.function + def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor(dtype="float32", ndim=2)): + s = R.call_tir(sub, (c0, x), R.Tensor((16, 16), dtype="float32")) + return s + + before = Before + # Perform lambda lifting + after = transform.LambdaLift()(before) + # No local functions are lifted + assert_structural_equal(after, before, map_free_vars=True) + _check_save_roundtrip(after) + + +if __name__ == "__main__": + tvm.testing.main()