From 5ac44487e6f8c4d52bcedabd7f570f6d15609c62 Mon Sep 17 00:00:00 2001 From: Yuanjing Shi Date: Wed, 8 Sep 2021 10:08:15 -0700 Subject: [PATCH] [AutoTVM] Use popenpool in local_executor (#8851) * use popenpool in local_executor * move auto_tvm_common to tvm.testing * refactor * nit * remove LocalFutureNoFork * exception handling * handling two exceptions * handling error * add initiazlier --- python/tvm/autotvm/measure/__init__.py | 1 - python/tvm/autotvm/measure/local_executor.py | 157 ------------------ python/tvm/autotvm/measure/measure_methods.py | 86 +++++----- python/tvm/testing/__init__.py | 4 +- .../tvm/testing/autotvm.py | 1 + python/tvm/testing/popen_pool.py | 16 ++ tests/python/contrib/test_popen_pool.py | 36 ++++ .../python/unittest/test_autotvm_database.py | 2 +- .../python/unittest/test_autotvm_executor.py | 69 -------- .../unittest/test_autotvm_index_tuner.py | 2 +- tests/python/unittest/test_autotvm_measure.py | 9 +- tests/python/unittest/test_autotvm_record.py | 2 +- .../unittest/test_autotvm_xgboost_model.py | 2 +- 13 files changed, 112 insertions(+), 275 deletions(-) delete mode 100644 python/tvm/autotvm/measure/local_executor.py rename tests/python/unittest/test_autotvm_common.py => python/tvm/testing/autotvm.py (97%) delete mode 100644 tests/python/unittest/test_autotvm_executor.py diff --git a/python/tvm/autotvm/measure/__init__.py b/python/tvm/autotvm/measure/__init__.py index c4c0dc92b116..10b0843402ea 100644 --- a/python/tvm/autotvm/measure/__init__.py +++ b/python/tvm/autotvm/measure/__init__.py @@ -31,4 +31,3 @@ request_remote, ) from .executor import Executor -from .local_executor import LocalExecutor diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py deleted file mode 100644 index a9aeb790c82a..000000000000 --- a/python/tvm/autotvm/measure/local_executor.py +++ /dev/null @@ -1,157 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Local based implementation of the executor using multiprocessing""" - -import signal - -from multiprocessing import Process, Queue - -try: - from queue import Empty -except ImportError: - from Queue import Empty - -try: - import psutil -except ImportError: - psutil = None - -from . import executor - - -def kill_child_processes(parent_pid, sig=signal.SIGTERM): - """kill all child processes recursively""" - try: - parent = psutil.Process(parent_pid) - children = parent.children(recursive=True) - except psutil.NoSuchProcess: - return - for process in children: - try: - process.send_signal(sig) - except psutil.NoSuchProcess: - return - - -def _execute_func(func, queue, args, kwargs): - """execute function and return the result or exception to a queue""" - try: - res = func(*args, **kwargs) - except Exception as exc: # pylint: disable=broad-except - res = exc - queue.put(res) - - -def call_with_timeout(queue, timeout, func, args, kwargs): - """A wrapper to support timeout of a function call""" - - # start a new process for timeout (cannot use thread because we have c function) - p = Process(target=_execute_func, args=(func, queue, args, kwargs)) - p.start() - p.join(timeout=timeout) - - queue.put(executor.TimeoutError()) - - kill_child_processes(p.pid) - p.terminate() - p.join() - - -class LocalFuture(executor.Future): - """Local wrapper for the future - - Parameters - ---------- - process: multiprocessing.Process - process for running this task - queue: multiprocessing.Queue - queue for receiving the result of this task - """ - - def __init__(self, process, queue): - self._done = False - self._process = process - self._queue = queue - - def done(self): - self._done = self._done or not self._queue.empty() - return self._done - - def get(self, timeout=None): - try: - res = self._queue.get(block=True, timeout=timeout) - except Empty: - raise executor.TimeoutError() - if self._process.is_alive(): - kill_child_processes(self._process.pid) - self._process.terminate() - self._process.join() - self._queue.close() - self._queue.join_thread() - self._done = True - del self._queue - del self._process - return res - - -class LocalFutureNoFork(executor.Future): - """Local wrapper for the future. - This is a none-fork version of LocalFuture. - Use this for the runtime that does not support fork (like cudnn) - """ - - def __init__(self, result): - self._result = result - - def done(self): - return True - - def get(self, timeout=None): - return self._result - - -class LocalExecutor(executor.Executor): - """Local executor that runs workers on the same machine with multiprocessing. - - Parameters - ---------- - timeout: float, optional - timeout of a job. If time is out. A TimeoutError will be returned (not raised) - do_fork: bool, optional - For some runtime systems that do not support fork after initialization - (e.g. cuda runtime, cudnn). Set this to False if you have used these runtime - before submitting jobs. - """ - - def __init__(self, timeout=None, do_fork=True): - self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT - self.do_fork = do_fork - - if self.do_fork: - if not psutil: - raise RuntimeError( - "Python package psutil is missing. " "please try `pip install psutil`" - ) - - def submit(self, func, *args, **kwargs): - if not self.do_fork: - return LocalFutureNoFork(func(*args, **kwargs)) - - queue = Queue(2) # Size of 2 to avoid a race condition with size 1. - process = Process(target=call_with_timeout, args=(queue, self.timeout, func, args, kwargs)) - process.start() - return LocalFuture(process, queue) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index eab6822b63b8..42e046aefb4a 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -38,7 +38,9 @@ import tvm.ir.transform from tvm import nd from tvm import rpc as _rpc +from tvm.autotvm.env import AutotvmGlobalScope, reset_global_scope from tvm.contrib import ndk, nvcc, stackvm, tar +from tvm.contrib.popen_pool import PopenPoolExecutor from tvm.driver import build from tvm.error import TVMError from tvm.target import Target @@ -46,7 +48,6 @@ from ..env import AutotvmGlobalScope from ..task.space import InstantiationError from ..utils import get_const_tuple -from .local_executor import LocalExecutor from .measure import Builder, MeasureErrorNo, MeasureResult, Runner logger = logging.getLogger("autotvm") @@ -98,7 +99,9 @@ def __init__(self, timeout=10, n_parallel=None, build_func="default"): else: raise ValueError("Invalid build_func" + build_func) self.build_func = _WrappedBuildFunc(build_func) - self.executor = LocalExecutor(timeout=timeout) + self.executor = PopenPoolExecutor( + timeout=timeout, initializer=reset_global_scope, initargs=(AutotvmGlobalScope.current,) + ) self.tmp_dir = tempfile.mkdtemp() def build(self, measure_inputs): @@ -114,53 +117,52 @@ def build(self, measure_inputs): futures.append(ret) for future in futures: - res = future.get() - - if isinstance(res, Exception): - # timeout or fleet error, return MeasureResult directly - results.append( - MeasureResult( - (res,), MeasureErrorNo.BUILD_TIMEOUT, self.timeout, time.time() - ) - ) - elif res.error is not None: - # instantiation error - if isinstance(res.error, InstantiationError): - results.append( - MeasureResult( + try: + res = future.result() + if res.error is not None: + # instantiation error + if isinstance(res.error, InstantiationError): + res = MeasureResult( (res.error,), MeasureErrorNo.INSTANTIATION_ERROR, res.time_cost, time.time(), ) - ) - else: - if "InstantiationError" in str(res.error): - msg = str(res.error) - try: - msg = msg.split("\n")[-2].split(": ")[1] - except Exception: # pylint: disable=broad-except - pass - results.append( - MeasureResult( + + else: + if "InstantiationError" in str(res.error): + msg = str(res.error) + try: + msg = msg.split("\n")[-2].split(": ")[1] + except Exception: # pylint: disable=broad-except + pass + res = MeasureResult( (InstantiationError(msg),), MeasureErrorNo.INSTANTIATION_ERROR, res.time_cost, time.time(), ) - ) - else: # tvm error - results.append( - MeasureResult( + + else: # tvm error + res = MeasureResult( (res.error,), MeasureErrorNo.COMPILE_HOST, res.time_cost, time.time(), ) - ) - else: - # return BuildResult - results.append(res) + except TimeoutError as ex: + res = MeasureResult( + (ex,), MeasureErrorNo.BUILD_TIMEOUT, self.timeout, time.time() + ) + except ChildProcessError as ex: + res = MeasureResult( + (ex,), + MeasureErrorNo.RUNTIME_DEVICE, + self.timeout, + time.time(), + ) + + results.append(res) return results @@ -242,7 +244,11 @@ def __init__( self.cooldown_interval = cooldown_interval self.module_loader = module_loader - self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1)) + self.executor = PopenPoolExecutor( + timeout=timeout * (self.n_parallel + 1), + initializer=reset_global_scope, + initargs=(AutotvmGlobalScope.current,), + ) @property def ref_input(self): @@ -337,15 +343,15 @@ def run(self, measure_inputs, build_results): futures.append(ret) for future in futures: - res = future.get() - if isinstance(res, Exception): # executor error or timeout + try: + res = future.result() + results.append(res) + except Exception as ex: # pylint: disable=broad-except results.append( MeasureResult( - (str(res),), MeasureErrorNo.RUN_TIMEOUT, self.timeout, time.time() + (str(ex),), MeasureErrorNo.RUN_TIMEOUT, self.timeout, time.time() ) ) - else: - results.append(res) return results diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index 75349d8d5a14..d84846725ec4 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -25,6 +25,8 @@ from ._ffi_api import ErrorTest, FrontendTestModule, identity_cpp from .popen_pool import initializer, after_initializer, register_ffi, call_cpp_ffi -from .popen_pool import call_py_ffi, call_cpp_py_ffi +from .popen_pool import call_py_ffi, call_cpp_py_ffi, fast_summation, slow_summation +from .popen_pool import timeout_job from . import auto_scheduler +from . import autotvm diff --git a/tests/python/unittest/test_autotvm_common.py b/python/tvm/testing/autotvm.py similarity index 97% rename from tests/python/unittest/test_autotvm_common.py rename to python/tvm/testing/autotvm.py index 60f7d8bafb1b..6f7bb13fe6dc 100644 --- a/tests/python/unittest/test_autotvm_common.py +++ b/python/tvm/testing/autotvm.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name, missing-function-docstring, missing-class-docstring """Common utilities for testing autotvm""" import time diff --git a/python/tvm/testing/popen_pool.py b/python/tvm/testing/popen_pool.py index 20345a2218fe..b646d7a89e94 100644 --- a/python/tvm/testing/popen_pool.py +++ b/python/tvm/testing/popen_pool.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, missing-function-docstring """Common functions for popen_pool test cases""" +import time import tvm TEST_GLOBAL_STATE_1 = 0 @@ -57,3 +58,18 @@ def call_cpp_ffi(arg): def call_cpp_py_ffi(arg): return tvm.testing.identity_cpp(arg) + + +def fast_summation(n): + return n * (n + 1) // 2 + + +def slow_summation(n): + r = 0 + for i in range(0, n + 1): + r += i + return r + + +def timeout_job(n): + time.sleep(n * 1.5) diff --git a/tests/python/contrib/test_popen_pool.py b/tests/python/contrib/test_popen_pool.py index 9ebe4c11c118..b3a91e176a32 100644 --- a/tests/python/contrib/test_popen_pool.py +++ b/tests/python/contrib/test_popen_pool.py @@ -27,6 +27,9 @@ call_py_ffi, call_cpp_ffi, call_cpp_py_ffi, + fast_summation, + slow_summation, + timeout_job, ) @@ -104,8 +107,41 @@ def test_popen_ffi(): assert proc.recv() == initargs[0] +def test_popen_pool_executor_async(): + pool = PopenPoolExecutor() + f1 = pool.submit(slow_summation, 9999999) + f2 = pool.submit(fast_summation, 9999999) + t1 = 0 + t2 = 0 + while True: + if t1 == 0 and f1.done(): + t1 = time.time() + if t2 == 0 and f2.done(): + t2 = time.time() + if t1 != 0 and t2 != 0: + break + assert t2 < t1, "Expected fast async job to finish first!" + assert f1.result() == f2.result() + + +def test_popen_pool_executor_timeout(): + timeout = 0.5 + + pool = PopenPoolExecutor(timeout=timeout) + + f1 = pool.submit(timeout_job, timeout) + while not f1.done(): + pass + try: + res = f1.result() + except Exception as ex: + assert isinstance(ex, TimeoutError) + + if __name__ == "__main__": test_popen_worker() test_popen_pool_executor() test_popen_initializer() test_popen_ffi() + test_popen_pool_executor_async() + test_popen_pool_executor_timeout() diff --git a/tests/python/unittest/test_autotvm_database.py b/tests/python/unittest/test_autotvm_database.py index 197243ed47c0..d5980022811f 100644 --- a/tests/python/unittest/test_autotvm_database.py +++ b/tests/python/unittest/test_autotvm_database.py @@ -21,7 +21,7 @@ from tvm.autotvm import database from tvm.autotvm.record import encode, MeasureResult -from test_autotvm_common import get_sample_records +from tvm.testing.autotvm import get_sample_records def test_save_load(): diff --git a/tests/python/unittest/test_autotvm_executor.py b/tests/python/unittest/test_autotvm_executor.py deleted file mode 100644 index 9757576be9e3..000000000000 --- a/tests/python/unittest/test_autotvm_executor.py +++ /dev/null @@ -1,69 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test local executor""" -import time - -from tvm.autotvm.measure import LocalExecutor, executor - - -def slow(n): - r = 0 - for i in range(0, n + 1): - r += i - return r - - -def fast(n): - return n * (n + 1) // 2 - - -def test_local_measure_async(): - ex = LocalExecutor() - f1 = ex.submit(slow, 9999999) - f2 = ex.submit(fast, 9999999) - t1 = 0 - t2 = 0 - while True: - if t1 == 0 and f1.done(): - t1 = time.time() - if t2 == 0 and f2.done(): - t2 = time.time() - if t1 != 0 and t2 != 0: - break - assert t2 < t1, "Expected fast async job to finish first!" - assert f1.get() == f2.get() - - -def timeout_job(n): - time.sleep(n * 1.5) - - -def test_timeout(): - timeout = 0.5 - - ex = LocalExecutor(timeout=timeout) - - f1 = ex.submit(timeout_job, timeout) - while not f1.done(): - pass - res = f1.get() - assert isinstance(res, executor.TimeoutError) - - -if __name__ == "__main__": - test_local_measure_async() - test_timeout() diff --git a/tests/python/unittest/test_autotvm_index_tuner.py b/tests/python/unittest/test_autotvm_index_tuner.py index c433d8fb7297..be89ee2506fc 100644 --- a/tests/python/unittest/test_autotvm_index_tuner.py +++ b/tests/python/unittest/test_autotvm_index_tuner.py @@ -17,7 +17,7 @@ """Test index based tuners""" import multiprocessing -from test_autotvm_common import DummyRunner, get_sample_task +from tvm.testing.autotvm import DummyRunner, get_sample_task from tvm import autotvm from tvm.autotvm.tuner import GridSearchTuner, RandomTuner diff --git a/tests/python/unittest/test_autotvm_measure.py b/tests/python/unittest/test_autotvm_measure.py index a89c69c37d64..3ef5cbdad635 100644 --- a/tests/python/unittest/test_autotvm_measure.py +++ b/tests/python/unittest/test_autotvm_measure.py @@ -17,13 +17,14 @@ """Test builder and runner""" import logging import multiprocessing -import time +import concurrent import numpy as np import tvm from tvm import te -from test_autotvm_common import DummyRunner, bad_matmul, get_sample_task +from tvm.autotvm.measure import executor +from tvm.testing.autotvm import DummyRunner, bad_matmul, get_sample_task from tvm import autotvm from tvm.autotvm.measure.measure import MeasureErrorNo, MeasureResult from tvm.autotvm import measure @@ -76,7 +77,9 @@ def submit(self, func, *args, **kwargs): self.ran_dummy_executor = True sig = Signature.from_callable(func) assert sig.bind(*args, **kwargs).arguments["ref_input"] == refinp - return measure.local_executor.LocalFutureNoFork(None) + dummy_future = concurrent.futures.Future() + dummy_future.set_result(None) + return dummy_future runner.executor = DummyExecutor() runner.run([None], [None]) diff --git a/tests/python/unittest/test_autotvm_record.py b/tests/python/unittest/test_autotvm_record.py index 51cc9074a4fe..65739df52cd9 100644 --- a/tests/python/unittest/test_autotvm_record.py +++ b/tests/python/unittest/test_autotvm_record.py @@ -25,7 +25,7 @@ from tvm.autotvm.measure import MeasureInput, MeasureResult, MeasureErrorNo from tvm.autotvm.record import encode, decode, ApplyHistoryBest, measure_str_key -from test_autotvm_common import get_sample_task +from tvm.testing.autotvm import get_sample_task def test_load_dump(): diff --git a/tests/python/unittest/test_autotvm_xgboost_model.py b/tests/python/unittest/test_autotvm_xgboost_model.py index 445cff8759ab..baecdaceab6d 100644 --- a/tests/python/unittest/test_autotvm_xgboost_model.py +++ b/tests/python/unittest/test_autotvm_xgboost_model.py @@ -25,7 +25,7 @@ from tvm.autotvm import MeasureInput, MeasureResult from tvm.autotvm.tuner.xgboost_cost_model import XGBoostCostModel -from test_autotvm_common import get_sample_task, get_sample_records +from tvm.testing.autotvm import get_sample_task, get_sample_records def test_fit():