Skip to content

Commit

Permalink
[AutoTVM] Use popenpool in local_executor (apache#8851)
Browse files Browse the repository at this point in the history
* use popenpool in local_executor

* move auto_tvm_common to tvm.testing

* refactor

* nit

* remove LocalFutureNoFork

* exception handling

* handling two exceptions

* handling error

* add initiazlier
  • Loading branch information
Yuanjing Shi authored and ylc committed Jan 13, 2022
1 parent 562a8ff commit 5ac4448
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 275 deletions.
1 change: 0 additions & 1 deletion python/tvm/autotvm/measure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@
request_remote,
)
from .executor import Executor
from .local_executor import LocalExecutor
157 changes: 0 additions & 157 deletions python/tvm/autotvm/measure/local_executor.py

This file was deleted.

86 changes: 46 additions & 40 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@
import tvm.ir.transform
from tvm import nd
from tvm import rpc as _rpc
from tvm.autotvm.env import AutotvmGlobalScope, reset_global_scope
from tvm.contrib import ndk, nvcc, stackvm, tar
from tvm.contrib.popen_pool import PopenPoolExecutor
from tvm.driver import build
from tvm.error import TVMError
from tvm.target import Target

from ..env import AutotvmGlobalScope
from ..task.space import InstantiationError
from ..utils import get_const_tuple
from .local_executor import LocalExecutor
from .measure import Builder, MeasureErrorNo, MeasureResult, Runner

logger = logging.getLogger("autotvm")
Expand Down Expand Up @@ -98,7 +99,9 @@ def __init__(self, timeout=10, n_parallel=None, build_func="default"):
else:
raise ValueError("Invalid build_func" + build_func)
self.build_func = _WrappedBuildFunc(build_func)
self.executor = LocalExecutor(timeout=timeout)
self.executor = PopenPoolExecutor(
timeout=timeout, initializer=reset_global_scope, initargs=(AutotvmGlobalScope.current,)
)
self.tmp_dir = tempfile.mkdtemp()

def build(self, measure_inputs):
Expand All @@ -114,53 +117,52 @@ def build(self, measure_inputs):
futures.append(ret)

for future in futures:
res = future.get()

if isinstance(res, Exception):
# timeout or fleet error, return MeasureResult directly
results.append(
MeasureResult(
(res,), MeasureErrorNo.BUILD_TIMEOUT, self.timeout, time.time()
)
)
elif res.error is not None:
# instantiation error
if isinstance(res.error, InstantiationError):
results.append(
MeasureResult(
try:
res = future.result()
if res.error is not None:
# instantiation error
if isinstance(res.error, InstantiationError):
res = MeasureResult(
(res.error,),
MeasureErrorNo.INSTANTIATION_ERROR,
res.time_cost,
time.time(),
)
)
else:
if "InstantiationError" in str(res.error):
msg = str(res.error)
try:
msg = msg.split("\n")[-2].split(": ")[1]
except Exception: # pylint: disable=broad-except
pass
results.append(
MeasureResult(

else:
if "InstantiationError" in str(res.error):
msg = str(res.error)
try:
msg = msg.split("\n")[-2].split(": ")[1]
except Exception: # pylint: disable=broad-except
pass
res = MeasureResult(
(InstantiationError(msg),),
MeasureErrorNo.INSTANTIATION_ERROR,
res.time_cost,
time.time(),
)
)
else: # tvm error
results.append(
MeasureResult(

else: # tvm error
res = MeasureResult(
(res.error,),
MeasureErrorNo.COMPILE_HOST,
res.time_cost,
time.time(),
)
)
else:
# return BuildResult
results.append(res)
except TimeoutError as ex:
res = MeasureResult(
(ex,), MeasureErrorNo.BUILD_TIMEOUT, self.timeout, time.time()
)
except ChildProcessError as ex:
res = MeasureResult(
(ex,),
MeasureErrorNo.RUNTIME_DEVICE,
self.timeout,
time.time(),
)

results.append(res)

return results

Expand Down Expand Up @@ -242,7 +244,11 @@ def __init__(
self.cooldown_interval = cooldown_interval
self.module_loader = module_loader

self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1))
self.executor = PopenPoolExecutor(
timeout=timeout * (self.n_parallel + 1),
initializer=reset_global_scope,
initargs=(AutotvmGlobalScope.current,),
)

@property
def ref_input(self):
Expand Down Expand Up @@ -337,15 +343,15 @@ def run(self, measure_inputs, build_results):
futures.append(ret)

for future in futures:
res = future.get()
if isinstance(res, Exception): # executor error or timeout
try:
res = future.result()
results.append(res)
except Exception as ex: # pylint: disable=broad-except
results.append(
MeasureResult(
(str(res),), MeasureErrorNo.RUN_TIMEOUT, self.timeout, time.time()
(str(ex),), MeasureErrorNo.RUN_TIMEOUT, self.timeout, time.time()
)
)
else:
results.append(res)

return results

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from ._ffi_api import ErrorTest, FrontendTestModule, identity_cpp

from .popen_pool import initializer, after_initializer, register_ffi, call_cpp_ffi
from .popen_pool import call_py_ffi, call_cpp_py_ffi
from .popen_pool import call_py_ffi, call_cpp_py_ffi, fast_summation, slow_summation
from .popen_pool import timeout_job

from . import auto_scheduler
from . import autotvm
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, missing-function-docstring, missing-class-docstring
"""Common utilities for testing autotvm"""
import time

Expand Down
16 changes: 16 additions & 0 deletions python/tvm/testing/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, missing-function-docstring
"""Common functions for popen_pool test cases"""
import time
import tvm

TEST_GLOBAL_STATE_1 = 0
Expand Down Expand Up @@ -57,3 +58,18 @@ def call_cpp_ffi(arg):

def call_cpp_py_ffi(arg):
return tvm.testing.identity_cpp(arg)


def fast_summation(n):
return n * (n + 1) // 2


def slow_summation(n):
r = 0
for i in range(0, n + 1):
r += i
return r


def timeout_job(n):
time.sleep(n * 1.5)
Loading

0 comments on commit 5ac4448

Please sign in to comment.