Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC] Make tracker jupyter friendly via PopenWorker #7961

Merged
merged 2 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,26 @@ def _start(self):
self._reader = os.fdopen(main_read, "rb")
self._writer = os.fdopen(main_write, "wb")

def join(self):
"""Join the current process worker before it terminates"""
def join(self, timeout=None):
"""Join the current process worker before it terminates.

Parameters
----------
timeout: Optional[number]
Timeout value, block at most timeout seconds if it
is a positive number.
"""
if self._proc:
try:
self._proc.wait(timeout)
except subprocess.TimeoutExpired:
pass

def is_alive(self):
"""Check if the process is alive"""
if self._proc:
self._proc.wait()
return self._proc.poll() is None
return False

def send(self, fn, args=(), kwargs=None, timeout=None):
"""Send a new function task fn(*args, **kwargs) to the subprocess.
Expand Down
24 changes: 0 additions & 24 deletions python/tvm/exec/rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@
# under the License.
# pylint: disable=redefined-outer-name, invalid-name
"""Tool to start RPC tracker"""
from __future__ import absolute_import

import logging
import argparse
import multiprocessing
import sys
from ..rpc.tracker import Tracker


Expand All @@ -38,27 +34,7 @@ def main(args):
)
parser.add_argument("--port", type=int, default=9190, help="The port of the RPC")
parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC")
parser.add_argument(
"--no-fork",
dest="fork",
action="store_false",
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.",
)
parser.add_argument("--silent", action="store_true", help="Whether run in silent mode.")

parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.fork is False:
if sys.version_info[0] < 3:
raise RuntimeError("Python3 is required for spawn mode.")
multiprocessing.set_start_method("spawn")
else:
if not args.silent:
logging.info(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
main(args)
89 changes: 63 additions & 26 deletions python/tvm/rpc/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@
"""
# pylint: disable=invalid-name

import asyncio
import heapq
import logging
import socket
import threading
import multiprocessing
import errno
import struct
import json
from tvm.contrib.popen_pool import PopenWorker

try:
from tornado import ioloop
Expand Down Expand Up @@ -362,14 +363,55 @@ def run(self):


def _tracker_server(listen_sock, stop_key):
asyncio.set_event_loop(asyncio.new_event_loop())
handler = TrackerServerHandler(listen_sock, stop_key)
handler.run()


class PopenTrackerServerState(object):
"""Internal PopenTrackerServer State"""

current = None

def __init__(self, host, port=9190, port_end=9199, silent=False):
if silent:
logger.setLevel(logging.WARN)

sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [errno.EADDRINUSE]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.thread = threading.Thread(target=_tracker_server, args=(sock, self.stop_key))
self.thread.start()
self.host = host


def _popen_start_tracker_server(host, port=9190, port_end=9199, silent=False):
# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
state = PopenTrackerServerState(host, port, port_end, silent)
PopenTrackerServerState.current = state
# returns the port so that the main can get the port number.
return (state.port, state.stop_key)


class Tracker(object):
"""Start RPC tracker on a separate process.

Python implementation based on multi-processing.
Python implementation based on PopenWorker.

Parameters
----------
Expand All @@ -389,28 +431,20 @@ class Tracker(object):
def __init__(self, host="0.0.0.0", port=9190, port_end=9199, silent=False):
if silent:
logger.setLevel(logging.WARN)

sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [errno.EADDRINUSE]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(target=_tracker_server, args=(sock, self.stop_key))
self.proc.start()
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_tracker_server,
[
host,
port,
port_end,
silent,
],
)
# receive the port
self.port, self.stop_key = self.proc.recv()
self.host = host
# close the socket on this process
sock.close()

def _stop_tracker(self):
sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM)
Expand All @@ -427,11 +461,14 @@ def terminate(self):
if self.proc:
if self.proc.is_alive():
self._stop_tracker()
self.proc.join(1)
self.proc.join(0.1)
junrushao marked this conversation as resolved.
Show resolved Hide resolved
if self.proc.is_alive():
logger.info("Terminating Tracker Server...")
self.proc.terminate()
self.proc.kill()
self.proc = None

def __del__(self):
self.terminate()
try:
self.terminate()
except TypeError:
pass