Skip to content

Commit

Permalink
Remove libp2p handlers when ConnectionHandler, DHT, and Decentralized…
Browse files Browse the repository at this point in the history
…Averager are shut down (learning-at-home#501)
  • Loading branch information
borzunov authored Aug 17, 2022
1 parent 2826147 commit 3267fc7
Show file tree
Hide file tree
Showing 16 changed files with 297 additions and 119 deletions.
9 changes: 5 additions & 4 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
from hivemind.dht import DHT, DHTID
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
from hivemind.p2p import P2P, P2PContext, P2PDaemonError, P2PHandlerError, PeerID, ServicerBase
from hivemind.proto import averaging_pb2
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
from hivemind.utils.asyncio import (
Expand Down Expand Up @@ -350,6 +349,9 @@ def shutdown(self) -> None:
logger.exception("Averager shutdown has no effect: the process is already not alive")

async def _shutdown(self, timeout: Optional[float]) -> None:
if not self.client_mode:
await self.remove_p2p_handlers(self._p2p, namespace=self.prefix)

remaining_tasks = set()
for group in self._running_groups.values():
remaining_tasks.update(group.finalize(cancel=True))
Expand Down Expand Up @@ -469,8 +471,7 @@ async def find_peers_or_notify_cancel():
asyncio.CancelledError,
asyncio.InvalidStateError,
P2PHandlerError,
DispatchFailure,
ControlFailure,
P2PDaemonError,
) as e:
if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
if not step.cancelled():
Expand Down
5 changes: 2 additions & 3 deletions hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
from hivemind.dht import DHT, DHTID
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
from hivemind.p2p import P2P, P2PContext, P2PDaemonError, P2PHandlerError, PeerID, ServicerBase
from hivemind.proto import averaging_pb2
from hivemind.utils import DHTExpiration, TimedStorage, get_dht_time, get_logger, timed_storage
from hivemind.utils.asyncio import anext, cancel_and_wait
Expand Down Expand Up @@ -239,7 +238,7 @@ async def _request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
except asyncio.TimeoutError:
logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
return None
except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
except (P2PDaemonError, P2PHandlerError, StopAsyncIteration) as e:
logger.debug(f"{self} - failed to request potential leader {leader}:", exc_info=True)
return None

Expand Down
1 change: 1 addition & 0 deletions hivemind/dht/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def __init__(self, *, _initialized_with_create=False):
async def shutdown(self):
"""Process existing requests, close all connections and stop the server"""
self.is_alive = False
await self.protocol.shutdown()
if self._should_shutdown_p2p:
await self.p2p.shutdown()

Expand Down
6 changes: 5 additions & 1 deletion hivemind/dht/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def create(
self.record_validator = record_validator
self.authorizer = authorizer

if not client_mode:
if not self.client_mode:
await self.add_p2p_handlers(self.p2p, AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer))

self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes())
Expand All @@ -79,6 +79,10 @@ async def create(
self.node_info = dht_pb2.NodeInfo()
return self

async def shutdown(self) -> None:
if not self.client_mode:
await self.remove_p2p_handlers(self.p2p)

def __init__(self, *, _initialized_with_create=False):
"""Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
Expand Down
55 changes: 47 additions & 8 deletions hivemind/moe/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,75 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
:param module_backends: a dict [UID -> ModuleBackend] with all active experts
"""

def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend]):
def __init__(
self,
dht: DHT,
module_backends: Dict[str, ModuleBackend],
*,
balanced: bool = True,
shutdown_timeout: float = 3,
start: bool = False,
):
super().__init__()
self.dht, self.module_backends = dht, module_backends
self.balanced, self.shutdown_timeout = balanced, shutdown_timeout
self._p2p: Optional[P2P] = None

self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False)
self.ready = MPFuture()

if start:
self.run_in_background(await_ready=True)

def run(self):
torch.set_num_threads(1)
loop = switch_to_uvloop()
stop = asyncio.Event()
loop.add_reader(self._inner_pipe.fileno(), stop.set)

async def _run():
try:
self._p2p = await self.dht.replicate_p2p()
await self.add_p2p_handlers(self._p2p, balanced=True)

# wait forever
await asyncio.Future()

await self.add_p2p_handlers(self._p2p, balanced=self.balanced)
self.ready.set_result(None)
except Exception as e:
logger.error("ConnectionHandler failed to start:", exc_info=True)
self.ready.set_exception(e)
return

self.ready.set_result(None)
try:
await stop.wait()
finally:
await self.remove_p2p_handlers(self._p2p)

try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")

def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
"""
Starts ConnectionHandler in a background process. If :await_ready:, this method will wait until
it is ready to process incoming requests or for :timeout: seconds max.
"""
self.start()
if await_ready:
self.wait_until_ready(timeout)

def wait_until_ready(self, timeout: Optional[float] = None) -> None:
self.ready.result(timeout=timeout)

def shutdown(self):
if self.is_alive():
self._outer_pipe.send("_shutdown")
self.join(self.shutdown_timeout)
if self.is_alive():
logger.warning(
"ConnectionHandler did not shut down within the grace period; terminating it the hard way"
)
self.terminate()
else:
logger.warning("ConnectionHandler shutdown had no effect, the process is already dead")

async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
module_info = self.module_backends[request.uid].get_info()
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))
Expand Down
4 changes: 2 additions & 2 deletions hivemind/p2p/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
from hivemind.p2p.p2p_daemon import P2P, P2PContext
from hivemind.p2p.p2p_daemon_bindings import P2PDaemonError, P2PHandlerError, PeerID, PeerInfo
from hivemind.p2p.servicer import ServicerBase, StubBase
16 changes: 16 additions & 0 deletions hivemind/p2p/p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,19 @@ async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2

await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type, balanced=balanced)

async def remove_protobuf_handler(
self,
name: str,
*,
stream_input: bool = False,
stream_output: bool = False,
) -> None:
if not stream_input and not stream_output:
await self._client.remove_unary_handler(name)
return

await self.remove_binary_stream_handler(name)

async def _add_protobuf_unary_handler(
self,
handle_name: str,
Expand Down Expand Up @@ -553,6 +566,9 @@ async def add_binary_stream_handler(
self._start_listening()
await self._client.stream_handler(name, handler, balanced)

async def remove_binary_stream_handler(self, name: str) -> None:
await self._client.remove_stream_handler(name)

async def call_binary_stream_handler(
self, peer_id: PeerID, handler_name: str
) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
Expand Down
2 changes: 2 additions & 0 deletions hivemind/p2p/p2p_daemon_bindings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
from hivemind.p2p.p2p_daemon_bindings.utils import P2PDaemonError, P2PHandlerError
77 changes: 55 additions & 22 deletions hivemind/p2p/p2p_daemon_bindings/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
from multiaddr import Multiaddr, protocols

from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
from hivemind.p2p.p2p_daemon_bindings.utils import DispatchFailure, raise_if_failed, read_pbmsg_safe, write_pbmsg
from hivemind.p2p.p2p_daemon_bindings.utils import (
DispatchFailure,
P2PDaemonError,
P2PHandlerError,
raise_if_failed,
read_pbmsg_safe,
write_pbmsg,
)
from hivemind.proto import p2pd_pb2 as p2pd_pb
from hivemind.utils.logging import get_logger

Expand Down Expand Up @@ -249,20 +256,37 @@ async def _ensure_persistent_conn(self):
self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))

async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False) -> None:
if proto in self.unary_handlers:
raise P2PDaemonError(f"Handler for protocol {proto} already registered")
self.unary_handlers[proto] = handler

call_id = uuid4()
req = p2pd_pb.PersistentConnectionRequest(
callId=call_id.bytes,
addUnaryHandler=p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced),
)

add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced)
req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
self._pending_calls[call_id] = asyncio.Future()
await self._pending_messages.put(req)
await self._pending_calls[call_id]

if self.unary_handlers.get(proto):
raise P2PDaemonError(f"Handler for protocol {proto} already registered")
self.unary_handlers[proto] = handler
async def remove_unary_handler(self, proto: str) -> None:
if proto not in self.unary_handlers:
raise P2PDaemonError(f"Handler for protocol {proto} is not registered")

call_id = uuid4()
req = p2pd_pb.PersistentConnectionRequest(
callId=call_id.bytes,
removeUnaryHandler=p2pd_pb.RemoveUnaryHandlerRequest(proto=proto),
)

self._pending_calls[call_id] = asyncio.Future()
await self._pending_messages.put(req)
await self._pending_calls[call_id]

del self.unary_handlers[proto]

async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
call_id = uuid4()
call_unary_req = p2pd_pb.CallUnaryRequest(
Expand Down Expand Up @@ -362,31 +386,40 @@ async def stream_open(
return stream_info, reader, writer

async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
self.handlers[proto] = handler_cb

reader, writer = await self.daemon_connector.open_connection()

listen_path_maddr_bytes = self.listen_maddr.to_bytes()
stream_handler_req = p2pd_pb.StreamHandlerRequest(
addr=listen_path_maddr_bytes, proto=[proto], balanced=balanced
req = p2pd_pb.Request(
type=p2pd_pb.Request.STREAM_HANDLER,
streamHandler=p2pd_pb.StreamHandlerRequest(
addr=self.listen_maddr.to_bytes(),
proto=[proto],
balanced=balanced,
),
)
req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
await write_pbmsg(writer, req)

resp = p2pd_pb.Response() # type: ignore
await read_pbmsg_safe(reader, resp)
writer.close()
raise_if_failed(resp)

# if success, add the handler to the dict
self.handlers[proto] = handler_cb

async def remove_stream_handler(self, proto: str) -> None:
reader, writer = await self.daemon_connector.open_connection()

class P2PHandlerError(Exception):
"""
Raised if remote handled a request with an exception
"""
req = p2pd_pb.Request(
type=p2pd_pb.Request.REMOVE_STREAM_HANDLER,
removeStreamHandler=p2pd_pb.RemoveStreamHandlerRequest(
addr=self.listen_maddr.to_bytes(),
proto=[proto],
),
)
await write_pbmsg(writer, req)

resp = p2pd_pb.Response() # type: ignore
await read_pbmsg_safe(reader, resp)
writer.close()
raise_if_failed(resp)

class P2PDaemonError(Exception):
"""
Raised if daemon failed to handle request
"""
del self.handlers[proto]
8 changes: 7 additions & 1 deletion hivemind/p2p/p2p_daemon_bindings/p2pclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ async def listen(self) -> AsyncIterator["Client"]:
async with self.control.listen():
yield self

async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False) -> None:
await self.control.add_unary_handler(proto, handler, balanced=balanced)

async def remove_unary_handler(self, proto: str) -> None:
await self.control.remove_unary_handler(proto)

async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
return await self.control.call_unary_handler(peer_id, proto, data)

Expand Down Expand Up @@ -114,3 +117,6 @@ async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced:
:return:
"""
await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced)

async def remove_stream_handler(self, proto: str) -> None:
await self.control.remove_stream_handler(proto=proto)
16 changes: 14 additions & 2 deletions hivemind/p2p/p2p_daemon_bindings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,23 @@
DEFAULT_MAX_BITS: int = 64


class ControlFailure(Exception):
class P2PHandlerError(Exception):
"""
Raised if remote handled a request with an exception
"""


class P2PDaemonError(Exception):
"""
Raised if daemon failed to handle request
"""


class ControlFailure(P2PDaemonError):
pass


class DispatchFailure(Exception):
class DispatchFailure(P2PDaemonError):
pass


Expand Down
14 changes: 14 additions & 0 deletions hivemind/p2p/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ async def add_p2p_handlers(
]
)

async def remove_p2p_handlers(self, p2p: P2P, *, namespace: Optional[str] = None) -> None:
self._collect_rpc_handlers()

await asyncio.gather(
*[
p2p.remove_protobuf_handler(
self._get_handle_name(namespace, handler.method_name),
stream_input=handler.stream_input,
stream_output=handler.stream_output,
)
for handler in self._rpc_handlers
]
)

@classmethod
def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
cls._collect_rpc_handlers()
Expand Down
Loading

0 comments on commit 3267fc7

Please sign in to comment.