Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PYTHON][RPC] Make rpc proxy jupyter friendly via PopenWorker. #7757

Merged
merged 2 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def kill(self):
except IOError:
pass
# kill all child processes recurisvely
kill_child_processes(self._proc.pid)
try:
kill_child_processes(self._proc.pid)
except TypeError:
pass
try:
self._proc.kill()
except OSError:
Expand Down Expand Up @@ -149,6 +152,11 @@ def _start(self):
self._reader = os.fdopen(main_read, "rb")
self._writer = os.fdopen(main_write, "wb")

def join(self):
"""Join the current process worker before it terminates"""
if self._proc:
self._proc.wait()

def send(self, fn, args=(), kwargs=None, timeout=None):
"""Send a new function task fn(*args, **kwargs) to the subprocess.

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/exec/popen_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import threading
import traceback
import pickle
import logging
import cloudpickle

from tvm.contrib.popen_pool import StatusKind
Expand Down Expand Up @@ -49,6 +50,8 @@ def main():
reader = os.fdopen(int(sys.argv[1]), "rb")
writer = os.fdopen(int(sys.argv[2]), "wb")

logging.basicConfig(level=logging.INFO)
junrushao marked this conversation as resolved.
Show resolved Hide resolved

lock = threading.Lock()

def _respond(ret_value):
Expand Down
24 changes: 1 addition & 23 deletions python/tvm/exec/rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@
# under the License.
# pylint: disable=redefined-outer-name, invalid-name
"""RPC web proxy, allows redirect to websocket based RPC servers(browsers)"""
from __future__ import absolute_import

import logging
import argparse
import multiprocessing
import sys
import os
from ..rpc.proxy import Proxy
from tvm.rpc.proxy import Proxy


def find_example_resource():
Expand Down Expand Up @@ -82,24 +78,6 @@ def main(args):
"--example-rpc", type=bool, default=False, help="Whether to switch on example rpc mode"
)
parser.add_argument("--tracker", type=str, default="", help="Report to RPC tracker")
parser.add_argument(
"--no-fork",
dest="fork",
action="store_false",
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.",
)
parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.fork is False:
if sys.version_info[0] < 3:
raise RuntimeError("Python3 is required for spawn mode.")
multiprocessing.set_start_method("spawn")
else:
logging.info(
"If you are running ROCM/Metal, \
fork with cause compiler internal error. Try to launch with arg ```--no-fork```"
)
main(args)
125 changes: 98 additions & 27 deletions python/tvm/rpc/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
the proxy server will forward the message between the client and server.
"""
# pylint: disable=unused-variable, unused-argument
from __future__ import absolute_import

import os
import asyncio
import logging
import socket
import multiprocessing
import threading
import errno
import struct
import time
Expand All @@ -43,6 +42,7 @@
"RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg
)

from tvm.contrib.popen_pool import PopenWorker
from . import _ffi_api
from . import base
from .base import TrackerCode
Expand Down Expand Up @@ -261,6 +261,7 @@ def __init__(
logging.info(pair)
self.app = tornado.web.Application(handlers)
self.app.listen(web_port)

self.sock = sock
self.sock.setblocking(0)
self.loop = ioloop.IOLoop.current()
Expand Down Expand Up @@ -471,6 +472,7 @@ def _proxy_server(
index_page,
resource_files,
):
asyncio.set_event_loop(asyncio.new_event_loop())
handler = ProxyServerHandler(
listen_sock,
listen_port,
Expand All @@ -484,6 +486,87 @@ def _proxy_server(
handler.run()


class PopenProxyServerState(object):
"""Internal PopenProxy State for Popen"""

current = None

def __init__(
self,
host,
port=9091,
port_end=9199,
web_port=0,
timeout_client=600,
timeout_server=600,
tracker_addr=None,
index_page=None,
resource_files=None,
):

sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the same on windows? @rkimball

continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCProxy: client port bind to %s:%d", host, self.port)
sock.listen(1)
self.thread = threading.Thread(
target=_proxy_server,
args=(
sock,
self.port,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
),
)
# start the server in a different thread
# so we can return the port directly
self.thread.start()


def _popen_start_server(
host,
port=9091,
port_end=9199,
web_port=0,
timeout_client=600,
timeout_server=600,
tracker_addr=None,
index_page=None,
resource_files=None,
):
# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
state = PopenProxyServerState(
host,
port,
port_end,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
)
PopenProxyServerState.current = state
# returns the port so that the main can get the port number.
return state.port


class Proxy(object):
"""Start RPC proxy server on a seperate process.

Expand Down Expand Up @@ -532,43 +615,31 @@ def __init__(
index_page=None,
resource_files=None,
):
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCProxy: client port bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(
target=_proxy_server,
args=(
sock,
self.port,
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_server,
[
host,
port,
port_end,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
),
],
)
self.proc.start()
sock.close()
# receive the port
self.port = self.proc.recv()
self.host = host

def terminate(self):
"""Terminate the server process"""
if self.proc:
logging.info("Terminating Proxy Server...")
self.proc.terminate()
self.proc.kill()
self.proc = None

def __del__(self):
Expand Down
17 changes: 4 additions & 13 deletions tests/python/contrib/test_rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,12 @@ def rpc_proxy_check():
from tvm.rpc import proxy

web_port = 8888
prox = proxy.Proxy("localhost", web_port=web_port)
prox = proxy.Proxy("127.0.0.1", web_port=web_port)

def check():
if not tvm.runtime.enabled("rpc"):
return

@tvm.register_func("rpc.test2.addone")
def addone(x):
return x + 1

@tvm.register_func("rpc.test2.strcat")
def addone(name, x):
return "%s:%d" % (name, x)

server = multiprocessing.Process(
target=proxy.websocket_proxy_server, args=("ws://localhost:%d/ws" % web_port, "x1")
)
Expand All @@ -60,10 +52,9 @@ def addone(name, x):
server.deamon = True
server.start()
client = rpc.connect(prox.host, prox.port, key="x1")
f1 = client.get_function("rpc.test2.addone")
assert f1(10) == 11
f2 = client.get_function("rpc.test2.strcat")
assert f2("abc", 11) == "abc:11"
f1 = client.get_function("testing.echo")
assert f1(10) == 10
assert f1("xyz") == "xyz"

check()
except ImportError:
Expand Down