Skip to content

Commit

Permalink
Add dedicated kick queue for HITL. (facebookresearch#1931)
Browse files Browse the repository at this point in the history
  • Loading branch information
0mdc authored and dannymcy committed Jun 26, 2024
1 parent aa8438d commit 0559866
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
# LICENSE file in the root directory of this source tree.

from multiprocessing import Queue
from typing import List, Optional
from typing import Any, List, Optional

from habitat_hitl.core.types import (
ClientState,
ConnectionRecord,
DataDict,
DisconnectionRecord,
Keyframe,
)
Expand All @@ -27,12 +26,16 @@ def __init__(self, networking_config) -> None:
self._client_state_queue: Queue[ClientState] = Queue()
self._connection_record_queue: Queue[ConnectionRecord] = Queue()
self._disconnection_record_queue: Queue[DisconnectionRecord] = Queue()
self._kick_signal_queue: Queue[int] = Queue()

def send_keyframe_to_networking_thread(self, keyframe: Keyframe) -> None:
"""Send a keyframe (outgoing data) to the networking thread."""
# Acquire the semaphore to ensure the simulation doesn't advance too far ahead
self._keyframe_queue.put(keyframe)

def send_kick_signal_to_networking_thread(self, user_index: int) -> None:
self._kick_signal_queue.put(user_index)

def send_client_state_to_main_thread(
self, client_state: ClientState
) -> None:
Expand Down Expand Up @@ -63,7 +66,7 @@ def get_single_queued_keyframe(self) -> Optional[Keyframe]:
return keyframe

@staticmethod
def _dequeue_all(queue: Queue) -> List[DataDict]:
def _dequeue_all(queue: Queue) -> List[Any]:
"""Dequeue all items from a queue."""
items = []

Expand All @@ -88,3 +91,7 @@ def get_queued_connection_records(self) -> List[ConnectionRecord]:
def get_queued_disconnection_records(self) -> List[DisconnectionRecord]:
"""Dequeue all disconnection records."""
return self._dequeue_all(self._disconnection_record_queue)

def get_queued_kick_signals(self) -> List[int]:
"""Dequeue all kick signals."""
return self._dequeue_all(self._kick_signal_queue)
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,21 @@ def is_okay_to_send_keyframes(self) -> bool:
and not self._waiting_for_app_ready
)

def _check_kick_client(self):
kicked_user_indices = (
self._interprocess_record.get_queued_kick_signals()
)
# TODO: We only support 1 user at the moment.
if 0 in kicked_user_indices:
for socket in self._connected_clients.values():
# Don't await this; we want to keep checking keyframes.
# Beware that the connection will remain alive for some time after this.
asyncio.create_task(socket.close())

async def check_keyframe_queue(self) -> None:
# this runs continuously even when there is no client connection
while True:
self._check_kick_client()
inc_keyframes = self._interprocess_record.get_queued_keyframes()

if len(inc_keyframes):
Expand All @@ -165,15 +177,6 @@ async def check_keyframe_queue(self) -> None:
if "message" in inc_keyframes[0]:
message_dict = inc_keyframes[0]["message"]

# for kickClient, we require the requester to include the connection_id. This ensures we don't kick the wrong client. E.g. the requester recently requested to kick an idle client, but NetworkManager already dropped that client and received a new client connection.
if "kickClient" in message_dict:
connection_id = message_dict["kickClient"]
if connection_id in self._connected_clients:
print(f"kicking client {connection_id}")
websocket = self._connected_clients[connection_id]
# Don't await this; we want to keep checking keyframes. Beware this means the connection will remain alive for some time after this.
asyncio.create_task(websocket.close())

# See hitl_defaults.yaml wait_for_app_ready_signal and ClientMessageManager.signal_app_ready
if (
self._waiting_for_app_ready
Expand Down
7 changes: 4 additions & 3 deletions habitat-hitl/habitat_hitl/core/client_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from habitat_hitl.app_states.app_service import AppService
from habitat_hitl.core.average_helper import AverageHelper
from habitat_hitl.core.user_mask import Mask


class ClientHelper:
Expand All @@ -18,6 +19,7 @@ class ClientHelper:

def __init__(self, app_service: AppService):
self._app_service = app_service
self._remote_client_state = app_service.remote_client_state
self._frame_counter = 0

self._client_frame_latency_avg_helper = AverageHelper(
Expand Down Expand Up @@ -73,9 +75,8 @@ def _update_idle_kick(self, is_user_idle_this_frame: bool) -> None:
self._show_idle_kick_warning = True

if self._idle_frame_counter > max_idle_frames:
self._app_service.client_message_manager.signal_kick_client(
self._client_connection_id
)
# TODO: We only support 1 user at the moment.
self._remote_client_state.kick(Mask.from_index(0))
self._idle_frame_counter = None
else:
# reset counter whenever the client isn't idle
Expand Down
10 changes: 0 additions & 10 deletions habitat-hitl/habitat_hitl/core/client_message_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,6 @@ def signal_app_ready(self, destination_mask: Mask = Mask.ALL):
message = self._messages[user_index]
message["isAppReady"] = True

def signal_kick_client(
self, connection_id: int, destination_mask: Mask = Mask.ALL
):
r"""
Signal NetworkManager to kick a client identified by connection_id. See also RemoteClientState.get_new_connection_records()[i]["connectionId"]. Sloppy: this is a message to NetworkManager, not the client.
"""
for user_index in self._users.indices(destination_mask):
message = self._messages[user_index]
message["kickClient"] = connection_id

def set_server_keyframe_id(
self, keyframe_id: int, destination_mask: Mask = Mask.ALL
):
Expand Down
9 changes: 9 additions & 0 deletions habitat-hitl/habitat_hitl/core/remote_client_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,12 @@ def on_frame_end(self) -> None:

def clear_history(self) -> None:
self._recent_client_states.clear()

def kick(self, user_mask: Mask) -> None:
"""
Immediately kick the users matching the specified user mask.
"""
for user_index in self._users.indices(user_mask):
self._interprocess_record.send_kick_signal_to_networking_thread(
user_index
)

0 comments on commit 0559866

Please sign in to comment.