Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOOLS] JSON upgrader to upgrade serialized json. #4730

Merged
merged 1 commit into from
Jan 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from numbers import Integral as _Integral

from ._ffi.base import string_types
from ._ffi.base import string_types, TVMError
from ._ffi.object import register_object, Object
from ._ffi.object import convert_to_object as _convert_to_object
from ._ffi.object_generic import _scalar_type_inference
Expand All @@ -35,6 +35,7 @@
from . import schedule as _schedule
from . import container as _container
from . import tag as _tag
from . import json_compact

int8 = "int8"
int32 = "int32"
Expand Down Expand Up @@ -154,7 +155,12 @@ def load_json(json_str):
node : Object
The loaded tvm node.
"""
return _api_internal._load_json(json_str)

try:
return _api_internal._load_json(json_str)
except TVMError:
json_str = json_compact.upgrade_json(json_str)
return _api_internal._load_json(json_str)


def save_json(node):
Expand Down
93 changes: 93 additions & 0 deletions python/tvm/json_compact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tool to upgrade json from historical versions."""
import json

def create_updater(node_map, from_ver, to_ver):
"""Create an updater to update json loaded data.

Parameters
----------
node_map : Map[str, Function]
Map from type_key to updating function

from_ver : str
Prefix of version that we can accept,

to_ver : str
The target version.

Returns
-------
fupdater : function
The updater function
"""
def _updater(data):
assert data["attrs"]["tvm_version"].startswith(from_ver)
nodes = data["nodes"]
for idx, item in enumerate(nodes):
f = node_map.get(item["type_key"], None)
if f:
nodes[idx] = f(item, nodes)
data["attrs"]["tvm_version"] = to_ver
return data
return _updater


def create_updater_06_to_07():
"""Create an update to upgrade json from v0.6 to v0.7

Returns
-------
fupdater : function
The updater function
"""
def _ftype_var(item, nodes):
vindex = int(item["attrs"]["var"])
item["attrs"]["name_hint"] = nodes[vindex]["attrs"]["name"]
# set vindex to null
nodes[vindex]["type_key"] = ""
del item["attrs"]["var"]
return item

node_map = {
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
}
return create_updater(node_map, "0.6", "0.7")


def upgrade_json(json_str):
"""Update json from a historical version.

Parameters
----------
json_str : str
A historical json file.

Returns
-------
updated_json : str
The updated version.
"""
data = json.loads(json_str)
from_version = data["attrs"]["tvm_version"]
if from_version.startswith("0.6"):
data = create_updater_06_to_07()(data)
else:
raise ValueError("Cannot update from version %s" % from_version)
return json.dumps(data, indent=2)
47 changes: 47 additions & 0 deletions tests/python/relay/test_json_compact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm import relay
import json

def test_type_var():
# type var in 0.6
nodes = [
{"type_key": ""},
{"type_key": "relay.TypeVar",
"attrs": {"kind": "0", "span": "0", "var": "2"}},
{"type_key": "Variable",
"attrs": {"dtype": "int32", "name": "in0"}},
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.load_json(json.dumps(data))
assert isinstance(tvar, relay.TypeVar)
assert tvar.name_hint == "in0"
nodes[1]["type_key"] = "relay.GlobalTypeVar"
tvar = tvm.load_json(json.dumps(data))
assert isinstance(tvar, relay.GlobalTypeVar)
assert tvar.name_hint == "in0"


if __name__ == "__main__":
test_type_var()