Skip to content

Commit

Permalink
[RPC][REFACTOR] Use PopenWorker to handle RPC Server. (apache#7889)
Browse files Browse the repository at this point in the history
Previously the rpc server relies multiprocessing to start a new process and does not work under jupyter.
It also have a popen mode that does ensure the socket start listening before returning the port number.

This PR switches the implementations use PopenWorker. The port number is returned after the socket
get binded, which resolves some of the RPC flaky issues(need sleep to wait the server to start).
It also makes the RPC server jupyter friendly.
  • Loading branch information
tqchen authored and Trevor Morris committed May 6, 2021
1 parent 6abb2c4 commit 2d41d7e
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 173 deletions.
1 change: 0 additions & 1 deletion python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,6 @@ def __init__(
port=self.tracker.port,
port_end=10000,
key=device_key,
use_popen=True,
silent=True,
tracker_addr=(self.tracker.host, self.tracker.port),
)
Expand Down
1 change: 0 additions & 1 deletion python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,6 @@ def set_task(self, task):
port=9000,
port_end=10000,
key=device_key,
use_popen=True,
silent=True,
tracker_addr=(tracker.host, tracker.port),
)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def kill_child_processes(pid):

try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
except psutil.NoSuchProcess:
return

for process in parent.children(recursive=True):
for process in children:
try:
process.kill()
except psutil.NoSuchProcess:
Expand Down
20 changes: 6 additions & 14 deletions python/tvm/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
# under the License.
# pylint: disable=redefined-outer-name, invalid-name
"""Start an RPC server"""
from __future__ import absolute_import

import argparse
import multiprocessing
import sys
import logging
from .. import rpc

Expand Down Expand Up @@ -51,6 +47,7 @@ def main(args):
load_library=args.load_library,
custom_addr=args.custom_addr,
silent=args.silent,
no_fork=not args.fork,
)
server.proc.join()

Expand Down Expand Up @@ -85,14 +82,9 @@ def main(args):
parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.fork is False:
if sys.version_info[0] < 3:
raise RuntimeError("Python3 is required for spawn mode.")
multiprocessing.set_start_method("spawn")
else:
if not args.silent:
logging.info(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
if not args.fork is False and not args.silent:
logging.info(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
main(args)
6 changes: 3 additions & 3 deletions python/tvm/rpc/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def __init__(
self.thread.start()


def _popen_start_server(
def _popen_start_proxy_server(
host,
port=9091,
port_end=9199,
Expand Down Expand Up @@ -570,7 +570,7 @@ def _popen_start_server(
class Proxy(object):
"""Start RPC proxy server on a seperate process.
Python implementation based on multi-processing.
Python implementation based on PopenWorker.
Parameters
----------
Expand Down Expand Up @@ -618,7 +618,7 @@ def __init__(
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_server,
_popen_start_proxy_server,
[
host,
port,
Expand Down
206 changes: 110 additions & 96 deletions python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,21 @@
import select
import struct
import logging
import threading
import multiprocessing
import subprocess
import time
import sys
import signal
import platform
import tvm._ffi

from tvm._ffi.base import py_str
from tvm._ffi.libinfo import find_lib_path
from tvm.runtime.module import load_module as _load_module
from tvm.contrib import utils
from tvm.contrib.popen_pool import PopenWorker
from . import _ffi_api
from . import base

# pylint: disable=unused-import
from . import testing
from .base import TrackerCode

logger = logging.getLogger("RPCServer")
Expand Down Expand Up @@ -296,13 +297,85 @@ def _connect_proxy_loop(addr, key, load_library):
time.sleep(retry_period)


def _popen(cmd):
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=os.environ)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Server invoke error:\n"
msg += out
raise RuntimeError(msg)
class PopenRPCServerState(object):
"""Internal PopenRPCServer State"""

current = None

def __init__(
self,
host,
port=9091,
port_end=9199,
is_proxy=False,
tracker_addr=None,
key="",
load_library=None,
custom_addr=None,
silent=False,
):

# start update
self.host = host
self.port = port
self.libs = []
self.custom_addr = custom_addr

if silent:
logger.setLevel(logging.ERROR)

if not is_proxy:
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.thread = threading.Thread(
target=_listen_loop,
args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr),
)
self.thread.start()
else:
self.thread = threading.Thread(
target=_connect_proxy_loop, args=((host, port), key, load_library)
)
self.thread.start()


def _popen_start_rpc_server(
host,
port=9091,
port_end=9199,
is_proxy=False,
tracker_addr=None,
key="",
load_library=None,
custom_addr=None,
silent=False,
no_fork=False,
):
if no_fork:
multiprocessing.set_start_method("spawn")
# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
state = PopenRPCServerState(
host, port, port_end, is_proxy, tracker_addr, key, load_library, custom_addr, silent
)
PopenRPCServerState.current = state
# returns the port so that the main can get the port number.
return state.port


class Server(object):
Expand All @@ -328,11 +401,6 @@ class Server(object):
If this is true, the host and port actually corresponds to the
address of the proxy server.
use_popen : bool, optional
Whether to use Popen to start a fresh new process instead of fork.
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.
tracker_addr: Tuple (str, int) , optional
The address of RPC Tracker in tuple(host, ip) format.
If is not None, the server will register itself to the tracker.
Expand All @@ -348,6 +416,9 @@ class Server(object):
silent: bool, optional
Whether run this server in silent mode.
no_fork: bool, optional
Whether forbid fork in multiprocessing.
"""

def __init__(
Expand All @@ -356,101 +427,44 @@ def __init__(
port=9091,
port_end=9199,
is_proxy=False,
use_popen=False,
tracker_addr=None,
key="",
load_library=None,
custom_addr=None,
silent=False,
no_fork=False,
):
try:
if _ffi_api.ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1")
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_rpc_server,
[
host,
port,
port_end,
is_proxy,
tracker_addr,
key,
load_library,
custom_addr,
silent,
no_fork,
],
)
# receive the port
self.port = self.proc.recv()
self.host = host
self.port = port
self.libs = []
self.custom_addr = custom_addr
self.use_popen = use_popen

if silent:
logger.setLevel(logging.ERROR)

if use_popen:
cmd = [
sys.executable,
"-m",
"tvm.exec.rpc_server",
"--host=%s" % host,
"--port=%s" % port,
"--port-end=%s" % port_end,
]
if tracker_addr:
assert key
cmd += ["--tracker=%s:%d" % tracker_addr, "--key=%s" % key]
if load_library:
cmd += ["--load-library", load_library]
if custom_addr:
cmd += ["--custom-addr", custom_addr]
if silent:
cmd += ["--silent"]

# prexec_fn is not thread safe and may result in deadlock.
# python 3.2 introduced the start_new_session parameter as
# an alternative to the common use case of
# prexec_fn=os.setsid. Once the minimum version of python
# supported by TVM reaches python 3.2 this code can be
# rewritten in favour of start_new_session. In the
# interim, stop the pylint diagnostic.
#
# pylint: disable=subprocess-popen-preexec-fn
if platform.system() == "Windows":
self.proc = subprocess.Popen(cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
else:
self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
time.sleep(0.5)
elif not is_proxy:
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop,
args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr),
)
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key, load_library)
)
self.proc.start()

def terminate(self):
"""Terminate the server process"""
if self.use_popen:
if self.proc:
if platform.system() == "Windows":
os.kill(self.proc.pid, signal.CTRL_C_EVENT)
else:
os.killpg(self.proc.pid, signal.SIGTERM)
self.proc = None
else:
if self.proc:
self.proc.terminate()
self.proc = None
if self.proc:
self.proc.kill()
self.proc = None

def __del__(self):
self.terminate()
Loading

0 comments on commit 2d41d7e

Please sign in to comment.