Skip to content

Commit

Permalink
[RPC] Make tracker jupyter friendly
Browse files Browse the repository at this point in the history
This PR uses the PopenWorker to handle the tracker start up
and makes the tracker jupyter friendly.
  • Loading branch information
tqchen committed May 1, 2021
1 parent 6d555b6 commit c60a5cf
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 52 deletions.
22 changes: 19 additions & 3 deletions python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,26 @@ def _start(self):
self._reader = os.fdopen(main_read, "rb")
self._writer = os.fdopen(main_write, "wb")

def join(self):
"""Join the current process worker before it terminates"""
def join(self, timeout=None):
"""Join the current process worker before it terminates.
Parameters
----------
timeout: Optional[number]
Timeout value, block at most timeout seconds if it
is a positive number.
"""
if self._proc:
try:
self._proc.wait(timeout)
except subprocess.TimeoutExpired:
pass

def is_alive(self):
"""Check if the process is alive"""
if self._proc:
self._proc.wait()
return self._proc.poll() is None
return False

def send(self, fn, args=(), kwargs=None, timeout=None):
"""Send a new function task fn(*args, **kwargs) to the subprocess.
Expand Down
24 changes: 0 additions & 24 deletions python/tvm/exec/rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@
# under the License.
# pylint: disable=redefined-outer-name, invalid-name
"""Tool to start RPC tracker"""
from __future__ import absolute_import

import logging
import argparse
import multiprocessing
import sys
from ..rpc.tracker import Tracker


Expand All @@ -36,27 +32,7 @@ def main(args):
parser.add_argument("--host", type=str, default="0.0.0.0", help="the hostname of the tracker")
parser.add_argument("--port", type=int, default=9190, help="The port of the RPC")
parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC")
parser.add_argument(
"--no-fork",
dest="fork",
action="store_false",
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.",
)
parser.add_argument("--silent", action="store_true", help="Whether run in silent mode.")

parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.fork is False:
if sys.version_info[0] < 3:
raise RuntimeError("Python3 is required for spawn mode.")
multiprocessing.set_start_method("spawn")
else:
if not args.silent:
logging.info(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
main(args)
88 changes: 63 additions & 25 deletions python/tvm/rpc/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@
"""
# pylint: disable=invalid-name

import asyncio
import heapq
import logging
import socket
import threading
import multiprocessing
import errno
import struct
import json
from tvm.contrib.popen_pool import PopenWorker

try:
from tornado import ioloop
Expand Down Expand Up @@ -362,14 +363,55 @@ def run(self):


def _tracker_server(listen_sock, stop_key):
asyncio.set_event_loop(asyncio.new_event_loop())
handler = TrackerServerHandler(listen_sock, stop_key)
handler.run()


class PopenTrackerServerState(object):
"""Internal PopenTrackerServer State"""

current = None

def __init__(self, host, port=9190, port_end=9199, silent=False):
if silent:
logger.setLevel(logging.WARN)

sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.thread = threading.Thread(target=_tracker_server, args=(sock, self.stop_key))
self.thread.start()
self.host = host


def _popen_start_tracker_server(host, port=9190, port_end=9199, silent=False):
# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
state = PopenTrackerServerState(host, port, port_end, silent)
PopenTrackerServerState.current = state
# returns the port so that the main can get the port number.
return (state.port, state.stop_key)


class Tracker(object):
"""Start RPC tracker on a seperate process.
Python implementation based on multi-processing.
Python implementation based on PopenWorker.
Parameters
----------
Expand All @@ -390,27 +432,20 @@ def __init__(self, host, port=9190, port_end=9199, silent=False):
if silent:
logger.setLevel(logging.WARN)

sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(target=_tracker_server, args=(sock, self.stop_key))
self.proc.start()
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_tracker_server,
[
host,
port,
port_end,
silent,
],
)
# receive the port
self.port, self.stop_key = self.proc.recv()
self.host = host
# close the socket on this process
sock.close()

def _stop_tracker(self):
sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM)
Expand All @@ -427,11 +462,14 @@ def terminate(self):
if self.proc:
if self.proc.is_alive():
self._stop_tracker()
self.proc.join(1)
self.proc.join(0.1)
if self.proc.is_alive():
logger.info("Terminating Tracker Server...")
self.proc.terminate()
self.proc.kill()
self.proc = None

def __del__(self):
self.terminate()
try:
self.terminate()
except TypeError:
pass

0 comments on commit c60a5cf

Please sign in to comment.