diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index cd8f8c9d1a3e8..885eb0d1d0f8d 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -245,7 +245,9 @@ def deserialize_workload_registry_entry(data): name, value = data if name not in WORKLOAD_FUNC_REGISTRY: # pylint: disable=assignment-from-no-return - WORKLOAD_FUNC_REGISTRY[name] = LoadJSON(value) + if not callable(value): + value = LoadJSON(value) + WORKLOAD_FUNC_REGISTRY[name] = value def save_workload_func_registry(filename): diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index d82cfd447a403..375f8167ff08c 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -293,6 +293,15 @@ def test_dag_measure_local_builder_runner(): assert mress[0].error_no == 0 +def test_workload_serialization(): + key = tvm.auto_scheduler.utils.get_func_name(matmul_auto_scheduler_test) + transfer_data = workload_registry.serialize_workload_registry_entry(key) + f_data = pickle.dumps(transfer_data) + f_new = pickle.loads(f_data) + del workload_registry.WORKLOAD_FUNC_REGISTRY[key] + workload_registry.deserialize_workload_registry_entry(f_new) + + def test_measure_local_builder_rpc_runner(): if not tvm.testing.device_enabled("llvm"): return @@ -423,6 +432,7 @@ def foo(): test_workload_dis_factor() test_measure_local_builder_runner() test_dag_measure_local_builder_runner() + test_workload_serialization() test_measure_local_builder_rpc_runner() test_measure_target_host() test_measure_special_inputs_map_by_name_local_runner()