Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Object] Restore the StrMap behavior in JSON/SHash/SEqual #5719

Merged
merged 1 commit into from
Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ using runtime::ObjectRef;
using runtime::String;
using runtime::StringObj;

/*! \brief String-aware ObjectRef hash functor */
struct ObjectHash {
size_t operator()(const ObjectRef& a) const {
if (const auto* str = a.as<StringObj>()) {
Expand All @@ -59,6 +60,7 @@ struct ObjectHash {
}
};

/*! \brief String-aware ObjectRef equal functor */
struct ObjectEqual {
bool operator()(const ObjectRef& a, const ObjectRef& b) const {
if (a.same_as(b)) {
Expand Down Expand Up @@ -96,8 +98,7 @@ class MapNode : public Object {
* \tparam V The value NodeRef type.
*/
template <typename K, typename V,
typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value>::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it is better to directly change it to static_assert later,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but the difference is less fundamental imo

typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
class Map : public ObjectRef {
public:
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _convert(item, nodes):
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequential": _rename("transform.Sequential"),
"StrMap": _rename("Map"),
# TIR
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
Expand Down
59 changes: 35 additions & 24 deletions src/node/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,40 +247,51 @@ struct MapNodeTrait {
}

static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) {
if (key->data.empty()) {
hash_reduce(uint64_t(0));
return;
}
if (key->data.begin()->first->IsInstance<StringObj>()) {
bool is_str_map = std::all_of(key->data.begin(), key->data.end(), [](const auto& v) {
return v.first->template IsInstance<StringObj>();
});
if (is_str_map) {
SHashReduceForSMap(key, hash_reduce);
} else {
SHashReduceForOMap(key, hash_reduce);
}
}

static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
for (const auto& kv : lhs->data) {
// Only allow equal checking if the keys are already mapped
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
if (!rhs_key.defined()) return false;
auto it = rhs->data.find(rhs_key);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}

static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}

static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
if (rhs->data.size() == 0) return true;
if (lhs->data.begin()->first->IsInstance<StringObj>()) {
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
} else {
for (const auto& kv : lhs->data) {
// Only allow equal checking if the keys are already mapped
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
if (!rhs_key.defined()) return false;
auto it = rhs->data.find(rhs_key);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
bool ls = std::all_of(lhs->data.begin(), lhs->data.end(),
[](const auto& v) { return v.first->template IsInstance<StringObj>(); });
bool rs = std::all_of(rhs->data.begin(), rhs->data.end(),
[](const auto& v) { return v.first->template IsInstance<StringObj>(); });
if (ls != rs) {
return false;
}
return true;
return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal);
}
};

Expand Down
29 changes: 21 additions & 8 deletions src/node/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,18 @@ class NodeIndexer : public AttrVisitor {
}
} else if (node->IsInstance<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
for (const auto& kv : n->data) {
if (!kv.first->IsInstance<StringObj>()) {
bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) {
return v.first->template IsInstance<StringObj>();
});
if (is_str_map) {
for (const auto& kv : n->data) {
MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else {
for (const auto& kv : n->data) {
MakeIndex(const_cast<Object*>(kv.first.get()));
MakeIndex(const_cast<Object*>(kv.second.get()));
}
MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else {
// if the node already have repr bytes, no need to visit Attrs.
Expand Down Expand Up @@ -246,13 +253,19 @@ class JSONAttrGetter : public AttrVisitor {
}
} else if (node->IsInstance<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
for (const auto& kv : n->data) {
if (const auto* str = kv.first.as<StringObj>()) {
node_->keys.push_back(std::string(str->data, str->size));
} else {
bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) {
return v.first->template IsInstance<StringObj>();
});
if (is_str_map) {
for (const auto& kv : n->data) {
node_->keys.push_back(Downcast<String>(kv.first));
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
} else {
for (const auto& kv : n->data) {
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.first.get())));
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
} else {
// recursively index normal object.
Expand Down
29 changes: 29 additions & 0 deletions tests/python/relay/test_json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,34 @@ def test_tir_var():
assert y.name == "y"


def test_str_map():
nodes = [
{'type_key': ''},
{'type_key': 'StrMap', 'keys': ['z', 'x'], 'data': [2, 3]},
{'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
{'type_key': 'Max', 'attrs': {'a': '4', 'b': '10', 'dtype': 'int32'}},
{'type_key': 'Add', 'attrs': {'a': '5', 'b': '9', 'dtype': 'int32'}},
{'type_key': 'Add', 'attrs': {'a': '6', 'b': '8', 'dtype': 'int32'}},
{'type_key': 'tir.Var', 'attrs': {'dtype': 'int32', 'name': '7', 'type_annotation': '0'}},
{'type_key': 'runtime.String', 'repr_str': 'x'},
{'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '1'}},
{'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
{'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '100'}}
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
x = tvm.ir.load_json(json.dumps(data))
assert(isinstance(x, tvm.ir.container.Map))
assert(len(x) == 2)
assert('x' in x)
assert('z' in x)
assert(bool(x['z'] == 2))


if __name__ == "__main__":
test_op()
test_type_var()
Expand All @@ -194,3 +222,4 @@ def test_tir_var():
test_func_tuple_type()
test_global_var()
test_tir_var()
test_str_map()