Skip to content

Commit

Permalink
[RELAY] TextPrinter: Use Map Format (#2553)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Feb 3, 2019
1 parent e2970b2 commit e0af5c2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
39 changes: 20 additions & 19 deletions src/relay/ir/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* we support a meta-data section in the text format.
* We allow the text format to refer to a node in the meta-data section.
*
* The meta-data section is a json serialized string of an Array<NodeRef>.
* The meta-data section is a json serialized string of an Map<string, Array<NodeRef>>.
* Each element in the meta-data section can be referenced by the text format.
* Each meta data node is printed in the following format.
*
* meta.<type-key-of-node>(<index-in-meta-section>)
* meta[type-key-of-node>][<index-in-meta-section>]
*
* Specifically, consider the following IR(constructed by python).
*
Expand All @@ -63,7 +63,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
*
* \code
*
* fn (%x: Tensor[(meta.Variable(id=0),), float32]) {
* fn (%x: Tensor[(meta[Variable][0],), float32]) {
* %x
* }
* # Meta data section is a json-serialized string
Expand All @@ -74,7 +74,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
*
* Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta-data section,
* we print it as meta.Variable(0).
* we print it as meta[Variable][0].
*
* The text parser can recover this object by loading from the corresponding
* location in the meta data section.
Expand All @@ -91,26 +91,27 @@ class TextMetaDataContext {
* \return A string representation of the meta node.
*/
std::string GetMetaNode(const NodeRef& node) {
std::ostringstream os;
auto it = meta_index_.find(node);
int64_t index;
if (it != meta_index_.end()) {
index = it->second;
} else {
index = static_cast<int64_t>(meta_data_.size());
meta_data_.push_back(node);
meta_index_[node] = index;
auto it = meta_repr_.find(node);
if (it != meta_repr_.end()) {
return it->second;
}
os << "meta." << node->type_key() << "(id=" << index << ")";
return os.str();
Array<NodeRef>& mvector =
meta_data_[node->type_key()];
int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node);
std::ostringstream os;
os << "meta[" << node->type_key() << "][" << index << "]";
meta_repr_[node] = os.str();
return meta_repr_[node];
}
/*!
* \brief Get the metadata section in json format.
* \return the meta datastring.
*/
std::string GetMetaSection() const {
if (meta_data_.size() == 0) return std::string();
return SaveJSON(Array<NodeRef>(meta_data_));
return SaveJSON(Map<std::string, NodeRef>(
meta_data_.begin(), meta_data_.end()));
}

/*! \return whether the meta data context is empty. */
Expand All @@ -120,9 +121,9 @@ class TextMetaDataContext {

private:
/*! \brief additional metadata stored in TVM json format */
std::vector<NodeRef> meta_data_;
/*! \brief map from meta data into its index */
std::unordered_map<NodeRef, int64_t, NodeHash, NodeEqual> meta_index_;
std::unordered_map<std::string, Array<NodeRef> > meta_data_;
/*! \brief map from meta data into its string representation */
std::unordered_map<NodeRef, std::string, NodeHash, NodeEqual> meta_repr_;
};

class TextPrinter :
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def test_meta_data():
f = relay.Function([x, w], z)
text = f.astext()
assert "channels=2" in text
assert "meta.Variable(id=0)" in text
assert "meta[Variable][0]" in text
show(text)

text = relay.const([1,2,3]).astext()
assert "meta.relay.Constant(id=0)" in text
assert "meta[relay.Constant][0]" in text
show(text)


Expand Down

0 comments on commit e0af5c2

Please sign in to comment.