Skip to content

Commit

Permalink
fix string serilazation, add const char* constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Apr 9, 2020
1 parent 5a58136 commit 18de508
Show file tree
Hide file tree
Showing 20 changed files with 58 additions and 45 deletions.
10 changes: 9 additions & 1 deletion include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,15 @@ class String : public ObjectRef {
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
explicit String(std::string other);
String(std::string other); // NOLINT(*)

/*!
* \brief Construct a new String object
*
* \param other a char array.
*/
String(const char* other) // NOLINT(*)
: String(std::string(other)) {}

/*!
* \brief Change the value the reference object points to.
Expand Down
38 changes: 29 additions & 9 deletions src/node/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class NodeIndexer : public AttrVisitor {
std::vector<Object*> node_list_{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index_;
std::vector<DLTensor*> tensor_list_;
std::unordered_map<StringObj*, size_t> string_index_;
std::vector<std::string> string_list_;
ReflectionVTable* reflection_ = ReflectionVTable::Global();

void Visit(const char* key, double* value) final {}
Expand Down Expand Up @@ -102,7 +104,13 @@ class NodeIndexer : public AttrVisitor {
for (const auto& kv : n->data) {
MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else if (!node->IsInstance<StringObj>()) {
} else if (node->IsInstance<StringObj>()) {
StringObj* ptr = static_cast<StringObj*>(node);
if (string_index_.count(ptr)) return;
CHECK_EQ(string_index_.size(), string_list_.size());
string_index_[ptr] = string_list_.size();
string_list_.push_back(ptr->data);
} else {
reflection_->VisitAttrs(node, this);
}
}
Expand Down Expand Up @@ -337,11 +345,7 @@ class JSONAttrSetter : public AttrVisitor {
n->data[node_->keys[i]]
= ObjectRef(node_list_->at(node_->data[i]));
}
} else if (node->IsInstance<StringObj>()) {
StringObj* n = static_cast<StringObj*>(node);
auto saved = node_list_->at(node_->data[0]);
saved = runtime::GetObjectPtr<StringObj>(n);
} else {
} else if (!node->IsInstance<StringObj>()) {
reflection_->VisitAttrs(node, this);
}
}
Expand All @@ -355,6 +359,8 @@ struct JSONGraph {
std::vector<JSONNode> nodes;
// base64 b64ndarrays of arrays
std::vector<std::string> b64ndarrays;
// strings
std::vector<std::string> strings;
// global attributes
AttrMap attrs;

Expand All @@ -363,6 +369,7 @@ struct JSONGraph {
writer->WriteObjectKeyValue("root", root);
writer->WriteObjectKeyValue("nodes", nodes);
writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays);
writer->WriteObjectKeyValue("strings", strings);
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
}
Expand All @@ -375,6 +382,7 @@ struct JSONGraph {
helper.DeclareField("root", &root);
helper.DeclareField("nodes", &nodes);
helper.DeclareOptionalField("b64ndarrays", &b64ndarrays);
helper.DeclareOptionalField("strings", &strings);
helper.DeclareOptionalField("attrs", &attrs);
helper.ReadAllFields(reader);
}
Expand Down Expand Up @@ -403,6 +411,8 @@ struct JSONGraph {
b64strm.Finish();
g.b64ndarrays.emplace_back(std::move(blob));
}
// serialize string
g.strings = std::move(indexer.string_list_);
return g;
}
};
Expand All @@ -423,6 +433,7 @@ ObjectRef LoadJSON(std::string json_str) {
jgraph.Load(&reader);
std::vector<ObjectPtr<Object> > nodes;
std::vector<runtime::NDArray> tensors;
std::vector<std::string> strings;
// load in tensors
for (const std::string& blob : jgraph.b64ndarrays) {
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
Expand All @@ -432,16 +443,25 @@ ObjectRef LoadJSON(std::string json_str) {
CHECK(temp.Load(&b64strm));
tensors.emplace_back(temp);
}
// load in strings
strings = std::move(jgraph.strings);
ReflectionVTable* reflection = ReflectionVTable::Global();

// node 0 is always null
nodes.reserve(jgraph.nodes.size());

int string_idx = 0;
for (const JSONNode& jnode : jgraph.nodes) {
if (jnode.type_key.length() != 0) {
ObjectPtr<Object> node =
reflection->CreateInitObject(jnode.type_key, jnode.global_key);
nodes.emplace_back(node);
if (jnode.type_key == "runtime.String") {
String ref = String(strings[string_idx++]);
auto* ref_node = const_cast<StringObj*>(ref.as<StringObj>());
nodes.emplace_back(runtime::GetObjectPtr<Object>(ref_node));
} else {
ObjectPtr<Object> node =
reflection->CreateInitObject(jnode.type_key, jnode.global_key);
nodes.emplace_back(node);
}
} else {
nodes.emplace_back(ObjectPtr<Object>());
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class RelayBuildModule : public runtime::ModuleNode {
}

Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{runtime::String("main")};
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));

// Run all dialect legalization passes.
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ Pass AlterOpLayout() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
};
return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ Pass AnnotateTarget(const std::string& target) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
{runtime::String("InferType")});
{"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
}

Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/canonicalize_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ Pass CanonicalizeCast() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/canonicalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ Pass CanonicalizeOps() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeOps(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/combine_parallel_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/combine_parallel_op_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ Pass CombineParallelOpBatch(const std::string& op_name,
batch_op_name,
min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
Expand Down
4 changes: 1 addition & 3 deletions src/relay/transforms/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ Pass ConvertLayout(const std::string& desired_layout) {
return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
};
return CreateFunctionPass(
pass_func, 3, "ConvertLayout",
{runtime::String("InferType"),
runtime::String("CanonicalizeOps")});
pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
}

TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,7 @@ Pass RewriteAnnotatedOps(int fallback_device) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
};
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
};
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/fast_math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ Pass FastMath() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FastMath(f));
};
return CreateFunctionPass(pass_func, 4, "FastMath",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.FastMath")
Expand Down
6 changes: 2 additions & 4 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,7 @@ Pass ForwardFoldScaleAxis() {
return Downcast<Function>(
relay::fold_scale_axis::ForwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
Expand All @@ -973,8 +972,7 @@ Pass BackwardFoldScaleAxis() {
return Downcast<Function>(
relay::fold_scale_axis::BackwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,7 @@ Pass FuseOps(int fuse_opt_level) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
};
return CreateFunctionPass(pass_func, 1, "FuseOps",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.FuseOps")
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
return CreateFunctionPass(pass_func, 1, "Legalize", {runtime::String("InferType")});
return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ Pass SimplifyInference() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SimplifyInference(f));
};
return CreateFunctionPass(pass_func, 0, "SimplifyInference",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference")
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ TEST(String, empty) {
using namespace std;
String s{"hello"};
CHECK_EQ(s.empty(), false);
s = "";
s = std::string("");
CHECK_EQ(s.empty(), true);
}

Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@

def check_json_roundtrip(node):
json_str = tvm.ir.save_json(node)
print(node)
back = tvm.ir.load_json(json_str)
print(back)
assert tvm.ir.structural_equal(back, node, map_free_vars=True)


Expand Down Expand Up @@ -99,11 +97,13 @@ def test_function():
type_params = tvm.runtime.convert([])
fn = relay.Function(params, body, ret_type, type_params)
fn = fn.with_attr("test_attribute", "value")
fn = fn.with_attr("test_attribute1", "value1")
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
assert fn.attrs["test_attribute"] == "value"
assert fn.attrs["test_attribute1"] == "value1"
str(fn)
check_json_roundtrip(fn)

Expand Down

0 comments on commit 18de508

Please sign in to comment.