diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index f7514c7685e6..5c5bd2673073 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode { */ TVM_DLL FuncType func_type_annotation() const; - /*! - * \brief Check whether the function should use the TVM default compiler to build, or - * use other compilers. - * - * \return Whether the function will be compiled using the default compiler - * (e.g. those are used in the TVM stack). - */ - bool UseDefaultCompiler() const; - static constexpr const char* _type_key = "relay.Function"; TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); }; diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ccbe4dfc858a..1237c56163f9 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; CHECK(src_func.defined()); - if (!src_func->UseDefaultCompiler()) { + if (src_func->GetAttr(attr::kCompiler).defined()) { auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; if (ext_mods.find(code_gen->value) == ext_mods.end()) { @@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode { } // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. - if (!key->source_func->UseDefaultCompiler()) { + if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto cache_node = make_object(); const auto name_node = key->source_func->GetAttr(attr::kExternalSymbol); diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 032ebcd22d20..0587cd216e12 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -424,7 +424,7 @@ class GraphRuntimeCodegen auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; // Handle external function - if (!func->UseDefaultCompiler()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = tvm::target::ext_dev(); CCacheKey key = (*pf0)(func, target); CachedFunc ext_func = (*pf1)(compile_engine_, key); @@ -490,7 +490,8 @@ class GraphRuntimeCodegen return {}; } std::vector VisitExpr_(const FunctionNode* op) override { - CHECK(!op->UseDefaultCompiler()) << "Only functions supported by custom codegen"; + CHECK(op->GetAttr(attr::kCompiler).defined()) + << "Only functions supported by custom codegen"; return {}; } std::vector VisitExpr_(const RefCreateNode* op) override { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index caf429aeab49..2fc6567348d8 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor { Target target; - if (!func->UseDefaultCompiler()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = tvm::target::ext_dev(); } else { // Next generate the invoke instruction. @@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor { auto cfunc = engine_->Lower(key); auto op_index = -1; - if (!func->UseDefaultCompiler()) { + if (func->GetAttr(attr::kCompiler).defined()) { op_index = context_->cached_funcs.size(); context_->cached_funcs.push_back(cfunc); } else { diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 0eb6c1a0a68c..8327a6bb2168 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -122,17 +122,17 @@ struct PrimitiveInliner : ExprMutator { auto global = pair.first; auto base_func = pair.second; if (auto* n = base_func.as()) { - if (!n->UseDefaultCompiler()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false); func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, - func->attrs); + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); module_->Add(global, func, true); DLOG(INFO) << "After inlining primitives: " << global diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 987fdcb1d920..fd8c35152ff1 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -187,13 +187,13 @@ class LambdaLifter : public ExprMutator { auto glob_funcs = module_->functions; for (auto pair : glob_funcs) { if (auto* n = pair.second.as()) { - if (!n->UseDefaultCompiler()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, - func->attrs); + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); module_->Add(pair.first, func, true); } } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index d371edb31fca..b2516456cb27 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const { return FuncType(param_types, ret_type, this->type_params, {}); } -bool FunctionNode::UseDefaultCompiler() const { - tir::StringImm val = this->GetAttr(attr::kCompiler); - return !val.defined() || val->value == "default"; -} - TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay.ir.Function") diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 919b06604efd..59c1750ae677 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, bool FunctionPassNode::SkipFunction(const Function& func) const { return func->GetAttr(attr::kSkipOptimization, 0)->value != 0 || - !(func->UseDefaultCompiler()); + (func->GetAttr(attr::kCompiler).defined()); } Pass CreateFunctionPass( diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index 5f26d67d7ae4..9e118ba8f87c 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -125,13 +125,13 @@ class Inliner : ExprMutator { CHECK(fn) << "Expected to work on a Relay function."; auto func = Function(fn->params, - fn->body, - fn->ret_type, - fn->type_params, - fn->attrs); + fn->body, + fn->ret_type, + fn->type_params, + fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. - if (func->UseDefaultCompiler()) { + if (!func->GetAttr(attr::kCompiler).defined()) { CHECK_EQ(func->params.size(), args.size()) << "Mismatch found in the number of parameters and call args"; // Bind the parameters with call args. diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 327eb6274aea..e4722e2c3748 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) { for (const auto& it : funcs) { CHECK_EQ(FreeVars(it.second).size(), 0); if (const auto* n = it.second.as()) { - if (!n->UseDefaultCompiler()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) {