Skip to content

Commit

Permalink
[VTA] Recover rpc server support
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jul 30, 2021
1 parent ff71773 commit 8ebbcf0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
26 changes: 26 additions & 0 deletions python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,14 @@ def _popen_start_rpc_server(
custom_addr=None,
silent=False,
no_fork=False,
server_init_callback=None,
):
if no_fork:
multiprocessing.set_start_method("spawn")

if server_init_callback:
server_init_callback()

# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
Expand Down Expand Up @@ -420,6 +425,25 @@ class Server(object):
no_fork: bool, optional
Whether forbid fork in multiprocessing.
server_init_callback: Callable, optional
Additional initialization function when starting the server.
Note
----
The RPC server only sees functions in the tvm namespace.
To bring additional custom functions to the server env, you can use server_init_callback.
.. code:: python
def server_init_callback():
import tvm
# must import mypackage here
import mypackage
tvm.register_func("function", mypackage.func)
server = rpc.Server(host, server_init_callback=server_init_callback)
"""

def __init__(
Expand All @@ -434,6 +458,7 @@ def __init__(
custom_addr=None,
silent=False,
no_fork=False,
server_init_callback=None,
):
try:
if _ffi_api.ServerLoop is None:
Expand All @@ -455,6 +480,7 @@ def __init__(
custom_addr,
silent,
no_fork,
server_init_callback,
],
)
# receive the port
Expand Down
16 changes: 14 additions & 2 deletions vta/python/vta/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from ..libinfo import find_libvta


@tvm.register_func("tvm.rpc.server.start", override=True)
def server_start():
"""VTA RPC server extension."""
# pylint: disable=unused-variable
Expand Down Expand Up @@ -148,8 +147,21 @@ def main():
else:
tracker_addr = None

# register the initialization callback
def server_init_callback():
# pylint: disable=redefined-outer-name, reimported, import-outside-toplevel, import-self
import tvm
import vta.exec.rpc_server

tvm.register_func("tvm.rpc.server.start", vta.exec.rpc_server.server_start, override=True)

server = rpc.Server(
args.host, args.port, args.port_end, key=args.key, tracker_addr=tracker_addr
args.host,
args.port,
args.port_end,
key=args.key,
tracker_addr=tracker_addr,
server_init_callback=server_init_callback,
)
server.proc.join()

Expand Down

0 comments on commit 8ebbcf0

Please sign in to comment.