From c60a5cfb1a44a27e1f45e8ac43b2d57726845e70 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 1 May 2021 08:25:52 -0400 Subject: [PATCH] [RPC] Make tracker jupyter friendly This PR uses the PopenWorker to handle the tracker start up and makes the tracker jupyter friendly. --- python/tvm/contrib/popen_pool.py | 22 ++++++-- python/tvm/exec/rpc_tracker.py | 24 --------- python/tvm/rpc/tracker.py | 88 +++++++++++++++++++++++--------- 3 files changed, 82 insertions(+), 52 deletions(-) diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py index ecda995c7162..2f552034e9f8 100644 --- a/python/tvm/contrib/popen_pool.py +++ b/python/tvm/contrib/popen_pool.py @@ -153,10 +153,26 @@ def _start(self): self._reader = os.fdopen(main_read, "rb") self._writer = os.fdopen(main_write, "wb") - def join(self): - """Join the current process worker before it terminates""" + def join(self, timeout=None): + """Join the current process worker before it terminates. + + Parameters + ---------- + timeout: Optional[number] + Timeout value, block at most timeout seconds if it + is a positive number. + """ + if self._proc: + try: + self._proc.wait(timeout) + except subprocess.TimeoutExpired: + pass + + def is_alive(self): + """Check if the process is alive""" if self._proc: - self._proc.wait() + return self._proc.poll() is None + return False def send(self, fn, args=(), kwargs=None, timeout=None): """Send a new function task fn(*args, **kwargs) to the subprocess. diff --git a/python/tvm/exec/rpc_tracker.py b/python/tvm/exec/rpc_tracker.py index 4a1a964338ba..82ab13a9d3df 100644 --- a/python/tvm/exec/rpc_tracker.py +++ b/python/tvm/exec/rpc_tracker.py @@ -16,12 +16,8 @@ # under the License. # pylint: disable=redefined-outer-name, invalid-name """Tool to start RPC tracker""" -from __future__ import absolute_import - import logging import argparse -import multiprocessing -import sys from ..rpc.tracker import Tracker @@ -36,27 +32,7 @@ def main(args): parser.add_argument("--host", type=str, default="0.0.0.0", help="the hostname of the tracker") parser.add_argument("--port", type=int, default=9190, help="The port of the RPC") parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC") - parser.add_argument( - "--no-fork", - dest="fork", - action="store_false", - help="Use spawn mode to avoid fork. This option \ - is able to avoid potential fork problems with Metal, OpenCL \ - and ROCM compilers.", - ) parser.add_argument("--silent", action="store_true", help="Whether run in silent mode.") - - parser.set_defaults(fork=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO) - if args.fork is False: - if sys.version_info[0] < 3: - raise RuntimeError("Python3 is required for spawn mode.") - multiprocessing.set_start_method("spawn") - else: - if not args.silent: - logging.info( - "If you are running ROCM/Metal, fork will cause " - "compiler internal error. Try to launch with arg ```--no-fork```" - ) main(args) diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index e1c366e99b0d..8a5dd19a4743 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -41,14 +41,15 @@ """ # pylint: disable=invalid-name +import asyncio import heapq import logging import socket import threading -import multiprocessing import errno import struct import json +from tvm.contrib.popen_pool import PopenWorker try: from tornado import ioloop @@ -362,14 +363,55 @@ def run(self): def _tracker_server(listen_sock, stop_key): + asyncio.set_event_loop(asyncio.new_event_loop()) handler = TrackerServerHandler(listen_sock, stop_key) handler.run() +class PopenTrackerServerState(object): + """Internal PopenTrackerServer State""" + + current = None + + def __init__(self, host, port=9190, port_end=9199, silent=False): + if silent: + logger.setLevel(logging.WARN) + + sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) + self.port = None + self.stop_key = base.random_key("tracker") + for my_port in range(port, port_end): + try: + sock.bind((host, my_port)) + self.port = my_port + break + except socket.error as sock_err: + if sock_err.errno in [98, 48]: + continue + raise sock_err + if not self.port: + raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) + logger.info("bind to %s:%d", host, self.port) + sock.listen(1) + self.thread = threading.Thread(target=_tracker_server, args=(sock, self.stop_key)) + self.thread.start() + self.host = host + + +def _popen_start_tracker_server(host, port=9190, port_end=9199, silent=False): + # This is a function that will be sent to the + # Popen worker to run on a separate process. + # Create and start the server in a different thread + state = PopenTrackerServerState(host, port, port_end, silent) + PopenTrackerServerState.current = state + # returns the port so that the main can get the port number. + return (state.port, state.stop_key) + + class Tracker(object): """Start RPC tracker on a seperate process. - Python implementation based on multi-processing. + Python implementation based on PopenWorker. Parameters ---------- @@ -390,27 +432,20 @@ def __init__(self, host, port=9190, port_end=9199, silent=False): if silent: logger.setLevel(logging.WARN) - sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) - self.port = None - self.stop_key = base.random_key("tracker") - for my_port in range(port, port_end): - try: - sock.bind((host, my_port)) - self.port = my_port - break - except socket.error as sock_err: - if sock_err.errno in [98, 48]: - continue - raise sock_err - if not self.port: - raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - logger.info("bind to %s:%d", host, self.port) - sock.listen(1) - self.proc = multiprocessing.Process(target=_tracker_server, args=(sock, self.stop_key)) - self.proc.start() + self.proc = PopenWorker() + # send the function + self.proc.send( + _popen_start_tracker_server, + [ + host, + port, + port_end, + silent, + ], + ) + # receive the port + self.port, self.stop_key = self.proc.recv() self.host = host - # close the socket on this process - sock.close() def _stop_tracker(self): sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM) @@ -427,11 +462,14 @@ def terminate(self): if self.proc: if self.proc.is_alive(): self._stop_tracker() - self.proc.join(1) + self.proc.join(0.1) if self.proc.is_alive(): logger.info("Terminating Tracker Server...") - self.proc.terminate() + self.proc.kill() self.proc = None def __del__(self): - self.terminate() + try: + self.terminate() + except TypeError: + pass