diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 46b0d25b3d7d..8f6629a14f92 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -42,11 +42,11 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO * we support a meta-data section in the text format. * We allow the text format to refer to a node in the meta-data section. * - * The meta-data section is a json serialized string of an Array. + * The meta-data section is a json serialized string of an Map>. * Each element in the meta-data section can be referenced by the text format. * Each meta data node is printed in the following format. * - * meta.() + * meta[type-key-of-node>][] * * Specifically, consider the following IR(constructed by python). * @@ -63,7 +63,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO * * \code * - * fn (%x: Tensor[(meta.Variable(id=0),), float32]) { + * fn (%x: Tensor[(meta[Variable][0],), float32]) { * %x * } * # Meta data section is a json-serialized string @@ -74,7 +74,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO * * Note that we store tvm.var("n") in the meta data section. * Since it is stored in the index-0 in the meta-data section, - * we print it as meta.Variable(0). + * we print it as meta[Variable][0]. * * The text parser can recover this object by loading from the corresponding * location in the meta data section. @@ -91,18 +91,18 @@ class TextMetaDataContext { * \return A string representation of the meta node. */ std::string GetMetaNode(const NodeRef& node) { - std::ostringstream os; - auto it = meta_index_.find(node); - int64_t index; - if (it != meta_index_.end()) { - index = it->second; - } else { - index = static_cast(meta_data_.size()); - meta_data_.push_back(node); - meta_index_[node] = index; + auto it = meta_repr_.find(node); + if (it != meta_repr_.end()) { + return it->second; } - os << "meta." << node->type_key() << "(id=" << index << ")"; - return os.str(); + Array& mvector = + meta_data_[node->type_key()]; + int64_t index = static_cast(mvector.size()); + mvector.push_back(node); + std::ostringstream os; + os << "meta[" << node->type_key() << "][" << index << "]"; + meta_repr_[node] = os.str(); + return meta_repr_[node]; } /*! * \brief Get the metadata section in json format. @@ -110,7 +110,8 @@ class TextMetaDataContext { */ std::string GetMetaSection() const { if (meta_data_.size() == 0) return std::string(); - return SaveJSON(Array(meta_data_)); + return SaveJSON(Map( + meta_data_.begin(), meta_data_.end())); } /*! \return whether the meta data context is empty. */ @@ -120,9 +121,9 @@ class TextMetaDataContext { private: /*! \brief additional metadata stored in TVM json format */ - std::vector meta_data_; - /*! \brief map from meta data into its index */ - std::unordered_map meta_index_; + std::unordered_map > meta_data_; + /*! \brief map from meta data into its string representation */ + std::unordered_map meta_repr_; }; class TextPrinter : diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 624ef71ed870..07d2ad2af447 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -48,11 +48,11 @@ def test_meta_data(): f = relay.Function([x, w], z) text = f.astext() assert "channels=2" in text - assert "meta.Variable(id=0)" in text + assert "meta[Variable][0]" in text show(text) text = relay.const([1,2,3]).astext() - assert "meta.relay.Constant(id=0)" in text + assert "meta[relay.Constant][0]" in text show(text)