Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HITL - Session Data Improvements #2044

Merged
merged 3 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions examples/hitl/rearrange_v2/app_state_end_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from app_states import create_app_state_reset
from s3_upload import (
generate_unique_session_id,
make_s3_filename,
upload_file_to_s3,
validate_experiment_name,
)
from session import Session
from util import get_top_down_view
Expand Down Expand Up @@ -71,42 +71,58 @@ def _end_session(self):

# Finalize session.
if self._session.error == "":
session.success = True
session.finished = True
session.session_recorder.end_session(self._session.error)

# Get data collection parameters.
try:
config = self._app_service.config
data_collection_config = config.rearrange_v2.data_collection
s3_path = data_collection_config.s3_path
s3_subdir = "complete" if session.success else "incomplete"
s3_path = os.path.join(s3_path, s3_subdir)

# Use the port as a discriminator for when there are multiple concurrent servers.
output_folder_suffix = str(config.habitat_hitl.networking.port)
output_folder = f"output_{output_folder_suffix}"

output_file_name = data_collection_config.output_file_name
output_file = f"{output_file_name}.json.gz"

except Exception as e:
print(f"Invalid data collection config. Skipping S3 upload. {e}")
return
config = self._app_service.config

# Set the base S3 path as the experiment name.
s3_path: str = "default"
if len(session.connection_records) > 0:
experiment_name: Optional[str] = session.connection_records[0].get(
"experiment", None
)
if validate_experiment_name(experiment_name):
s3_path = experiment_name
else:
print(f"Invalid experiment name: '{experiment_name}'")

# Generate unique session ID
session_id = generate_unique_session_id(
session.episode_indices, session.connection_records
)
s3_path = os.path.join(s3_path, session_id)

# Use the port as a discriminator for when there are multiple concurrent servers.
output_folder_suffix = str(config.habitat_hitl.networking.port)
output_folder = f"output_{output_folder_suffix}"

# Delete previous output directory
if os.path.exists(output_folder):
shutil.rmtree(output_folder)

# Create new output directory
os.makedirs(output_folder)
json_path = os.path.join(output_folder, output_file)
save_as_json_gzip(session.session_recorder, json_path)

# Generate unique session ID
session_id = generate_unique_session_id(
session.episode_indices, session.connection_records
# Create a session metadata file.
session_json_path = os.path.join(output_folder, "session.json.gz")
save_as_json_gzip(
session.session_recorder.get_session_output(), session_json_path
)

# Create one file per episode.
episode_outputs = session.session_recorder.get_episode_outputs()
for episode_output in episode_outputs:
episode_json_path = os.path.join(
output_folder,
f"{episode_output.episode.episode_index}.json.gz",
)
save_as_json_gzip(
episode_output,
episode_json_path,
)

# Upload output directory
orig_file_names = [
f
Expand All @@ -115,5 +131,8 @@ def _end_session(self):
]
for orig_file_name in orig_file_names:
local_file_path = os.path.join(output_folder, orig_file_name)
s3_file_name = make_s3_filename(session_id, orig_file_name)
s3_file_name = orig_file_name
print(
f"Uploading '{local_file_path}' to '{s3_path}' as '{s3_file_name}'."
)
upload_file_to_s3(local_file_path, s3_file_name, s3_path)
7 changes: 4 additions & 3 deletions examples/hitl/rearrange_v2/app_state_load_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def sim_update(self, dt: float, post_sim_update_dict):
def _increment_episode(self):
session = self._session
assert session.episode_indices is not None
if session.current_episode_index < len(session.episode_indices):
self._set_episode(session.current_episode_index)
session.current_episode_index += 1
if session.next_session_episode < len(session.episode_indices):
self._set_episode(session.next_session_episode)
session.next_session_episode += 1
else:
self._session_ended = True

Expand All @@ -102,6 +102,7 @@ def _set_episode(self, episode_index: int):

# Set the ID of the next episode to play in lab.
next_episode_index = session.episode_indices[episode_index]
session.current_episode_index = next_episode_index
print(f"Next episode index: {next_episode_index}.")
try:
app_service.episode_helper.set_next_episode_by_index(
Expand Down
13 changes: 9 additions & 4 deletions examples/hitl/rearrange_v2/rearrange_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,14 @@ def on_enter(self):
].gui_agent_controller._agent_idx

episode = self._app_service.episode_helper.current_episode

self._session.session_recorder.start_episode(
episode.episode_id,
episode.scene_id,
episode.scene_dataset_config,
user_index_to_agent_index_map,
episode_index=self._session.current_episode_index,
episode_id=episode.episode_id,
scene_id=episode.scene_id,
dataset=episode.scene_dataset_config,
user_index_to_agent_index_map=user_index_to_agent_index_map,
episode_info=episode.info,
)

def on_exit(self):
Expand Down Expand Up @@ -759,6 +762,8 @@ def sim_update(self, dt: float, post_sim_update_dict):
self._metrics.get_task_percent_complete(),
)
self._session.session_recorder.record_frame(frame_data)
else:
self._session.session_recorder.record_frame({})

def _is_any_agent_policy_driven(self) -> bool:
"""
Expand Down
19 changes: 18 additions & 1 deletion examples/hitl/rearrange_v2/s3_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


import os
from typing import Dict, List
from typing import Dict, List, Optional

from util import timestamp

Expand Down Expand Up @@ -109,3 +109,20 @@ def make_s3_filename(session_id: str, orig_file_name: str) -> str:
s3_filename += "!"

return s3_filename


def validate_experiment_name(experiment_name: Optional[str]) -> bool:
"""
Check whether the given experiment name is valid.
"""
if experiment_name is None:
return False

if len(experiment_name) > 128:
return False

authorized_chars = ["_", "-", "."]
return all(
not (not c.isalnum() and c not in authorized_chars)
for c in experiment_name
)
26 changes: 23 additions & 3 deletions examples/hitl/rearrange_v2/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class Session:
"""
Data for a single RearrangeV2 session.
Data for a single HITL session.
A session is defined as a sequence of episodes done by a fixed set of users.
"""

Expand All @@ -23,13 +23,33 @@ def __init__(
episode_indices: List[int],
connection_records: Dict[int, ConnectionRecord],
):
self.success = False
self.finished = False
"""Whether the session is finished."""

self.episode_indices = episode_indices
"""List of episode indices within the session."""

self.current_episode_index = 0
"""
Current episode index within the episode set.

If there are `1000` episodes, `current_episode_index` would be a value between `0` and `999` inclusively.
"""

self.next_session_episode = 0
0mdc marked this conversation as resolved.
Show resolved Hide resolved
"""
Next index of the `episode_indices` list (element index, not episode index).

If `episode_indices` contains the values `10`, `20` and `30`, `next_session_episode` would be either `0`, `1`, `2` or `3`.
"""

self.connection_records = connection_records
"""Connection records of each user."""

self.session_recorder = SessionRecorder(
config, connection_records, episode_indices
)
"""Utility for recording the session."""

self.error = "" # Use this to display error that causes termination
self.error = ""
"""Field that contains the display error that caused session termination."""
Loading