Skip to content

Commit

Permalink
Fix LocalBuilder on macos with python 3.8. (apache#6083)
Browse files Browse the repository at this point in the history
Python 3.8 changes the default way multiprocessing creates new processes
on macOS from forking to spawing. Spawning requires all objects to be
picklable. Nested functions and lambdas are not picklable, so this
commit fixes the one instance of nested functions in the codebase that
was causing issues.
  • Loading branch information
tkonolige authored and Trevor Morris committed Sep 2, 2020
1 parent 3e915d6 commit 22c4b75
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, timeout=10, n_parallel=None, build_func='default'):
build_func = ndk.create_shared
else:
raise ValueError("Invalid build_func" + build_func)
self.build_func = _wrap_build_func(build_func)
self.build_func = _WrappedBuildFunc(build_func)
self.executor = LocalExecutor(timeout=timeout)
self.tmp_dir = tempfile.mkdtemp()

Expand Down Expand Up @@ -390,25 +390,29 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)


def _wrap_build_func(build_func):
class _WrappedBuildFunc():
"""
Wrap build_func to a function that can be used in measure.
Note: this is a class instead of a closure so that it can be pickled when
using multiprocessing.
Parameters
----------
build_func : The compilation function
We expect fcompile to contain an attr "output_format"
Returns
-------
wrapped_build_func : function
wrapped_build_func : callable
The wrapped build function
"""
if not hasattr(build_func, "output_format"):
raise AttributeError("Expect build_func to have the attribute output_format.")
output_format = build_func.output_format
def __init__(self, build_func):
if not hasattr(build_func, "output_format"):
raise AttributeError("Expect build_func to have the attribute output_format.")
self.build_func = build_func

def _wrapped(measure_input, tmp_dir, **kwargs):
def __call__(self, measure_input, tmp_dir, **kwargs):
"""
Wrapped build func.
Expand All @@ -423,15 +427,13 @@ def _wrapped(measure_input, tmp_dir, **kwargs):
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
getrandbits(64), output_format))
getrandbits(64), self.build_func.output_format))
# TODO(tvm-team) consider linline _build_func_common
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, build_func)
func.export_library(filename, self.build_func)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
return _wrapped


def run_through_rpc(measure_input, build_result,
number, repeat, min_repeat_ms, cooldown_interval,
Expand Down

0 comments on commit 22c4b75

Please sign in to comment.