Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC][IR] Initial stab at std::string->String upgrade #5438

Merged
merged 1 commit into from
Apr 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/ir/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SourceName;
class SourceNameNode : public Object {
public:
/*! \brief The source name. */
std::string name;
String name;
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }

Expand All @@ -64,7 +64,7 @@ class SourceName : public ObjectRef {
* \param name Name of the operator.
* \return SourceName valid throughout program lifetime.
*/
TVM_DLL static SourceName Get(const std::string& name);
TVM_DLL static SourceName Get(const String& name);

TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode);
};
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class TypeVarNode : public TypeNode {
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
String name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;

Expand Down Expand Up @@ -262,7 +262,7 @@ class TypeVar : public Type {
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL TypeVar(std::string name_hint, TypeKind kind);
TVM_DLL TypeVar(String name_hint, TypeKind kind);

TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
};
Expand Down
27 changes: 24 additions & 3 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# under the License.
"""Tool to upgrade json from historical versions."""
import json
import tvm.ir
import tvm.runtime


def create_updater(node_map, from_ver, to_ver):
"""Create an updater to update json loaded data.
Expand All @@ -41,8 +44,12 @@ def _updater(data):
nodes = data["nodes"]
for idx, item in enumerate(nodes):
f = node_map.get(item["type_key"], None)
if f:
nodes[idx] = f(item, nodes)
if isinstance(f, list):
for fpass in f:
item = fpass(item, nodes)
elif f:
item = f(item, nodes)
nodes[idx] = item
data["attrs"]["tvm_version"] = to_ver
return data
return _updater
Expand Down Expand Up @@ -84,12 +91,26 @@ def _update_global_key(item, _):
del item["global_key"]
return item

def _update_from_std_str(key):
def _convert(item, nodes):
str_val = item["attrs"][key]
jdata = json.loads(tvm.ir.save_json(tvm.runtime.String(str_val)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we bring the serialization of tvm:String out of save_json?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate a bit about what do you mean? Right now all tvm objects are serialized through save_json, what we are doing here is to get that particular entry and append to the end of the nodes.

root_idx = jdata["root"]
val = jdata["nodes"][root_idx]
sidx = len(nodes)
nodes.append(val)
item["attrs"][key] = '%d' % sidx
return item

return _convert


node_map = {
# Base IR
"SourceName": _update_global_key,
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": _ftype_var,
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
Expand Down
18 changes: 11 additions & 7 deletions src/ir/span.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@

namespace tvm {

ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
ObjectPtr<Object> GetSourceNameNode(const String& name) {
// always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr
static std::unordered_map<std::string, ObjectPtr<SourceNameNode> > source_map;
static std::unordered_map<String, ObjectPtr<SourceNameNode> > source_map;

auto sn = source_map.find(name);
if (sn == source_map.end()) {
Expand All @@ -41,7 +41,11 @@ ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
}
}

SourceName SourceName::Get(const std::string& name) {
ObjectPtr<Object> GetSourceNameNodeByStr(const std::string& name) {
return GetSourceNameNode(name);
}

SourceName SourceName::Get(const String& name) {
return SourceName(GetSourceNameNode(name));
}

Expand All @@ -55,10 +59,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode)
.set_repr_bytes([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name;
});
.set_creator(GetSourceNameNodeByStr)
.set_repr_bytes([](const Object* n) -> std::string {
return static_cast<const SourceNameNode*>(n)->name;
});

Span SpanNode::make(SourceName source, int lineno, int col_offset) {
auto n = make_object<SpanNode>();
Expand Down
4 changes: 2 additions & 2 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});


TypeVar::TypeVar(std::string name, TypeKind kind) {
TypeVar::TypeVar(String name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
Expand All @@ -76,7 +76,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode);

TVM_REGISTER_GLOBAL("ir.TypeVar")
.set_body_typed([](std::string name, int kind) {
.set_body_typed([](String name, int kind) {
return TypeVar(name, static_cast<TypeKind>(kind));
});

Expand Down