Skip to content

Commit

Permalink
Replace UseDefaultCompiler with GetAttr (apache#5088)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and Trevor Morris committed Apr 16, 2020
1 parent 7ae1f20 commit 0123e22
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 37 deletions.
9 changes: 0 additions & 9 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode {
*/
TVM_DLL FuncType func_type_annotation() const;

/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;

static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode {
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (!src_func->UseDefaultCompiler()) {
if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
Expand Down Expand Up @@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode {
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (!key->source_func->UseDefaultCompiler()) {
if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
Expand Down
5 changes: 3 additions & 2 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class GraphRuntimeCodegen
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;
// Handle external function
if (!func->UseDefaultCompiler()) {
if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
Expand Down Expand Up @@ -490,7 +490,8 @@ class GraphRuntimeCodegen
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
CHECK(!op->UseDefaultCompiler()) << "Only functions supported by custom codegen";
CHECK(op->GetAttr<tir::StringImm>(attr::kCompiler).defined())
<< "Only functions supported by custom codegen";
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {

Target target;

if (!func->UseDefaultCompiler()) {
if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
} else {
// Next generate the invoke instruction.
Expand All @@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto cfunc = engine_->Lower(key);

auto op_index = -1;
if (!func->UseDefaultCompiler()) {
if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
} else {
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,17 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);

DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);

func = Function(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
func = Function(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
}
}
Expand Down
5 changes: 0 additions & 5 deletions src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncType(param_types, ret_type, this->type_params, {});
}

bool FunctionNode::UseDefaultCompiler() const {
tir::StringImm val = this->GetAttr<tir::StringImm>(attr::kCompiler);
return !val.defined() || val->value == "default";
}

TVM_REGISTER_NODE_TYPE(FunctionNode);

TVM_REGISTER_GLOBAL("relay.ir.Function")
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,

bool FunctionPassNode::SkipFunction(const Function& func) const {
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
!(func->UseDefaultCompiler());
(func->GetAttr<tir::StringImm>(attr::kCompiler).defined());
}

Pass CreateFunctionPass(
Expand Down
10 changes: 5 additions & 5 deletions src/relay/transforms/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ class Inliner : ExprMutator {
CHECK(fn) << "Expected to work on a Relay function.";

auto func = Function(fn->params,
fn->body,
fn->ret_type,
fn->type_params,
fn->attrs);
fn->body,
fn->ret_type,
fn->type_params,
fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
if (func->UseDefaultCompiler()) {
if (!func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
CHECK_EQ(func->params.size(), args.size())
<< "Mismatch found in the number of parameters and call args";
// Bind the parameters with call args.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
}
Expr ret =
TransformF([&](const Expr& e) {
Expand Down

0 comments on commit 0123e22

Please sign in to comment.