Skip to content

Commit

Permalink
[TOOLS] JSON upgrader to upgrade serialized json. (#4730)
Browse files Browse the repository at this point in the history
During Unified IR refactor we will change the structure of IRs.
This will cause certain historical modules stored via json no longer
able to be loaded by the current version.

This PR introduces a backward compatible layer to try its best effort
to upgrade json from previous version(this case 0.6) to the current version.
We mainly aim to support update of high-level ir(relay).
  • Loading branch information
tqchen authored Jan 17, 2020
1 parent a5bb789 commit 67b97e5
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from numbers import Integral as _Integral

from ._ffi.base import string_types
from ._ffi.base import string_types, TVMError
from ._ffi.object import register_object, Object
from ._ffi.object import convert_to_object as _convert_to_object
from ._ffi.object_generic import _scalar_type_inference
Expand All @@ -35,6 +35,7 @@
from . import schedule as _schedule
from . import container as _container
from . import tag as _tag
from . import json_compact

int8 = "int8"
int32 = "int32"
Expand Down Expand Up @@ -154,7 +155,12 @@ def load_json(json_str):
node : Object
The loaded tvm node.
"""
return _api_internal._load_json(json_str)

try:
return _api_internal._load_json(json_str)
except TVMError:
json_str = json_compact.upgrade_json(json_str)
return _api_internal._load_json(json_str)


def save_json(node):
Expand Down
93 changes: 93 additions & 0 deletions python/tvm/json_compact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tool to upgrade json from historical versions."""
import json

def create_updater(node_map, from_ver, to_ver):
"""Create an updater to update json loaded data.
Parameters
----------
node_map : Map[str, Function]
Map from type_key to updating function
from_ver : str
Prefix of version that we can accept,
to_ver : str
The target version.
Returns
-------
fupdater : function
The updater function
"""
def _updater(data):
assert data["attrs"]["tvm_version"].startswith(from_ver)
nodes = data["nodes"]
for idx, item in enumerate(nodes):
f = node_map.get(item["type_key"], None)
if f:
nodes[idx] = f(item, nodes)
data["attrs"]["tvm_version"] = to_ver
return data
return _updater


def create_updater_06_to_07():
"""Create an update to upgrade json from v0.6 to v0.7
Returns
-------
fupdater : function
The updater function
"""
def _ftype_var(item, nodes):
vindex = int(item["attrs"]["var"])
item["attrs"]["name_hint"] = nodes[vindex]["attrs"]["name"]
# set vindex to null
nodes[vindex]["type_key"] = ""
del item["attrs"]["var"]
return item

node_map = {
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
}
return create_updater(node_map, "0.6", "0.7")


def upgrade_json(json_str):
"""Update json from a historical version.
Parameters
----------
json_str : str
A historical json file.
Returns
-------
updated_json : str
The updated version.
"""
data = json.loads(json_str)
from_version = data["attrs"]["tvm_version"]
if from_version.startswith("0.6"):
data = create_updater_06_to_07()(data)
else:
raise ValueError("Cannot update from version %s" % from_version)
return json.dumps(data, indent=2)
47 changes: 47 additions & 0 deletions tests/python/relay/test_json_compact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm import relay
import json

def test_type_var():
# type var in 0.6
nodes = [
{"type_key": ""},
{"type_key": "relay.TypeVar",
"attrs": {"kind": "0", "span": "0", "var": "2"}},
{"type_key": "Variable",
"attrs": {"dtype": "int32", "name": "in0"}},
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.load_json(json.dumps(data))
assert isinstance(tvar, relay.TypeVar)
assert tvar.name_hint == "in0"
nodes[1]["type_key"] = "relay.GlobalTypeVar"
tvar = tvm.load_json(json.dumps(data))
assert isinstance(tvar, relay.GlobalTypeVar)
assert tvar.name_hint == "in0"


if __name__ == "__main__":
test_type_var()

0 comments on commit 67b97e5

Please sign in to comment.