Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HITL - Allow empty user masks and add users to dependency injection container. #1932

Merged
merged 3 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions habitat-hitl/habitat_hitl/_internal/hitl_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def local_end_episode(do_reset=False):
self._app_service = AppService(
config=config,
hitl_config=self._hitl_config,
users=users,
gui_input=gui_input,
remote_client_state=self._remote_client_state,
gui_drawer=gui_drawer,
Expand Down
7 changes: 7 additions & 0 deletions habitat-hitl/habitat_hitl/app_states/app_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from habitat_hitl.core.remote_client_state import RemoteClientState
from habitat_hitl.core.serialize_utils import BaseRecorder
from habitat_hitl.core.text_drawer import AbstractTextDrawer
from habitat_hitl.core.user_mask import Users
from habitat_hitl.environment.controllers.controller_abc import GuiController
from habitat_hitl.environment.episode_helper import EpisodeHelper

Expand All @@ -27,6 +28,7 @@ def __init__(
*,
config,
hitl_config,
users: Users,
gui_input: GuiInput,
remote_client_state: RemoteClientState,
gui_drawer: GuiDrawer,
Expand All @@ -45,6 +47,7 @@ def __init__(
):
self._config = config
self._hitl_config = hitl_config
self._users = users
self._gui_input = gui_input
self._remote_client_state = remote_client_state
self._gui_drawer = gui_drawer
Expand All @@ -69,6 +72,10 @@ def config(self):
def hitl_config(self):
return self._hitl_config

@property
def users(self) -> Users:
return self._users

@property
def gui_input(self) -> GuiInput:
return self._gui_input
Expand Down
2 changes: 1 addition & 1 deletion habitat-hitl/habitat_hitl/core/user_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Users:
_max_user_count: int

def __init__(self, max_user_count: int) -> None:
assert max_user_count > 0
assert max_user_count >= 0
assert max_user_count <= Mask.MAX_VALUE
self._max_user_count = max_user_count

Expand Down
17 changes: 17 additions & 0 deletions habitat-hitl/test/test_user_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@


def test_hitl_user_mask():
# Test without any user.
zero_users = Users(0)
assert zero_users.max_user_count == 0
assert len(zero_users.to_index_list(Mask.ALL)) == 0
assert len(zero_users.to_index_list(Mask.NONE)) == 0
user_indices = zero_users.to_index_list(
Mask.from_index(0) | Mask.from_index(1)
)
assert 0 not in user_indices
assert 1 not in user_indices
user_indices = zero_users.to_index_list(Mask.all_except_index(0))
assert 0 not in user_indices

# Test without 4 users.
four_users = Users(4)
assert four_users.max_user_count == 4
assert len(four_users.to_index_list(Mask.ALL)) == 4
Expand All @@ -26,6 +40,7 @@ def test_hitl_user_mask():
assert 3 in user_indices
assert 4 not in user_indices

# Test without 6 users.
six_users = Users(6)
assert six_users.max_user_count == 6
assert len(six_users.to_index_list(Mask.ALL)) == 6
Expand All @@ -39,6 +54,7 @@ def test_hitl_user_mask():
assert 5 in user_indices
assert 6 not in user_indices

# Test without 2 users.
two_users = Users(2)
assert two_users.max_user_count == 2
assert len(two_users.to_index_list(Mask.ALL)) == 2
Expand All @@ -48,6 +64,7 @@ def test_hitl_user_mask():
assert 1 in user_indices
assert 2 not in user_indices

# Test without max users (32).
max_users = Users(32)
assert max_users.max_user_count == 32
assert len(max_users.to_index_list(Mask.ALL)) == 32
Expand Down