Skip to content

Commit

Permalink
[TIR] Use PopenPool instead of multiprocessing.pool (apache#8492)
Browse files Browse the repository at this point in the history
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
2 people authored and ylc committed Sep 29, 2021
1 parent dff7bb8 commit bb7e94e
Show file tree
Hide file tree
Showing 20 changed files with 205 additions and 153 deletions.
216 changes: 119 additions & 97 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind
from tvm.target import Target


Expand Down Expand Up @@ -599,7 +600,7 @@ class MeasureErrorNo(object):
UNKNOWN_ERROR = 8 # Unknown error


def _timed_func(inp_serialized, build_func, verbose):
def _local_build_worker(inp_serialized, build_func, verbose):
tic = time.time()
inp = MeasureInput.deserialize(inp_serialized)
task = inp.task
Expand Down Expand Up @@ -664,15 +665,13 @@ def local_build_worker(args):
)
build_func = BuildFunc.build_func

res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose))
if isinstance(res, TimeoutError):
if verbose >= 1:
print(".T", end="", flush=True) # Build timeout
res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
elif isinstance(res, Exception):
try:
res = _local_build_worker(inp, build_func, verbose)
# pylint: disable=broad-except
except Exception:
if verbose >= 1:
print(".E", end="", flush=True) # Build error
res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout
res = None, [], MeasureErrorNo.COMPILE_HOST, make_traceback_info(), timeout

return res

Expand Down Expand Up @@ -701,9 +700,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
res : List[BuildResult]
The build results of these MeasureInputs.
"""
# This pool is not doing computationally intensive work, so we can use threads
pool = multiprocessing.pool.ThreadPool(n_parallel)
tuple_res = pool.map(
executor = PopenPoolExecutor(n_parallel, timeout)
tuple_res = executor.map_with_error_catching(
local_build_worker,
[
(
Expand All @@ -715,13 +713,16 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
for i in inputs
],
)
pool.terminate()
pool.join()
del pool

results = []
for res in tuple_res:
results.append(BuildResult(*res))
if res.status == StatusKind.COMPLETE:
results.append(BuildResult(*res.value))
else:
assert res.status == StatusKind.TIMEOUT
if verbose >= 1:
print(".T", end="", flush=True) # Build timeout
results.append(BuildResult(None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout))

return results

Expand Down Expand Up @@ -817,21 +818,66 @@ def prepare_input_map(args):
return tensor_input_map


def prepare_runner_args(inp, build_res):
"""This function prepares the pre-defined arguments in `TASK_INPUT_BUFFER_TABLE` for local/rpc
runner in main process
Parameters
----------
inp : MeasureInput
Measure input to be measured.
build_res : BuildResult
Build result to be measured.
Returns
-------
List[Optional[numpy.ndarray]] :
List of arguments for running the program. If the argument does not have a pre-defined input
buffer, None is added to the list as a placeholder.
"""
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

task_input_names = inp.task.task_input_names
tensor_input_map = prepare_input_map(build_res.args)
if not task_input_names:
tensor_input_map = {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
task_input_buffer = get_task_input_buffer(inp.task.workload_key, tensor_name)
# convert tvm.NDArray to picklable numpy.ndarray
args.append(task_input_buffer.numpy())
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
args.append(None)
if task_inputs_count != len(task_input_names):
raise RuntimeError("task_inputs not fully matched, check if there's any unexpected error")
return args


def _timed_eval_func(
inp_serialized,
build_res,
args,
number,
repeat,
min_repeat_ms,
cooldown_interval,
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -862,33 +908,18 @@ def _timed_eval_func(
try:
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"

tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(
ndarray.array(
get_task_input_buffer(inp.task.workload_key, tensor_name), dev
)
)
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
assert len(args) == len(build_res.args)
# pylint: disable=consider-using-enumerate
for idx in range(len(args)):
if args[idx] is None:
build_res_arg = build_res.args[idx]
empty_array = ndarray.empty(
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
raise RuntimeError(
"task_inputs not fully matched, check if there's any unexpected error"
)
args[idx] = empty_array
else:
args[idx] = ndarray.array(args[idx], dev)
dev.sync()
costs = time_f(*args).results
# pylint: disable=broad-except
Expand Down Expand Up @@ -968,6 +999,7 @@ def local_run(

measure_results = []
assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
worker = PopenWorker()
for inp, build_res in zip(inputs, build_results):
if build_res.error_no != 0:
res = (
Expand All @@ -978,20 +1010,22 @@ def local_run(
time.time(),
)
else:
args = prepare_runner_args(inp, build_res)
res = call_func_with_timeout(
worker,
timeout,
_timed_eval_func,
args=(
inp.serialize(),
build_res,
args,
number,
repeat,
min_repeat_ms,
cooldown_interval,
enable_cpu_cache_flush,
verbose,
),
add_thread_wrapper=True,
)
if isinstance(res, TimeoutError):
if verbose >= 1:
Expand Down Expand Up @@ -1022,9 +1056,10 @@ def local_run(
return measure_results


def _timed_rpc_run(
def _rpc_run(
inp_serialized,
build_res,
args,
key,
host,
port,
Expand All @@ -1037,11 +1072,7 @@ def _timed_rpc_run(
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -1080,32 +1111,18 @@ def _timed_rpc_run(
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"

tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(
ndarray.array(
get_task_input_buffer(inp.task.workload_key, tensor_name), dev
)
)
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
assert len(args) == len(build_res.args)
# pylint: disable=consider-using-enumerate
for idx in range(len(args)):
if args[idx] is None:
build_res_arg = build_res.args[idx]
empty_array = ndarray.empty(
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
logger.warning(
"task_inputs not fully matched, check if there's any unexpected error"
)
args[idx] = empty_array
else:
args[idx] = ndarray.array(args[idx], dev)
dev.sync()

# First run for check that the kernel is correct
Expand Down Expand Up @@ -1152,7 +1169,7 @@ def _rpc_run_worker(args):
res : MeasureResult
The measure result of this Runner thread.
"""
_, build_res, _, _, _, _, timeout, _, _, _, _, _, verbose = args
_, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose = args
if build_res.error_no != MeasureErrorNo.NO_ERROR:
return (
(MAX_FLOAT,),
Expand All @@ -1162,24 +1179,16 @@ def _rpc_run_worker(args):
time.time(),
)

res = call_func_with_timeout(timeout, _timed_rpc_run, args=args)
if isinstance(res, TimeoutError):
if verbose >= 1:
print("*T", end="") # Run timeout
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
None,
build_res.time_cost + timeout,
time.time(),
)
elif isinstance(res, Exception):
try:
res = _rpc_run(*args)
# pylint: disable=broad-except
except Exception:
if verbose >= 1:
print("*E", end="") # Run error
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUNTIME_DEVICE,
str(res),
make_traceback_info(),
build_res.time_cost + timeout,
time.time(),
)
Expand Down Expand Up @@ -1259,13 +1268,14 @@ def rpc_runner_run(
"""
assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
# This pool is not doing computationally intensive work, so we can use threads
pool = multiprocessing.pool.ThreadPool(n_parallel)
tuple_res = pool.map(
executor = PopenPoolExecutor(n_parallel)
tuple_res = executor.map_with_error_catching(
_rpc_run_worker,
[
(
inp.serialize(),
build_res,
prepare_runner_args(inp, build_res),
key,
host,
port,
Expand All @@ -1281,13 +1291,25 @@ def rpc_runner_run(
for inp, build_res in zip(inputs, build_results)
],
)
pool.terminate()
pool.join()
del pool

results = []
for res in tuple_res:
results.append(MeasureResult(*res))
for i, res in enumerate(tuple_res):
if res.status == StatusKind.COMPLETE:
results.append(MeasureResult(*res.value))
else:
assert res.status == StatusKind.TIMEOUT
if verbose >= 1:
print("*T", end="") # Run timeout
build_res = build_results[i]
results.append(
MeasureResult(
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
None,
build_res.time_cost + timeout,
time.time(),
)
)

if verbose >= 1:
print("")
Expand Down
Loading

0 comments on commit bb7e94e

Please sign in to comment.