Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Stabilize relax pass mutation order #16883

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class IRModuleNode : public Object {
TVM_DLL GlobalVar GetGlobalVar(const String& str) const;

/*!
* \brief Collect all global vars defined in this module.
* \brief Collect all global vars defined in this module, ordered by
* the global variable name.
* \returns An array of global vars
*/
TVM_DLL Array<GlobalVar> GetGlobalVars() const;
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,10 @@ def export_tvm(
-------
irmodule : tvm.ir.IRModule
The converted tvm IR representation of the model.
params : Dict[str, tvm.nd.array]
A dictionary of parameters corresponding to the weights of
the model.
params : List[Tuple[str, Parameter]]
A list of Parameters corresponding to the weights of the model.
ext_mods : List[nn.ExternModule]
A list of ExternModules that are used in the model.
"""
# pylint: disable=import-outside-toplevel
from . import spec as _spec
Expand Down
4 changes: 4 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/node/structural_equal.h>
#include <tvm/runtime/registry.h>

#include <algorithm>
#include <fstream>
#include <sstream>
#include <unordered_set>
Expand Down Expand Up @@ -183,6 +184,9 @@ tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
for (const auto& pair : global_var_map_) {
global_vars.push_back(pair.second);
}
std::sort(global_vars.begin(), global_vars.end(), [](const GlobalVar& lhs, const GlobalVar& rhs) {
return lhs->name_hint < rhs->name_hint;
});
return tvm::Array<GlobalVar>(global_vars);
}

Expand Down
3 changes: 2 additions & 1 deletion src/relax/transform/alter_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class AlterOpImplMutator : public ExprMutator {
op_buffer_axis_separators__(axis_separators_) {}

IRModule Run() {
for (const auto& [gv, func] : mod_->functions) {
for (const auto& gv : mod_->GetGlobalVars()) {
const auto& func = mod_->Lookup(gv);
if (func->IsInstance<relax::FunctionNode>()) {
relax::Function update_func = Downcast<Function>(VisitExpr(func));
builder_->UpdateFunction(gv, update_func);
Expand Down
3 changes: 2 additions & 1 deletion src/relax/transform/dead_code_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array<runtime::String> ent
for (const auto& name : entry_function_names) {
entry_functions.insert(mod->GetGlobalVar(name));
}
for (const auto& [gv, func] : mod->functions) {
for (const auto& gv : mod->GetGlobalVars()) {
const auto& func = mod->Lookup(gv);
if (func.as<ExternFuncNode>() || func->GetLinkageType() == LinkageType::kExternal) {
entry_functions.insert(gv);
}
Expand Down
22 changes: 12 additions & 10 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,8 @@ class OperatorFusor : public ExprMutator {
* \return The new IRModule after transformation
*/
IRModule Transform() {
for (const auto& [gv, func] : mod_->functions) {
for (const auto& gv : mod_->GetGlobalVars()) {
const auto& func = mod_->Lookup(gv);
// Only visit Relax function without attr kPrimitive.
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
auto updated_func = Downcast<Function>(VisitExpr(func));
Expand Down Expand Up @@ -1196,9 +1197,9 @@ class CompositeFunctionAnnotator : public ExprMutator {

IRModule Run() {
auto mod = builder_->GetContextIRModule();
auto all_functions = mod->functions;
for (const auto& entry : all_functions) {
if (const auto* func = entry.second.as<FunctionNode>()) {
for (const auto& gv : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gv);
if (const auto* func = base_func.as<FunctionNode>()) {
if (func->GetAttr<String>(attr::kComposite).defined() ||
func->GetAttr<String>(attr::kCodegen).defined()) {
continue;
Expand All @@ -1208,7 +1209,7 @@ class CompositeFunctionAnnotator : public ExprMutator {
if (!new_body.same_as(func->body)) {
auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure,
func->attrs, func->span);
builder_->UpdateFunction(entry.first, new_func);
builder_->UpdateFunction(gv, new_func);
}
}
}
Expand Down Expand Up @@ -1272,11 +1273,12 @@ IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns,
support::Arena arena;
for (const auto& pattern : patterns) {
OperatorFusor::GroupMap group_map;
for (const auto& entry : mod->functions) {
if (entry.second->IsInstance<tir::PrimFuncNode>()) {
for (const auto& gv : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gv);
if (base_func->IsInstance<tir::PrimFuncNode>()) {
continue;
}
const FunctionNode* function = entry.second.as<FunctionNode>();
const FunctionNode* function = base_func.as<FunctionNode>();
if (function->GetAttr<Integer>(attr::kPrimitive).defined() ||
function->GetAttr<String>(attr::kComposite).defined() ||
function->GetAttr<String>(attr::kCodegen).defined()) {
Expand All @@ -1285,8 +1287,8 @@ IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns,

auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern,
pattern->annotation_patterns,
pattern->check.value_or(nullptr), entry.second,
&arena, pattern->attrs_getter.value_or(nullptr));
pattern->check.value_or(nullptr), base_func, &arena,
pattern->attrs_getter.value_or(nullptr));
for (const auto& [key, value] : map) {
CHECK(!group_map.count(key))
<< "ValueError: "
Expand Down
3 changes: 2 additions & 1 deletion src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,8 @@ class TIRFuseMutator : public ExprMutator {
static IRModule Transform(IRModule mod) {
// Collect all primitive relax functions
Map<GlobalVar, Function> primitive_relax;
for (const auto& [gvar, base_func] : mod->functions) {
for (const auto& gvar : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gvar);
// Only fuse primitive relax functions
if (base_func->HasNonzeroAttr(attr::kPrimitive)) {
if (auto func = base_func.as<relax::Function>()) {
Expand Down
8 changes: 5 additions & 3 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,18 @@ class LegalizeMutator : public ExprMutator {
}

IRModule Transform() {
for (const auto& [gv, func] : mod_->functions) {
for (const auto& gv : mod_->GetGlobalVars()) {
const auto& func = mod_->Lookup(gv);
if (func->IsInstance<FunctionNode>()) {
auto updated_func = Downcast<Function>(this->VisitExpr(func));
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
// Fill the "kTarget" attribute of PrimFunc
for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
const auto& mod = builder_->GetContextIRModule();
for (const auto& gv : mod->GetGlobalVars()) {
const tir::PrimFuncNode* prim_func;
if (tmap_.count(gv) && (prim_func = func.as<tir::PrimFuncNode>())) {
if (tmap_.count(gv) && (prim_func = mod->Lookup(gv).as<tir::PrimFuncNode>())) {
auto f = WithAttr(GetRef<tir::PrimFunc>(prim_func), tvm::attr::kTarget, tmap_[gv]);
builder_->UpdateFunction(gv, f);
}
Expand Down
Loading