Skip to content

Commit

Permalink
[AutoScheduler] Fix deserization of workload registry entry (apache#8662
Browse files Browse the repository at this point in the history
)
  • Loading branch information
vinx13 authored and ylc committed Jan 13, 2022
1 parent 81949ad commit 8e24f89
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/tvm/auto_scheduler/workload_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ def deserialize_workload_registry_entry(data):
name, value = data
if name not in WORKLOAD_FUNC_REGISTRY:
# pylint: disable=assignment-from-no-return
WORKLOAD_FUNC_REGISTRY[name] = LoadJSON(value)
if not callable(value):
value = LoadJSON(value)
WORKLOAD_FUNC_REGISTRY[name] = value


def save_workload_func_registry(filename):
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_auto_scheduler_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,15 @@ def test_dag_measure_local_builder_runner():
assert mress[0].error_no == 0


def test_workload_serialization():
key = tvm.auto_scheduler.utils.get_func_name(matmul_auto_scheduler_test)
transfer_data = workload_registry.serialize_workload_registry_entry(key)
f_data = pickle.dumps(transfer_data)
f_new = pickle.loads(f_data)
del workload_registry.WORKLOAD_FUNC_REGISTRY[key]
workload_registry.deserialize_workload_registry_entry(f_new)


def test_measure_local_builder_rpc_runner():
if not tvm.testing.device_enabled("llvm"):
return
Expand Down Expand Up @@ -423,6 +432,7 @@ def foo():
test_workload_dis_factor()
test_measure_local_builder_runner()
test_dag_measure_local_builder_runner()
test_workload_serialization()
test_measure_local_builder_rpc_runner()
test_measure_target_host()
test_measure_special_inputs_map_by_name_local_runner()
Expand Down

0 comments on commit 8e24f89

Please sign in to comment.