From 284faf241f173270064124d61a620503556860e7 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 3 May 2021 16:41:02 -0400 Subject: [PATCH] [RPC] Make tracker jupyter friendly (#7961) This PR uses the PopenWorker to handle the tracker start up and makes the tracker jupyter friendly. --- python/tvm/contrib/popen_pool.py | 22 ++++++-- python/tvm/exec/rpc_tracker.py | 24 --------- python/tvm/rpc/tracker.py | 89 ++++++++++++++++++++++---------- 3 files changed, 82 insertions(+), 53 deletions(-) diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py index ecda995c7162..2f552034e9f8 100644 --- a/python/tvm/contrib/popen_pool.py +++ b/python/tvm/contrib/popen_pool.py @@ -153,10 +153,26 @@ def _start(self): self._reader = os.fdopen(main_read, "rb") self._writer = os.fdopen(main_write, "wb") - def join(self): - """Join the current process worker before it terminates""" + def join(self, timeout=None): + """Join the current process worker before it terminates. + + Parameters + ---------- + timeout: Optional[number] + Timeout value, block at most timeout seconds if it + is a positive number. + """ + if self._proc: + try: + self._proc.wait(timeout) + except subprocess.TimeoutExpired: + pass + + def is_alive(self): + """Check if the process is alive""" if self._proc: - self._proc.wait() + return self._proc.poll() is None + return False def send(self, fn, args=(), kwargs=None, timeout=None): """Send a new function task fn(*args, **kwargs) to the subprocess. diff --git a/python/tvm/exec/rpc_tracker.py b/python/tvm/exec/rpc_tracker.py index 05809e044fec..091e95ad735f 100644 --- a/python/tvm/exec/rpc_tracker.py +++ b/python/tvm/exec/rpc_tracker.py @@ -16,12 +16,8 @@ # under the License. # pylint: disable=redefined-outer-name, invalid-name """Tool to start RPC tracker""" -from __future__ import absolute_import - import logging import argparse -import multiprocessing -import sys from ..rpc.tracker import Tracker @@ -38,27 +34,7 @@ def main(args): ) parser.add_argument("--port", type=int, default=9190, help="The port of the RPC") parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC") - parser.add_argument( - "--no-fork", - dest="fork", - action="store_false", - help="Use spawn mode to avoid fork. This option \ - is able to avoid potential fork problems with Metal, OpenCL \ - and ROCM compilers.", - ) parser.add_argument("--silent", action="store_true", help="Whether run in silent mode.") - - parser.set_defaults(fork=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO) - if args.fork is False: - if sys.version_info[0] < 3: - raise RuntimeError("Python3 is required for spawn mode.") - multiprocessing.set_start_method("spawn") - else: - if not args.silent: - logging.info( - "If you are running ROCM/Metal, fork will cause " - "compiler internal error. Try to launch with arg ```--no-fork```" - ) main(args) diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index 9dc4139f97ca..25ff15cad01d 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -41,14 +41,15 @@ """ # pylint: disable=invalid-name +import asyncio import heapq import logging import socket import threading -import multiprocessing import errno import struct import json +from tvm.contrib.popen_pool import PopenWorker try: from tornado import ioloop @@ -362,14 +363,55 @@ def run(self): def _tracker_server(listen_sock, stop_key): + asyncio.set_event_loop(asyncio.new_event_loop()) handler = TrackerServerHandler(listen_sock, stop_key) handler.run() +class PopenTrackerServerState(object): + """Internal PopenTrackerServer State""" + + current = None + + def __init__(self, host, port=9190, port_end=9199, silent=False): + if silent: + logger.setLevel(logging.WARN) + + sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) + self.port = None + self.stop_key = base.random_key("tracker") + for my_port in range(port, port_end): + try: + sock.bind((host, my_port)) + self.port = my_port + break + except socket.error as sock_err: + if sock_err.errno in [errno.EADDRINUSE]: + continue + raise sock_err + if not self.port: + raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) + logger.info("bind to %s:%d", host, self.port) + sock.listen(1) + self.thread = threading.Thread(target=_tracker_server, args=(sock, self.stop_key)) + self.thread.start() + self.host = host + + +def _popen_start_tracker_server(host, port=9190, port_end=9199, silent=False): + # This is a function that will be sent to the + # Popen worker to run on a separate process. + # Create and start the server in a different thread + state = PopenTrackerServerState(host, port, port_end, silent) + PopenTrackerServerState.current = state + # returns the port so that the main can get the port number. + return (state.port, state.stop_key) + + class Tracker(object): """Start RPC tracker on a separate process. - Python implementation based on multi-processing. + Python implementation based on PopenWorker. Parameters ---------- @@ -389,28 +431,20 @@ class Tracker(object): def __init__(self, host="0.0.0.0", port=9190, port_end=9199, silent=False): if silent: logger.setLevel(logging.WARN) - - sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) - self.port = None - self.stop_key = base.random_key("tracker") - for my_port in range(port, port_end): - try: - sock.bind((host, my_port)) - self.port = my_port - break - except socket.error as sock_err: - if sock_err.errno in [errno.EADDRINUSE]: - continue - raise sock_err - if not self.port: - raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - logger.info("bind to %s:%d", host, self.port) - sock.listen(1) - self.proc = multiprocessing.Process(target=_tracker_server, args=(sock, self.stop_key)) - self.proc.start() + self.proc = PopenWorker() + # send the function + self.proc.send( + _popen_start_tracker_server, + [ + host, + port, + port_end, + silent, + ], + ) + # receive the port + self.port, self.stop_key = self.proc.recv() self.host = host - # close the socket on this process - sock.close() def _stop_tracker(self): sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM) @@ -427,11 +461,14 @@ def terminate(self): if self.proc: if self.proc.is_alive(): self._stop_tracker() - self.proc.join(1) + self.proc.join(0.1) if self.proc.is_alive(): logger.info("Terminating Tracker Server...") - self.proc.terminate() + self.proc.kill() self.proc = None def __del__(self): - self.terminate() + try: + self.terminate() + except TypeError: + pass