Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Use PopenPool instead of multiprocessing.pool #8492

Merged
merged 38 commits into from
Aug 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
17a65ab
reapply changes
Aug 3, 2021
3ecfbb7
create tvm.testing
Aug 3, 2021
02bdfb9
reorganize testing utils
Aug 3, 2021
3100d99
remove wild card matching for auto_scheduler_common
Aug 3, 2021
4b37354
disable invalid name
Aug 3, 2021
8f659bf
nit
Aug 3, 2021
d7d9af6
address comments
Aug 3, 2021
7bb1aaa
linting
Aug 3, 2021
1b46d98
add module docstring
Aug 3, 2021
b339852
remove __init__.py in testing
Aug 3, 2021
51976f0
address Junru's comment
Aug 3, 2021
41182bb
add __init__.py
Aug 3, 2021
6432d6c
remove testing.py
Aug 3, 2021
b76d1aa
get subpackage to work
Aug 3, 2021
5a22719
avoid wild card matching
Aug 4, 2021
32c5e22
resolve self import
Aug 4, 2021
2598c36
more dependencies
Aug 4, 2021
3c172b5
fix ci issues
Aug 4, 2021
d94fb07
fixing ci issues
Aug 4, 2021
5cb628e
last xgboost error
Aug 4, 2021
2e06c35
revert changes to xgboost_cost_model
Aug 4, 2021
e961943
add ErrorTest and _ffi_api
Aug 5, 2021
0aa7ba1
reorg
Aug 5, 2021
108402c
format _ffi_api.py
Aug 5, 2021
43b2894
fix reimported
Aug 5, 2021
10aeba9
all changes
Aug 7, 2021
615450a
fix measure tests
Aug 9, 2021
1849e29
restore tvm.python.testing.py
Aug 9, 2021
ea27936
restore python.tvm.testing.utils.py
Aug 9, 2021
a6e06ee
remove task_input_buffer
Aug 9, 2021
9991d3f
linting and naming convention updated
Aug 9, 2021
1d7e25f
address comments and update __init__
Aug 9, 2021
10288a6
address comments
Aug 10, 2021
ed51463
two imports added
Aug 10, 2021
211932a
Persist PopenWorker
vinx13 Aug 10, 2021
a3274cd
linting
Aug 10, 2021
193709b
Use PopenpoolExecutor
vinx13 Aug 11, 2021
00d9476
linting
Aug 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 119 additions & 97 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind
from tvm.target import Target


Expand Down Expand Up @@ -599,7 +600,7 @@ class MeasureErrorNo(object):
UNKNOWN_ERROR = 8 # Unknown error


def _timed_func(inp_serialized, build_func, verbose):
def _local_build_worker(inp_serialized, build_func, verbose):
tic = time.time()
inp = MeasureInput.deserialize(inp_serialized)
task = inp.task
Expand Down Expand Up @@ -664,15 +665,13 @@ def local_build_worker(args):
)
build_func = BuildFunc.build_func

res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose))
shingjan marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(res, TimeoutError):
if verbose >= 1:
print(".T", end="", flush=True) # Build timeout
res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
elif isinstance(res, Exception):
try:
res = _local_build_worker(inp, build_func, verbose)
# pylint: disable=broad-except
except Exception:
if verbose >= 1:
print(".E", end="", flush=True) # Build error
res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout
res = None, [], MeasureErrorNo.COMPILE_HOST, make_traceback_info(), timeout

return res

Expand Down Expand Up @@ -701,9 +700,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
res : List[BuildResult]
The build results of these MeasureInputs.
"""
# This pool is not doing computationally intensive work, so we can use threads
pool = multiprocessing.pool.ThreadPool(n_parallel)
tuple_res = pool.map(
executor = PopenPoolExecutor(n_parallel, timeout)
tuple_res = executor.map_with_error_catching(
local_build_worker,
[
(
Expand All @@ -715,13 +713,16 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
for i in inputs
],
)
pool.terminate()
pool.join()
del pool

results = []
for res in tuple_res:
results.append(BuildResult(*res))
if res.status == StatusKind.COMPLETE:
results.append(BuildResult(*res.value))
else:
assert res.status == StatusKind.TIMEOUT
if verbose >= 1:
print(".T", end="", flush=True) # Build timeout
results.append(BuildResult(None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout))

return results

Expand Down Expand Up @@ -817,21 +818,66 @@ def prepare_input_map(args):
return tensor_input_map


def prepare_runner_args(inp, build_res):
"""This function prepares the pre-defined arguments in `TASK_INPUT_BUFFER_TABLE` for local/rpc
runner in main process

Parameters
----------
inp : MeasureInput
Measure input to be measured.

build_res : BuildResult
Build result to be measured.

Returns
-------
List[Optional[numpy.ndarray]] :
List of arguments for running the program. If the argument does not have a pre-defined input
buffer, None is added to the list as a placeholder.

"""
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

task_input_names = inp.task.task_input_names
tensor_input_map = prepare_input_map(build_res.args)
if not task_input_names:
tensor_input_map = {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
task_input_buffer = get_task_input_buffer(inp.task.workload_key, tensor_name)
# convert tvm.NDArray to picklable numpy.ndarray
args.append(task_input_buffer.numpy())
junrushao marked this conversation as resolved.
Show resolved Hide resolved
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
args.append(None)
if task_inputs_count != len(task_input_names):
raise RuntimeError("task_inputs not fully matched, check if there's any unexpected error")
return args


def _timed_eval_func(
inp_serialized,
build_res,
args,
number,
repeat,
min_repeat_ms,
cooldown_interval,
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -862,33 +908,18 @@ def _timed_eval_func(
try:
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"

tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(
ndarray.array(
get_task_input_buffer(inp.task.workload_key, tensor_name), dev
)
)
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
assert len(args) == len(build_res.args)
# pylint: disable=consider-using-enumerate
for idx in range(len(args)):
if args[idx] is None:
build_res_arg = build_res.args[idx]
empty_array = ndarray.empty(
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
raise RuntimeError(
"task_inputs not fully matched, check if there's any unexpected error"
)
args[idx] = empty_array
else:
args[idx] = ndarray.array(args[idx], dev)
dev.sync()
costs = time_f(*args).results
# pylint: disable=broad-except
Expand Down Expand Up @@ -968,6 +999,7 @@ def local_run(

measure_results = []
assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
worker = PopenWorker()
for inp, build_res in zip(inputs, build_results):
if build_res.error_no != 0:
res = (
Expand All @@ -978,20 +1010,22 @@ def local_run(
time.time(),
)
else:
args = prepare_runner_args(inp, build_res)
res = call_func_with_timeout(
worker,
timeout,
_timed_eval_func,
args=(
inp.serialize(),
build_res,
args,
number,
repeat,
min_repeat_ms,
cooldown_interval,
enable_cpu_cache_flush,
verbose,
),
add_thread_wrapper=True,
)
if isinstance(res, TimeoutError):
if verbose >= 1:
Expand Down Expand Up @@ -1022,9 +1056,10 @@ def local_run(
return measure_results


def _timed_rpc_run(
def _rpc_run(
inp_serialized,
build_res,
args,
key,
host,
port,
Expand All @@ -1037,11 +1072,7 @@ def _timed_rpc_run(
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -1080,32 +1111,18 @@ def _timed_rpc_run(
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"

tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(
ndarray.array(
get_task_input_buffer(inp.task.workload_key, tensor_name), dev
)
)
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
assert len(args) == len(build_res.args)
# pylint: disable=consider-using-enumerate
for idx in range(len(args)):
if args[idx] is None:
build_res_arg = build_res.args[idx]
empty_array = ndarray.empty(
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
logger.warning(
"task_inputs not fully matched, check if there's any unexpected error"
)
args[idx] = empty_array
else:
args[idx] = ndarray.array(args[idx], dev)
dev.sync()

# First run for check that the kernel is correct
Expand Down Expand Up @@ -1152,7 +1169,7 @@ def _rpc_run_worker(args):
res : MeasureResult
The measure result of this Runner thread.
"""
_, build_res, _, _, _, _, timeout, _, _, _, _, _, verbose = args
_, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose = args
if build_res.error_no != MeasureErrorNo.NO_ERROR:
return (
(MAX_FLOAT,),
Expand All @@ -1162,24 +1179,16 @@ def _rpc_run_worker(args):
time.time(),
)

res = call_func_with_timeout(timeout, _timed_rpc_run, args=args)
if isinstance(res, TimeoutError):
if verbose >= 1:
print("*T", end="") # Run timeout
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
None,
build_res.time_cost + timeout,
time.time(),
)
elif isinstance(res, Exception):
try:
res = _rpc_run(*args)
# pylint: disable=broad-except
except Exception:
if verbose >= 1:
print("*E", end="") # Run error
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUNTIME_DEVICE,
str(res),
make_traceback_info(),
build_res.time_cost + timeout,
time.time(),
)
Expand Down Expand Up @@ -1259,13 +1268,14 @@ def rpc_runner_run(
"""
assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
# This pool is not doing computationally intensive work, so we can use threads
pool = multiprocessing.pool.ThreadPool(n_parallel)
tuple_res = pool.map(
executor = PopenPoolExecutor(n_parallel)
tuple_res = executor.map_with_error_catching(
_rpc_run_worker,
[
(
inp.serialize(),
build_res,
prepare_runner_args(inp, build_res),
key,
host,
port,
Expand All @@ -1281,13 +1291,25 @@ def rpc_runner_run(
for inp, build_res in zip(inputs, build_results)
],
)
pool.terminate()
pool.join()
del pool

results = []
for res in tuple_res:
results.append(MeasureResult(*res))
for i, res in enumerate(tuple_res):
if res.status == StatusKind.COMPLETE:
results.append(MeasureResult(*res.value))
else:
assert res.status == StatusKind.TIMEOUT
if verbose >= 1:
print("*T", end="") # Run timeout
build_res = build_results[i]
results.append(
MeasureResult(
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
None,
build_res.time_cost + timeout,
time.time(),
)
)

if verbose >= 1:
print("")
Expand Down
Loading