Skip to content

Commit

Permalink
Merge pull request #14 from rail-berkeley/fix_proprio_norm
Browse files Browse the repository at this point in the history
add proprio keys to normalization dump filename
  • Loading branch information
kpertsch authored Sep 17, 2023
2 parents 311b109 + 67057ae commit 1fb1752
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 13 deletions.
10 changes: 9 additions & 1 deletion experiments/main/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,16 @@ def process_text(batch):
batch.pop("language_instruction")
return batch

train_data = make_dataset(**FLAGS.config.dataset_kwargs, train=True)
action_proprio_metadata = train_data.action_proprio_metadata
if save_dir is not None:
with tf.io.gfile.GFile(
os.path.join(save_dir, "action_proprio_metadata.json"), "w"
) as f:
json.dump(action_proprio_metadata, f)

train_data = (
make_dataset(**FLAGS.config.dataset_kwargs, train=True)
(train_data)
.unbatch()
.shuffle(FLAGS.config.shuffle_buffer_size)
.repeat()
Expand Down
5 changes: 1 addition & 4 deletions experiments/main/visualization_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@ class Visualizer:

def __post_init__(self):
self.dataset = make_dataset(**self.dataset_kwargs, train=False, shuffle=False)
builder = tfds.builder(
self.dataset_kwargs["name"], data_dir=self.dataset_kwargs["data_dir"]
)
self.action_proprio_stats = get_action_proprio_stats(builder, None)
self.action_proprio_stats = self.dataset.action_proprio_metadata
self.trajs, self.viz_trajs = [], []
self.visualized_trajs = False

Expand Down
26 changes: 20 additions & 6 deletions orca/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import hashlib
import inspect
import json
import logging
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union

import dlimp as dl
import numpy as np
Expand All @@ -16,10 +17,16 @@


def get_action_proprio_stats(
builder: DatasetBuilder, dataset: tf.data.Dataset
builder: DatasetBuilder,
dataset: tf.data.Dataset,
proprio_keys: List[str],
transform_fcn: Any,
) -> Dict[str, Dict[str, List[float]]]:
# get statistics file path --> embed unique hash that catches if dataset info changed
data_info_hash = hashlib.sha256(str(builder.info).encode("utf-8")).hexdigest()
# get statistics file path --> embed unique hash that catches if dataset info / keys / transform changed
transform_str = inspect.getsource(transform_fcn) if transform_fcn else ""
data_info_hash = hashlib.sha256(
(str(builder.info) + str(proprio_keys) + str(transform_str)).encode("utf-8")
).hexdigest()
path = tf.io.gfile.join(
builder.info.data_dir, f"action_proprio_stats_{data_info_hash}.json"
)
Expand Down Expand Up @@ -202,6 +209,7 @@ def make_dataset(
image_obs_keys: Union[str, List[str]] = [],
depth_obs_keys: Union[str, List[str]] = [],
state_obs_keys: Union[str, List[str]] = [],
action_proprio_metadata: Optional[dict] = None,
**kwargs,
) -> tf.data.Dataset:
"""Creates a dataset from the RLDS format.
Expand Down Expand Up @@ -272,11 +280,17 @@ def restructure(traj):
return traj

dataset = dataset.map(restructure)

action_proprio_metadata = get_action_proprio_stats(builder, dataset)
if action_proprio_metadata is None:
action_proprio_metadata = get_action_proprio_stats(
builder,
dataset,
state_obs_keys,
RLDS_TRAJECTORY_MAP_TRANSFORMS.get(name, None),
)

dataset = apply_common_transforms(
dataset, train=train, action_proprio_metadata=action_proprio_metadata, **kwargs
)
dataset.action_proprio_metadata = action_proprio_metadata

return dataset
4 changes: 2 additions & 2 deletions orca/data/dataset_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def r2_d2_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["action"] = tf.concat(
(
trajectory["action_dict"]["cartesian_position"],
trajectory["action_dict"]["gripper_position"],
trajectory["action_dict"]["cartesian_velocity"],
trajectory["action_dict"]["gripper_velocity"],
),
axis=-1,
)
Expand Down

0 comments on commit 1fb1752

Please sign in to comment.