diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 50b406b2c030d..083f87f89bc9d 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -360,7 +360,15 @@ class String : public ObjectRef { * \note If user passes const reference, it will trigger copy. If it's rvalue, * it will be moved into other. */ - explicit String(std::string other); + String(std::string other); // NOLINT(*) + + /*! + * \brief Construct a new String object + * + * \param other a char array. + */ + String(const char* other) // NOLINT(*) + : String(std::string(other)) {} /*! * \brief Change the value the reference object points to. diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 97436bae242c9..81b7b255266fd 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -53,6 +53,8 @@ class NodeIndexer : public AttrVisitor { std::vector node_list_{nullptr}; std::unordered_map tensor_index_; std::vector tensor_list_; + std::unordered_map string_index_; + std::vector string_list_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); void Visit(const char* key, double* value) final {} @@ -102,7 +104,13 @@ class NodeIndexer : public AttrVisitor { for (const auto& kv : n->data) { MakeIndex(const_cast(kv.second.get())); } - } else if (!node->IsInstance()) { + } else if (node->IsInstance()) { + StringObj* ptr = static_cast(node); + if (string_index_.count(ptr)) return; + CHECK_EQ(string_index_.size(), string_list_.size()); + string_index_[ptr] = string_list_.size(); + string_list_.push_back(ptr->data); + } else { reflection_->VisitAttrs(node, this); } } @@ -337,11 +345,7 @@ class JSONAttrSetter : public AttrVisitor { n->data[node_->keys[i]] = ObjectRef(node_list_->at(node_->data[i])); } - } else if (node->IsInstance()) { - StringObj* n = static_cast(node); - auto saved = node_list_->at(node_->data[0]); - saved = runtime::GetObjectPtr(n); - } else { + } else if (!node->IsInstance()) { reflection_->VisitAttrs(node, this); } } @@ -355,6 +359,8 @@ struct JSONGraph { std::vector nodes; // base64 b64ndarrays of arrays std::vector b64ndarrays; + // strings + std::vector strings; // global attributes AttrMap attrs; @@ -363,6 +369,7 @@ struct JSONGraph { writer->WriteObjectKeyValue("root", root); writer->WriteObjectKeyValue("nodes", nodes); writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays); + writer->WriteObjectKeyValue("strings", strings); if (attrs.size() != 0) { writer->WriteObjectKeyValue("attrs", attrs); } @@ -375,6 +382,7 @@ struct JSONGraph { helper.DeclareField("root", &root); helper.DeclareField("nodes", &nodes); helper.DeclareOptionalField("b64ndarrays", &b64ndarrays); + helper.DeclareOptionalField("strings", &strings); helper.DeclareOptionalField("attrs", &attrs); helper.ReadAllFields(reader); } @@ -403,6 +411,8 @@ struct JSONGraph { b64strm.Finish(); g.b64ndarrays.emplace_back(std::move(blob)); } + // serialize string + g.strings = std::move(indexer.string_list_); return g; } }; @@ -423,6 +433,7 @@ ObjectRef LoadJSON(std::string json_str) { jgraph.Load(&reader); std::vector > nodes; std::vector tensors; + std::vector strings; // load in tensors for (const std::string& blob : jgraph.b64ndarrays) { dmlc::MemoryStringStream mstrm(const_cast(&blob)); @@ -432,16 +443,25 @@ ObjectRef LoadJSON(std::string json_str) { CHECK(temp.Load(&b64strm)); tensors.emplace_back(temp); } + // load in strings + strings = std::move(jgraph.strings); ReflectionVTable* reflection = ReflectionVTable::Global(); // node 0 is always null nodes.reserve(jgraph.nodes.size()); + int string_idx = 0; for (const JSONNode& jnode : jgraph.nodes) { if (jnode.type_key.length() != 0) { - ObjectPtr node = - reflection->CreateInitObject(jnode.type_key, jnode.global_key); - nodes.emplace_back(node); + if (jnode.type_key == "runtime.String") { + String ref = String(strings[string_idx++]); + auto* ref_node = const_cast(ref.as()); + nodes.emplace_back(runtime::GetObjectPtr(ref_node)); + } else { + ObjectPtr node = + reflection->CreateInitObject(jnode.type_key, jnode.global_key); + nodes.emplace_back(node); + } } else { nodes.emplace_back(ObjectPtr()); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index bc0685b5d9053..6e6faf9274ef7 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -272,7 +272,7 @@ class RelayBuildModule : public runtime::ModuleNode { } Array pass_seqs; - Array entry_functions{runtime::String("main")}; + Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index 59cf9f98288f9..aab0b3a30a7cf 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -125,8 +125,7 @@ Pass AlterOpLayout() { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::alter_op_layout::AlterOpLayout(f)); }; - return CreateFunctionPass(pass_func, 3, "AlterOpLayout", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout") diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 65e17fcefddb8..44ef35a285f56 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -226,7 +226,7 @@ Pass AnnotateTarget(const std::string& target) { return Downcast(relay::annotate_target::AnnotateTarget(f, target)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", - {runtime::String("InferType")}); + {"InferType"}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); } diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index 4b35ba219b674..ebcbd578b5f0b 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -133,8 +133,7 @@ Pass CanonicalizeCast() { [=](Function f, IRModule m, PassContext pc) { return Downcast(CanonicalizeCast(f)); }; - return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast") diff --git a/src/relay/transforms/canonicalize_ops.cc b/src/relay/transforms/canonicalize_ops.cc index 44140a902a2fd..1d3111b29d7d5 100644 --- a/src/relay/transforms/canonicalize_ops.cc +++ b/src/relay/transforms/canonicalize_ops.cc @@ -74,8 +74,7 @@ Pass CanonicalizeOps() { [=](Function f, IRModule m, PassContext pc) { return Downcast(CanonicalizeOps(f)); }; - return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps") diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index 3c8eea04d28f3..af6b1353f5acf 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -220,8 +220,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) { [=](Function f, IRModule m, PassContext pc) { return Downcast(CombineParallelConv2D(f, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D") diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 2dc8321e517b2..1278020ac7353 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -80,8 +80,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) { [=](Function f, IRModule m, PassContext pc) { return Downcast(CombineParallelDense(f, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelDense", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense") diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index f63f169be4086..361565ef11d76 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -193,8 +193,7 @@ Pass CombineParallelOpBatch(const std::string& op_name, batch_op_name, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch") diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index d43a0851e0997..dbb2c38e3f274 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -133,9 +133,7 @@ Pass ConvertLayout(const std::string& desired_layout) { return Downcast(relay::convert_op_layout::ConvertLayout(f, desired_layout)); }; return CreateFunctionPass( - pass_func, 3, "ConvertLayout", - {runtime::String("InferType"), - runtime::String("CanonicalizeOps")}); + pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"}); } TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 9955ef6ee7d2c..908ba87a8c521 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -573,8 +573,7 @@ Pass RewriteAnnotatedOps(int fallback_device) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); }; - return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation") diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index 696e83a7db538..68c59f5ea2ef1 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -91,8 +91,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) { [=](Function f, IRModule m, PassContext pc) { return Downcast(EliminateCommonSubexpr(f, fskip)); }; - return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr") diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index 668982e561e88..8234dea5e075f 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -70,8 +70,7 @@ Pass FastMath() { [=](Function f, IRModule m, PassContext pc) { return Downcast(FastMath(f)); }; - return CreateFunctionPass(pass_func, 4, "FastMath", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.FastMath") diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 11325f6526b84..cfe74bfd8ef16 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -960,8 +960,7 @@ Pass ForwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::ForwardFoldScaleAxis(f)); }; - return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") @@ -973,8 +972,7 @@ Pass BackwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::BackwardFoldScaleAxis(f)); }; - return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis") diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index cdd29394a2047..f646042962f0c 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -980,8 +980,7 @@ Pass FuseOps(int fuse_opt_level) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; return Downcast(FuseOps(f, opt_level, m)); }; - return CreateFunctionPass(pass_func, 1, "FuseOps", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.FuseOps") diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc index b7a0945951295..0cb1c7d9b51db 100644 --- a/src/relay/transforms/legalize.cc +++ b/src/relay/transforms/legalize.cc @@ -102,7 +102,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::legalize::Legalize(f, legalize_map_attr_name)); }; - return CreateFunctionPass(pass_func, 1, "Legalize", {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize); diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index b33799a26b430..d349fdddeeea8 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -204,8 +204,7 @@ Pass SimplifyInference() { [=](Function f, IRModule m, PassContext pc) { return Downcast(SimplifyInference(f)); }; - return CreateFunctionPass(pass_func, 0, "SimplifyInference", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference") diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index f1198e7274016..063247db09b66 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -261,7 +261,7 @@ TEST(String, empty) { using namespace std; String s{"hello"}; CHECK_EQ(s.empty(), false); - s = ""; + s = std::string(""); CHECK_EQ(s.empty(), true); } diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index f182cbaae6315..5a71023e5d602 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -25,9 +25,7 @@ def check_json_roundtrip(node): json_str = tvm.ir.save_json(node) - print(node) back = tvm.ir.load_json(json_str) - print(back) assert tvm.ir.structural_equal(back, node, map_free_vars=True) @@ -99,11 +97,13 @@ def test_function(): type_params = tvm.runtime.convert([]) fn = relay.Function(params, body, ret_type, type_params) fn = fn.with_attr("test_attribute", "value") + fn = fn.with_attr("test_attribute1", "value1") assert fn.params == params assert fn.body == body assert fn.type_params == type_params assert fn.span == None assert fn.attrs["test_attribute"] == "value" + assert fn.attrs["test_attribute1"] == "value1" str(fn) check_json_roundtrip(fn)