Skip to content

Commit

Permalink
Fixing finetune script and allow for not passing in goal images
Browse files Browse the repository at this point in the history
  • Loading branch information
dibyaghosh committed Dec 14, 2023
1 parent 1e644f6 commit 4c8fd96
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
29 changes: 16 additions & 13 deletions examples/02_finetune_new_observation_action.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
This script demonstrates how to finetune Octo to a new observation space (single camera + proprio)
and new action space (bimanual) using a simulated ALOHA cube handover dataset (https://tonyzhaozh.github.io/aloha/).
To run this example, first download and extract the dataset from here: <TODO>
python examples/02_finetune_new_observation_action.py --pretrained_path=hf://rail-berkeley/octo-small --data_dir=...
"""
from absl import app, flags, logging
import flax
Expand All @@ -11,7 +15,7 @@
import wandb

from octo.data.dataset import make_single_dataset
from octo.data.oxe.oxe_dataset_configs import ActionEncoding, StateEncoding
from octo.data.utils.data_utils import NormalizationType
from octo.model.components.action_heads import L1ActionHead
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.model.octo_model import OctoModel
Expand All @@ -31,6 +35,8 @@
)
flags.DEFINE_string("data_dir", None, "Path to finetuning dataset, in RLDS format.")
flags.DEFINE_string("save_dir", None, "Directory for saving finetuning checkpoints.")
flags.DEFINE_integer("batch_size", 128, "Batch size for finetuning.")

flags.DEFINE_bool(
"freeze_transformer",
False,
Expand All @@ -39,6 +45,10 @@


def main(_):
assert (
FLAGS.batch_size % jax.device_count() == 0
), "Batch size must be divisible by device count."

initialize_compilation_cache()
# prevent tensorflow from using GPU memory since it's only used for data loading
tf.config.set_visible_devices([], "GPU")
Expand All @@ -62,19 +72,12 @@ def main(_):
image_obs_keys={"primary": "top"},
state_obs_keys=["state"],
language_key="language_instruction",
state_encoding=StateEncoding.JOINT_BIMANUAL,
action_encoding=ActionEncoding.JOINT_POS_BIMANUAL,
action_proprio_normalization_type="normal",
action_proprio_normalization_type=NormalizationType.NORMAL,
absolute_action_mask=[True] * 14,
),
traj_transform_kwargs=dict(
window_size=1,
future_action_window_size=49, # so we get 50 actions for our action chunk
goal_relabeling_strategy="no_image_conditioning", # train only language-conditioned policy
action_encoding=ActionEncoding.JOINT_POS_BIMANUAL,
task_augment_strategy="delete_task_conditioning",
task_augment_kwargs=dict(
keep_image_prob=0.0 # delete goal images in task definition
),
),
frame_transform_kwargs=dict(
resize_size={"primary": (256, 256)},
Expand All @@ -84,8 +87,8 @@ def main(_):
train_data_iter = (
dataset.repeat()
.unbatch()
.shuffle(100000) # can reduce this if RAM consumption too high
.batch(128)
.shuffle(10000) # can reduce this if RAM consumption too high
.batch(FLAGS.batch_size)
.iterator()
)

Expand Down Expand Up @@ -159,7 +162,7 @@ def loss_fn(params, batch, rng, train=True):
bound_module = model.module.bind({"params": params}, rngs={"dropout": rng})
transformer_embeddings = bound_module.octo_transformer(
batch["observation"],
batch["tasks"],
batch["task"],
batch["observation"]["pad_mask"],
train=train,
)
Expand Down
4 changes: 0 additions & 4 deletions examples/03_eval_finetuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
Xvfb :1 -screen 0 1024x768x16 &
export DISPLAY=:1
"""
import sys

from absl import app, flags, logging
import gym
import jax
import numpy as np
import wandb

sys.path.append("/nfs/nfs2/users/karl/code/act")

from octo.model.octo_model import OctoModel
from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio

Expand Down
10 changes: 10 additions & 0 deletions octo/model/components/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from typing import Dict, Optional, Sequence

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -115,6 +116,15 @@ def extract_inputs(keys, inputs, check_spatial=False):
# stack all spatial observation and task inputs
enc_inputs = extract_inputs(obs_stack_keys, observations, check_spatial=True)
if tasks and self.task_stack_keys:
matched_obs_keys = regex_filter(self.task_stack_keys, observations.keys())
for k in matched_obs_keys:
if k not in tasks:
logging.info(
f"No task inputs matching {k} were found. Replacing with zero padding."
)
tasks = flax.core.copy(
tasks, {k: jnp.zeros_like(observations[k][:, 0])}
)
task_stack_keys = regex_filter(self.task_stack_keys, sorted(tasks.keys()))
if len(task_stack_keys) == 0:
raise ValueError(
Expand Down

0 comments on commit 4c8fd96

Please sign in to comment.