Skip to content

Commit

Permalink
[Relax] Allow specifying entry_funcs for BYOC (#16902)
Browse files Browse the repository at this point in the history
* [Relax] Allow specifying entry_funcs for BYOC
  • Loading branch information
vinx13 authored Apr 18, 2024
1 parent 7dc0472 commit 59376ee
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 25 deletions.
5 changes: 4 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,15 @@ TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = Nul
* corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu".
* This must be True if the created composite functions are intended to be offloaded to
* an external backend without using the MergeCompositeFunctions pass.
* \param entry_function_names The names of functions that should be considered as entry points. If
* not specified, all externally exposed functions will be considered as entry points.
* \return The Pass.
*
* \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
bool annotate_codegen = false);
bool annotate_codegen = false,
const tvm::Array<String>& entry_function_names = {});

/*!
* \brief Group one or multiple composite functions created by FuseOpsByPattern into a new
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,7 @@ def FuseOpsByPattern(
patterns: List[Union[FusionPattern, Tuple]],
bind_constants: bool = True,
annotate_codegen: bool = False,
entry_functions: Optional[List[str]] = None,
) -> tvm.ir.transform.Pass:
"""Apply pattern matching to each function in the given module, and group matched expressions
into a new function.
Expand Down Expand Up @@ -919,6 +920,9 @@ def FuseOpsByPattern(
This must be True if the created composite functions are intended to be offloaded to
an external backend without using the MergeCompositeFunctions pass.
entry_functions : Optional[List[str]]
The set of entry functions to start from.
Returns
-------
ret : tvm.transform.Pass
Expand All @@ -938,6 +942,7 @@ def FuseOpsByPattern(
converted_patterns,
bind_constants,
annotate_codegen,
entry_functions or [],
) # type: ignore


Expand Down
69 changes: 46 additions & 23 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -690,8 +690,16 @@ class OperatorFusor : public ExprMutator {
* \brief The main transformation on the IRModule
* \return The new IRModule after transformation
*/
IRModule Transform() {
for (const auto& gv : mod_->GetGlobalVars()) {
IRModule Transform(const Array<String>& entry_function_names = {}) {
Array<GlobalVar> entry_functions;
if (entry_function_names.empty()) {
entry_functions = mod_->GetGlobalVars();
} else {
for (const auto& name : entry_function_names) {
entry_functions.push_back(mod_->GetGlobalVar(name));
}
}
for (const auto& gv : entry_functions) {
const auto& func = mod_->Lookup(gv);
// Only visit Relax function without attr kPrimitive.
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
Expand Down Expand Up @@ -1023,8 +1031,8 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) {

IRModule MakeGroupedFunctions(
IRModule mod, const std::unordered_map<const Object*, GraphPartitioner::Group*>& partition,
bool lift_constants) {
return OperatorFusor(mod, partition, lift_constants).Transform();
bool lift_constants, const Array<String>& entry_function_names) {
return OperatorFusor(mod, partition, lift_constants).Transform(entry_function_names);
}

/*! \brief Create a "partitioning", a map from interior / leaf expr to its representative group,
Expand Down Expand Up @@ -1269,26 +1277,39 @@ class CompositeFunctionAnnotator : public ExprMutator {
};

IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns, IRModule mod,
bool bind_constants, bool annotate_codegen) {
bool bind_constants, bool annotate_codegen,
Array<String> entry_function_names) {
support::Arena arena;

for (const auto& pattern : patterns) {
OperatorFusor::GroupMap group_map;
for (const auto& gv : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gv);
if (base_func->IsInstance<tir::PrimFuncNode>()) {
continue;
Array<Function> entry_functions;
if (entry_function_names.size()) {
for (const auto& name : entry_function_names) {
auto gv = mod->GetGlobalVar(name);
auto func = mod->Lookup(gv);
ICHECK(func->IsInstance<FunctionNode>()) << "Entry function must be a relax function";
entry_functions.push_back(Downcast<Function>(func));
}
const FunctionNode* function = base_func.as<FunctionNode>();
if (function->GetAttr<Integer>(attr::kPrimitive).defined() ||
function->GetAttr<String>(attr::kComposite).defined() ||
function->GetAttr<String>(attr::kCodegen).defined()) {
continue;
} else {
for (const auto& gv : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gv);
if (base_func->IsInstance<tir::PrimFuncNode>()) {
continue;
}
const FunctionNode* function = base_func.as<FunctionNode>();
if (function->GetAttr<Integer>(attr::kPrimitive).defined() ||
function->GetAttr<String>(attr::kComposite).defined() ||
function->GetAttr<String>(attr::kCodegen).defined()) {
continue;
}
entry_functions.push_back(Downcast<Function>(base_func));
}

auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern,
pattern->annotation_patterns,
pattern->check.value_or(nullptr), base_func, &arena,
pattern->attrs_getter.value_or(nullptr));
}
OperatorFusor::GroupMap group_map;
for (const auto& func : entry_functions) {
auto map = PatternBasedPartitioner::Run(
pattern->name, pattern->pattern, pattern->annotation_patterns,
pattern->check.value_or(nullptr), func, &arena, pattern->attrs_getter.value_or(nullptr));
for (const auto& [key, value] : map) {
CHECK(!group_map.count(key))
<< "ValueError: "
Expand All @@ -1298,7 +1319,8 @@ IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns,
group_map.insert({key, value});
}
}
mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants);
mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants,
entry_function_names);
}
if (annotate_codegen) {
return CompositeFunctionAnnotator(mod).Run();
Expand Down Expand Up @@ -1358,10 +1380,11 @@ Pass FuseOps(int fuse_opt_level) {
TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);

Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants,
bool annotate_codegen) {
bool annotate_codegen, const Array<String>& entry_function_names) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen);
return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen,
entry_function_names);
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
Expand Down
3 changes: 2 additions & 1 deletion src/relax/transform/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,13 @@ inline std::string GetExtSymbol(const Function& func) {
* \param partition A mapping from a subexpression to the containing group.
* \param lift_constants Whether or not to lift bound constants to parameters of the
* grouped function.
* \param entry_function_names The names of the entry functions.
* \return A new module containing grouped functions.
*/
IRModule MakeGroupedFunctions(
IRModule mod,
const std::unordered_map<const Object*, relay::GraphPartitioner::Group*>& partition,
bool lift_constants = true);
bool lift_constants = true, const Array<String>& entry_function_names = {});

/*!
* \brief Check if the given StructInfo is a scalar tensor. The sinfo should be an instance of
Expand Down

0 comments on commit 59376ee

Please sign in to comment.