diff --git a/examples/hitl/rearrange_v2/app_state_end_session.py b/examples/hitl/rearrange_v2/app_state_end_session.py index 7ac00a43fa..6129a515f6 100644 --- a/examples/hitl/rearrange_v2/app_state_end_session.py +++ b/examples/hitl/rearrange_v2/app_state_end_session.py @@ -13,8 +13,8 @@ from app_states import create_app_state_reset from s3_upload import ( generate_unique_session_id, - make_s3_filename, upload_file_to_s3, + validate_experiment_name, ) from session import Session from util import get_top_down_view @@ -71,27 +71,32 @@ def _end_session(self): # Finalize session. if self._session.error == "": - session.success = True + session.finished = True session.session_recorder.end_session(self._session.error) # Get data collection parameters. - try: - config = self._app_service.config - data_collection_config = config.rearrange_v2.data_collection - s3_path = data_collection_config.s3_path - s3_subdir = "complete" if session.success else "incomplete" - s3_path = os.path.join(s3_path, s3_subdir) - - # Use the port as a discriminator for when there are multiple concurrent servers. - output_folder_suffix = str(config.habitat_hitl.networking.port) - output_folder = f"output_{output_folder_suffix}" - - output_file_name = data_collection_config.output_file_name - output_file = f"{output_file_name}.json.gz" - - except Exception as e: - print(f"Invalid data collection config. Skipping S3 upload. {e}") - return + config = self._app_service.config + + # Set the base S3 path as the experiment name. + s3_path: str = "default" + if len(session.connection_records) > 0: + experiment_name: Optional[str] = session.connection_records[0].get( + "experiment", None + ) + if validate_experiment_name(experiment_name): + s3_path = experiment_name + else: + print(f"Invalid experiment name: '{experiment_name}'") + + # Generate unique session ID + session_id = generate_unique_session_id( + session.episode_indices, session.connection_records + ) + s3_path = os.path.join(s3_path, session_id) + + # Use the port as a discriminator for when there are multiple concurrent servers. + output_folder_suffix = str(config.habitat_hitl.networking.port) + output_folder = f"output_{output_folder_suffix}" # Delete previous output directory if os.path.exists(output_folder): @@ -99,14 +104,25 @@ def _end_session(self): # Create new output directory os.makedirs(output_folder) - json_path = os.path.join(output_folder, output_file) - save_as_json_gzip(session.session_recorder, json_path) - # Generate unique session ID - session_id = generate_unique_session_id( - session.episode_indices, session.connection_records + # Create a session metadata file. + session_json_path = os.path.join(output_folder, "session.json.gz") + save_as_json_gzip( + session.session_recorder.get_session_output(), session_json_path ) + # Create one file per episode. + episode_outputs = session.session_recorder.get_episode_outputs() + for episode_output in episode_outputs: + episode_json_path = os.path.join( + output_folder, + f"{episode_output.episode.episode_index}.json.gz", + ) + save_as_json_gzip( + episode_output, + episode_json_path, + ) + # Upload output directory orig_file_names = [ f @@ -115,5 +131,8 @@ def _end_session(self): ] for orig_file_name in orig_file_names: local_file_path = os.path.join(output_folder, orig_file_name) - s3_file_name = make_s3_filename(session_id, orig_file_name) + s3_file_name = orig_file_name + print( + f"Uploading '{local_file_path}' to '{s3_path}' as '{s3_file_name}'." + ) upload_file_to_s3(local_file_path, s3_file_name, s3_path) diff --git a/examples/hitl/rearrange_v2/app_state_load_episode.py b/examples/hitl/rearrange_v2/app_state_load_episode.py index 3c6e02cefc..fd2cdd8e30 100644 --- a/examples/hitl/rearrange_v2/app_state_load_episode.py +++ b/examples/hitl/rearrange_v2/app_state_load_episode.py @@ -90,9 +90,9 @@ def sim_update(self, dt: float, post_sim_update_dict): def _increment_episode(self): session = self._session assert session.episode_indices is not None - if session.current_episode_index < len(session.episode_indices): - self._set_episode(session.current_episode_index) - session.current_episode_index += 1 + if session.next_session_episode < len(session.episode_indices): + self._set_episode(session.next_session_episode) + session.next_session_episode += 1 else: self._session_ended = True @@ -102,6 +102,7 @@ def _set_episode(self, episode_index: int): # Set the ID of the next episode to play in lab. next_episode_index = session.episode_indices[episode_index] + session.current_episode_index = next_episode_index print(f"Next episode index: {next_episode_index}.") try: app_service.episode_helper.set_next_episode_by_index( diff --git a/examples/hitl/rearrange_v2/rearrange_v2.py b/examples/hitl/rearrange_v2/rearrange_v2.py index 8dc051834b..082b66e2fc 100644 --- a/examples/hitl/rearrange_v2/rearrange_v2.py +++ b/examples/hitl/rearrange_v2/rearrange_v2.py @@ -547,11 +547,14 @@ def on_enter(self): ].gui_agent_controller._agent_idx episode = self._app_service.episode_helper.current_episode + self._session.session_recorder.start_episode( - episode.episode_id, - episode.scene_id, - episode.scene_dataset_config, - user_index_to_agent_index_map, + episode_index=self._session.current_episode_index, + episode_id=episode.episode_id, + scene_id=episode.scene_id, + dataset=episode.scene_dataset_config, + user_index_to_agent_index_map=user_index_to_agent_index_map, + episode_info=episode.info, ) def on_exit(self): @@ -759,6 +762,8 @@ def sim_update(self, dt: float, post_sim_update_dict): self._metrics.get_task_percent_complete(), ) self._session.session_recorder.record_frame(frame_data) + else: + self._session.session_recorder.record_frame({}) def _is_any_agent_policy_driven(self) -> bool: """ diff --git a/examples/hitl/rearrange_v2/s3_upload.py b/examples/hitl/rearrange_v2/s3_upload.py index 54c8d1872f..13213bd58c 100644 --- a/examples/hitl/rearrange_v2/s3_upload.py +++ b/examples/hitl/rearrange_v2/s3_upload.py @@ -6,7 +6,7 @@ import os -from typing import Dict, List +from typing import Dict, List, Optional from util import timestamp @@ -109,3 +109,20 @@ def make_s3_filename(session_id: str, orig_file_name: str) -> str: s3_filename += "!" return s3_filename + + +def validate_experiment_name(experiment_name: Optional[str]) -> bool: + """ + Check whether the given experiment name is valid. + """ + if experiment_name is None: + return False + + if len(experiment_name) > 128: + return False + + authorized_chars = ["_", "-", "."] + return all( + not (not c.isalnum() and c not in authorized_chars) + for c in experiment_name + ) diff --git a/examples/hitl/rearrange_v2/session.py b/examples/hitl/rearrange_v2/session.py index 8926da32f1..b7f5e4a69e 100644 --- a/examples/hitl/rearrange_v2/session.py +++ b/examples/hitl/rearrange_v2/session.py @@ -13,7 +13,7 @@ class Session: """ - Data for a single RearrangeV2 session. + Data for a single HITL session. A session is defined as a sequence of episodes done by a fixed set of users. """ @@ -23,13 +23,33 @@ def __init__( episode_indices: List[int], connection_records: Dict[int, ConnectionRecord], ): - self.success = False + self.finished = False + """Whether the session is finished.""" + self.episode_indices = episode_indices + """List of episode indices within the session.""" + self.current_episode_index = 0 + """ + Current episode index within the episode set. + + If there are `1000` episodes, `current_episode_index` would be a value between `0` and `999` inclusively. + """ + + self.next_session_episode = 0 + """ + Next index of the `episode_indices` list (element index, not episode index). + + If `episode_indices` contains the values `10`, `20` and `30`, `next_session_episode` would be either `0`, `1`, `2` or `3`. + """ + self.connection_records = connection_records + """Connection records of each user.""" self.session_recorder = SessionRecorder( config, connection_records, episode_indices ) + """Utility for recording the session.""" - self.error = "" # Use this to display error that causes termination + self.error = "" + """Field that contains the display error that caused session termination.""" diff --git a/examples/hitl/rearrange_v2/session_recorder.py b/examples/hitl/rearrange_v2/session_recorder.py index ae4ca41e8a..25d381e529 100644 --- a/examples/hitl/rearrange_v2/session_recorder.py +++ b/examples/hitl/rearrange_v2/session_recorder.py @@ -4,6 +4,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass from typing import Any, Dict, List, Optional from util import timestamp @@ -11,59 +14,144 @@ from habitat_hitl.core.types import ConnectionRecord +@dataclass +class SessionRecord: + """ + Data entry for a session. + The session spans from the time the first user is connected to the time the last episode is completed. + """ + + episode_indices: List[int] + session_error: str + start_timestamp: int + end_timestamp: int + config: Dict[str, Any] + frame_count: int + connection_records: Dict[int, ConnectionRecord] + + +@dataclass +class UserRecord: + """ + Data entry for a user. + """ + + user_index: int + connection_record: ConnectionRecord + + +@dataclass +class EpisodeRecord: + """ + Data entry for an episode. + """ + + episode_index: int + episode_id: str + scene_id: str + dataset: str + user_index_to_agent_index_map: Dict[int, int] + episode_info: Dict[str, Any] + start_timestamp: int + end_timestamp: int + finished: bool + task_percent_complete: float + task_explanation: Optional[str] + frame_count: int + + +@dataclass +class SessionOutput: + """ + Content of the `session.json.gz` file. + """ + + session: SessionRecord + users: List[UserRecord] + episodes: List[EpisodeRecord] + + +@dataclass +class EpisodeOutput: + """ + Content of the `{episode_id}.json.gz` file. + """ + + session: SessionRecord + users: List[UserRecord] + episode: EpisodeRecord + frames: List[Dict[str, Any]] + + class SessionRecorder: + """ + Utility class for recording HITL data. + """ + def __init__( self, config: Dict[str, Any], connection_records: Dict[int, ConnectionRecord], episode_indices: List[int], ): - self.data = { - "episode_indices": episode_indices, - "completed": False, - "error": "", - "start_timestamp": timestamp(), - "end_timestamp": timestamp(), - "config": config, - "frame_count": 0, - "users": [], - "episodes": [], - } - - for user_index in range(len(connection_records)): - self.data["users"].append( - { - "user_index": user_index, - "connection_record": connection_records[user_index], - } + time = timestamp() + self.session_record = SessionRecord( + episode_indices=episode_indices, + session_error="", + start_timestamp=time, + end_timestamp=time, + config=config, + frame_count=0, + connection_records=connection_records, + ) + self.episode_records: List[EpisodeRecord] = [] + self.frames: List[List[Dict[str, Any]]] = [] + self.user_records: List[UserRecord] = [] + for user_index, connection_record in connection_records.items(): + self.user_records.append( + UserRecord( + user_index=user_index, connection_record=connection_record + ) ) def end_session(self, error: str): - self.data["end_timestamp"] = timestamp() - self.data["completed"] = True - self.data["error"] = error + """ + Signal that the session has ended. + Call 'get_session_output' and 'get_episode_outputs' to get the collected resulting data. + """ + self.session_record.end_timestamp = timestamp() + self.session_record.session_error = error def start_episode( self, + episode_index: int, episode_id: str, scene_id: str, dataset: str, user_index_to_agent_index_map: Dict[int, int], + episode_info: Dict[str, Any], ): - self.data["episodes"].append( - { - "episode_id": episode_id, - "scene_id": scene_id, - "start_timestamp": timestamp(), - "end_timestamp": timestamp(), - "completed": False, - "success": False, - "frame_count": 0, - "dataset": dataset, - "user_index_to_agent_index_map": user_index_to_agent_index_map, - "frames": [], - } + """ + Signal that an episode has started. + """ + time = timestamp() + self.episode_records.append( + EpisodeRecord( + episode_index=episode_index, + episode_id=episode_id, + scene_id=scene_id, + dataset=dataset, + user_index_to_agent_index_map=user_index_to_agent_index_map, + episode_info=episode_info, + start_timestamp=time, + end_timestamp=time, + finished=False, + task_percent_complete=0.0, + frame_count=0, + task_explanation=None, + ) ) + self.frames.append([]) def end_episode( self, @@ -71,20 +159,60 @@ def end_episode( task_percent_complete: float, task_explanation: Optional[str], ): - self.data["episodes"][-1]["end_timestamp"] = timestamp() - self.data["episodes"][-1]["finished"] = episode_finished - self.data["episodes"][-1][ - "task_percent_complete" - ] = task_percent_complete - self.data["episodes"][-1]["task_explanation"] = task_explanation + """ + Signal that an episode has ended. + """ + assert len(self.episode_records) > 0 + episode = self.episode_records[-1] + + time = timestamp() + self.session_record.end_timestamp = time + episode.end_timestamp = time + episode.finished = episode_finished + episode.task_percent_complete = task_percent_complete + episode.task_explanation = task_explanation def record_frame( self, frame_data: Dict[str, Any], ): - self.data["end_timestamp"] = timestamp() - self.data["frame_count"] += 1 + """ + Signal that a frame has occurred. + """ + assert len(self.episode_records) > 0 + episode_index = len(self.episode_records) - 1 + episode = self.episode_records[episode_index] + + time = timestamp() + self.session_record.end_timestamp = time + self.session_record.frame_count += 1 + episode.end_timestamp = time + episode.frame_count += 1 + + self.frames[episode_index].append(frame_data) - self.data["episodes"][-1]["end_timestamp"] = timestamp() - self.data["episodes"][-1]["frame_count"] += 1 - self.data["episodes"][-1]["frames"].append(frame_data) + def get_session_output(self) -> SessionOutput: + """ + Get the metadata of the session. + """ + return SessionOutput( + self.session_record, + self.user_records, + self.episode_records, + ) + + def get_episode_outputs(self) -> List[EpisodeOutput]: + """ + Get the recorded HITL data. + """ + output: List[EpisodeOutput] = [] + for i in range(len(self.episode_records)): + output.append( + EpisodeOutput( + self.session_record, + self.user_records, + self.episode_records[i], + self.frames[i], + ) + ) + return output diff --git a/habitat-hitl/test/rearrange_v2/test_s3_upload.py b/habitat-hitl/test/rearrange_v2/test_s3_upload.py index 4d5df10db9..268757f288 100644 --- a/habitat-hitl/test/rearrange_v2/test_s3_upload.py +++ b/habitat-hitl/test/rearrange_v2/test_s3_upload.py @@ -9,6 +9,7 @@ from examples.hitl.rearrange_v2.s3_upload import ( generate_unique_session_id, make_s3_filename, + validate_experiment_name, ) from examples.hitl.rearrange_v2.util import timestamp from habitat_hitl.core.types import ConnectionRecord @@ -68,3 +69,22 @@ def test_make_s3_filename(): s3_filename = make_s3_filename("ab", long_name) assert len(s3_filename) == 128 assert s3_filename[-4:] == ".txt" + + +def test_validate_experiment_name(): + assert validate_experiment_name(None) == False + assert validate_experiment_name("test") == True + assert validate_experiment_name("test_test-test.123") == True + assert validate_experiment_name("test?") == False + assert ( + validate_experiment_name( + "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttest" + ) + == True + ) + assert ( + validate_experiment_name( + "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttest" + ) + == False + )