From 59376eeca3373b889b1e25ef9f1d4aa4ff0524f8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 18 Apr 2024 08:50:55 -0700 Subject: [PATCH] [Relax] Allow specifying entry_funcs for BYOC (#16902) * [Relax] Allow specifying entry_funcs for BYOC --- include/tvm/relax/transform.h | 5 +- python/tvm/relax/transform/transform.py | 5 ++ src/relax/transform/fuse_ops.cc | 69 ++++++++++++++++--------- src/relax/transform/utils.h | 3 +- 4 files changed, 57 insertions(+), 25 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 82cbf3d12d5f..c3a3c873c02b 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -492,12 +492,15 @@ TVM_DLL Pass Gradient(String func_name, Optional> require_grads = Nul * corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu". * This must be True if the created composite functions are intended to be offloaded to * an external backend without using the MergeCompositeFunctions pass. + * \param entry_function_names The names of functions that should be considered as entry points. If + * not specified, all externally exposed functions will be considered as entry points. * \return The Pass. * * \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first. */ TVM_DLL Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants = true, - bool annotate_codegen = false); + bool annotate_codegen = false, + const tvm::Array& entry_function_names = {}); /*! * \brief Group one or multiple composite functions created by FuseOpsByPattern into a new diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index fa18cc672b40..38e7994eb97f 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -890,6 +890,7 @@ def FuseOpsByPattern( patterns: List[Union[FusionPattern, Tuple]], bind_constants: bool = True, annotate_codegen: bool = False, + entry_functions: Optional[List[str]] = None, ) -> tvm.ir.transform.Pass: """Apply pattern matching to each function in the given module, and group matched expressions into a new function. @@ -919,6 +920,9 @@ def FuseOpsByPattern( This must be True if the created composite functions are intended to be offloaded to an external backend without using the MergeCompositeFunctions pass. + entry_functions : Optional[List[str]] + The set of entry functions to start from. + Returns ------- ret : tvm.transform.Pass @@ -938,6 +942,7 @@ def FuseOpsByPattern( converted_patterns, bind_constants, annotate_codegen, + entry_functions or [], ) # type: ignore diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 3e762778d849..ee96f9fa805a 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -690,8 +690,16 @@ class OperatorFusor : public ExprMutator { * \brief The main transformation on the IRModule * \return The new IRModule after transformation */ - IRModule Transform() { - for (const auto& gv : mod_->GetGlobalVars()) { + IRModule Transform(const Array& entry_function_names = {}) { + Array entry_functions; + if (entry_function_names.empty()) { + entry_functions = mod_->GetGlobalVars(); + } else { + for (const auto& name : entry_function_names) { + entry_functions.push_back(mod_->GetGlobalVar(name)); + } + } + for (const auto& gv : entry_functions) { const auto& func = mod_->Lookup(gv); // Only visit Relax function without attr kPrimitive. if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { @@ -1023,8 +1031,8 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { IRModule MakeGroupedFunctions( IRModule mod, const std::unordered_map& partition, - bool lift_constants) { - return OperatorFusor(mod, partition, lift_constants).Transform(); + bool lift_constants, const Array& entry_function_names) { + return OperatorFusor(mod, partition, lift_constants).Transform(entry_function_names); } /*! \brief Create a "partitioning", a map from interior / leaf expr to its representative group, @@ -1269,26 +1277,39 @@ class CompositeFunctionAnnotator : public ExprMutator { }; IRModule FuseOpsByPattern(const tvm::Array& patterns, IRModule mod, - bool bind_constants, bool annotate_codegen) { + bool bind_constants, bool annotate_codegen, + Array entry_function_names) { support::Arena arena; + for (const auto& pattern : patterns) { - OperatorFusor::GroupMap group_map; - for (const auto& gv : mod->GetGlobalVars()) { - const auto& base_func = mod->Lookup(gv); - if (base_func->IsInstance()) { - continue; + Array entry_functions; + if (entry_function_names.size()) { + for (const auto& name : entry_function_names) { + auto gv = mod->GetGlobalVar(name); + auto func = mod->Lookup(gv); + ICHECK(func->IsInstance()) << "Entry function must be a relax function"; + entry_functions.push_back(Downcast(func)); } - const FunctionNode* function = base_func.as(); - if (function->GetAttr(attr::kPrimitive).defined() || - function->GetAttr(attr::kComposite).defined() || - function->GetAttr(attr::kCodegen).defined()) { - continue; + } else { + for (const auto& gv : mod->GetGlobalVars()) { + const auto& base_func = mod->Lookup(gv); + if (base_func->IsInstance()) { + continue; + } + const FunctionNode* function = base_func.as(); + if (function->GetAttr(attr::kPrimitive).defined() || + function->GetAttr(attr::kComposite).defined() || + function->GetAttr(attr::kCodegen).defined()) { + continue; + } + entry_functions.push_back(Downcast(base_func)); } - - auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern, - pattern->annotation_patterns, - pattern->check.value_or(nullptr), base_func, &arena, - pattern->attrs_getter.value_or(nullptr)); + } + OperatorFusor::GroupMap group_map; + for (const auto& func : entry_functions) { + auto map = PatternBasedPartitioner::Run( + pattern->name, pattern->pattern, pattern->annotation_patterns, + pattern->check.value_or(nullptr), func, &arena, pattern->attrs_getter.value_or(nullptr)); for (const auto& [key, value] : map) { CHECK(!group_map.count(key)) << "ValueError: " @@ -1298,7 +1319,8 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, group_map.insert({key, value}); } } - mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants); + mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants, + entry_function_names); } if (annotate_codegen) { return CompositeFunctionAnnotator(mod).Run(); @@ -1358,10 +1380,11 @@ Pass FuseOps(int fuse_opt_level) { TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants, - bool annotate_codegen) { + bool annotate_codegen, const Array& entry_function_names) { runtime::TypedPackedFunc pass_func = // [=](IRModule m, PassContext pc) { - return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen); + return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen, + entry_function_names); }; return CreateModulePass(/*pass_function=*/pass_func, // /*opt_level=*/0, // diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 1ad714972c2d..5755e118541f 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -137,12 +137,13 @@ inline std::string GetExtSymbol(const Function& func) { * \param partition A mapping from a subexpression to the containing group. * \param lift_constants Whether or not to lift bound constants to parameters of the * grouped function. + * \param entry_function_names The names of the entry functions. * \return A new module containing grouped functions. */ IRModule MakeGroupedFunctions( IRModule mod, const std::unordered_map& partition, - bool lift_constants = true); + bool lift_constants = true, const Array& entry_function_names = {}); /*! * \brief Check if the given StructInfo is a scalar tensor. The sinfo should be an instance of