Skip to content

Commit

Permalink
move string to tvm namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Apr 8, 2020
1 parent c4a2140 commit 28cb645
Show file tree
Hide file tree
Showing 27 changed files with 42 additions and 34 deletions.
2 changes: 2 additions & 0 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

namespace tvm {

using runtime::String;
using runtime::StringObj;
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
Expand Down
10 changes: 7 additions & 3 deletions src/node/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class NodeIndexer : public AttrVisitor {
for (const auto& kv : n->data) {
MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else if (!node->IsInstance<runtime::StringObj>()) {
} else if (!node->IsInstance<StringObj>()) {
reflection_->VisitAttrs(node, this);
}
}
Expand Down Expand Up @@ -242,7 +242,7 @@ class JSONAttrGetter : public AttrVisitor {
node_->data.push_back(
node_index_->at(const_cast<Object*>(kv.second.get())));
}
} else if (node->IsInstance<runtime::StringObj>()) {
} else if (node->IsInstance<StringObj>()) {
node_->data.push_back(node_index_->at(node));
} else {
// recursively index normal object.
Expand Down Expand Up @@ -337,7 +337,11 @@ class JSONAttrSetter : public AttrVisitor {
n->data[node_->keys[i]]
= ObjectRef(node_list_->at(node_->data[i]));
}
} else if (!node->IsInstance<runtime::StringObj>()) {
} else if (node->IsInstance<StringObj>()) {
StringObj* n = static_cast<StringObj*>(node);
auto saved = node_list_->at(node_->data[0]);
saved = runtime::GetObjectPtr<StringObj>(n);
} else {
reflection_->VisitAttrs(node, this);
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -617,14 +617,14 @@ class CompileEngineImpl : public CompileEngineNode {
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (src_func->GetAttr<runtime::String>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<runtime::String>(attr::kCompiler);
if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set";
std::string code_gen_name = code_gen.operator std::string();
if (ext_mods.find(code_gen_name) == ext_mods.end()) {
ext_mods[code_gen_name] = IRModule({}, {});
}
auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
auto gv = GlobalVar(std::string(symbol_name));
Expand Down Expand Up @@ -692,10 +692,10 @@ class CompileEngineImpl : public CompileEngineNode {
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->GetAttr<runtime::String>(attr::kCompiler).defined()) {
if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined())
<< "External function has not been attached a name yet.";
cache_node->func_name = std::string(name_node);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class CSourceModuleCodegenBase {
*/
std::string GetExtSymbol(const Function& func) const {
const auto name_node =
func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node);
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ class GraphRuntimeCodegen
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;
// Handle external function
if (func->GetAttr<runtime::String>(attr::kCompiler).defined()) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
Expand Down Expand Up @@ -482,7 +482,7 @@ class GraphRuntimeCodegen
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
CHECK(op->GetAttr<runtime::String>(attr::kCompiler).defined())
CHECK(op->GetAttr<String>(attr::kCompiler).defined())
<< "Only functions supported by custom codegen";
return {};
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {

Target target;

if (func->GetAttr<runtime::String>(attr::kCompiler).defined()) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
} else {
// Next generate the invoke instruction.
Expand All @@ -493,7 +493,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto cfunc = engine_->Lower(key);

auto op_index = -1;
if (func->GetAttr<runtime::String>(attr::kCompiler).defined()) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
if (n->GetAttr<runtime::String>(attr::kCompiler).defined()) continue;
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);

DLOG(INFO) << "Before inlining primitives: " << global
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
if (n->GetAttr<runtime::String>(attr::kCompiler).defined()) continue;
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
func = Function(func->params,
VisitExpr(func->body),
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,

bool FunctionPassNode::SkipFunction(const Function& func) const {
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
(func->GetAttr<runtime::String>(attr::kCompiler).defined());
(func->GetAttr<String>(attr::kCompiler).defined());
}

Pass CreateFunctionPass(
Expand Down
4 changes: 2 additions & 2 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class AnnotateTargetWrapper : public ExprMutator {
// handle composite functions
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
auto comp_name = func->GetAttr<runtime::String>(attr::kComposite);
auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined()) {
std::string comp_name_str = comp_name;
size_t i = comp_name_str.find('.');
Expand Down Expand Up @@ -148,7 +148,7 @@ class AnnotateTargetWrapper : public ExprMutator {
Function func;
Expr new_body;
// don't step into composite functions
if (fn->GetAttr<runtime::String>(attr::kComposite).defined()) {
if (fn->GetAttr<String>(attr::kComposite).defined()) {
func = GetRef<Function>(fn);
new_body = func->body;
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class Inliner : ExprMutator {
fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
if (!func->GetAttr<runtime::String>(attr::kCompiler).defined()) {
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
CHECK_EQ(func->params.size(), args.size())
<< "Mismatch found in the number of parameters and call args";
// Bind the parameters with call args.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class MergeCompositeWrapper : public ExprMutator {
if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
auto name_node = func->GetAttr<runtime::String>(attr::kComposite);
auto name_node = func->GetAttr<String>(attr::kComposite);
// don't step into existing composite functions
if (name_node.defined() && name_node != "") {
tvm::Array<tvm::relay::Expr> new_args;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<runtime::String>(attr::kCompiler).defined()) continue;
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
}
Expr ret =
TransformF([&](const Expr& e) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/build_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ExtractFuncInfo(const IRModule& mod) {
info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
}
}
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol)] = info;
}
return fmap;
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void CodeGenCPU::Init(const std::string& module_name,
void CodeGenCPU::AddFunction(const PrimFunc& f) {
CodeGenLLVM::AddFunction(f);
if (f_tvm_register_system_symbol_ != nullptr) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
llvm::FunctionType* ftype = llvm::FunctionType::get(
ret_void ? t_void_ : t_int_, param_types, false);

auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
CHECK(module_->getFunction(static_cast<std::string>(global_symbol)) == nullptr)
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined());
entry_func = global_symbol;
}
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
// reserve keywords
ReserveKeywordsAsUnique();

auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
GetUniqueName("_");

// add to alloc buffer type.
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";

Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opengl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) {
arg_kinds.push_back(kind);
}

auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute";

Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_vhls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
code = (*f)(code).operator std::string();
}

auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
runtime::String func_name(global_symbol);
Expand Down
2 changes: 1 addition & 1 deletion src/target/spirv/build_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ runtime::Module BuildSPIRV(IRModule mod) {
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";

Expand Down
2 changes: 1 addition & 1 deletion src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) {
builder_->MakeInst(spv::OpReturn);
builder_->MakeInst(spv::OpFunctionEnd);

auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";

Expand Down
2 changes: 1 addition & 1 deletion src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ runtime::Module BuildStackVM(const IRModule& mod) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenStackVM: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol;
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {

PrimFunc MakePackedAPI(PrimFunc&& func,
int num_unpacked_args) {
auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
std::string name_hint = global_symbol;
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "SplitHostDevice: Require the target attribute";
auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";

Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

def check_json_roundtrip(node):
json_str = tvm.ir.save_json(node)
print(node)
back = tvm.ir.load_json(json_str)
print(back)
assert tvm.ir.structural_equal(back, node, map_free_vars=True)


Expand Down

0 comments on commit 28cb645

Please sign in to comment.