Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/vla 2 #583

Open
wants to merge 31 commits into
base: user/rcadene/2024_10_07_vla
Choose a base branch
from

Conversation

mshukor
Copy link

@mshukor mshukor commented Dec 16, 2024

What this does

  • Fix the reward 0 bug (due to not normalizing the targets)
  • Support encoder and decoder for ACT
  • Support robot states as input to the action decoder
  • Some features related to loading hf models

How it was tested


ENV=aloha
ENV_TASK=AlohaTransferCube-v0
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human


policy=vla
LR=3e-5 #1e-5
LR_SCHEDULER=
USE_AMP=true
PRECISION=fp16

ASYNC_ENV=false

FEAT_SELECT=all_generated


VLM=google/paligemma2-3b-pt-224
VLM_NAME=paligemma2_3b
VLM_DIM=2304
NUM_IMG_TOKENS=598

USE_PROMNPT_TEMPLATE=false

ACTION_DECODER=act_decoder

DIM_MODEL=512
LORA_R=4

PEFT_METHOD=lora


USE_ACTION_CONNECTOR=true

TASK_NAME=lerobot_${ENV}_transfer_cube_${policy}_${ACTION_DECODER}_${VLM_NAME}_${PEFT_METHOD}_feat_select_${FEAT_SELECT}

GPUS=1
EVAL_FREQ=5000 #51000 #10000 51000
OFFLINE_STEPS=100000 #25000 17000 12500 50000
TRAIN_BATCH_SIZE=8
EVAL_BATCH_SIZE=8

SAVE_FREQ=5000


MUJOCO_GL=egl python lerobot/scripts/train.py \
 hydra.job.name=base_distributed_aloha_transfer_cube \
 hydra.run.dir=$WORK/logs/lerobot/${TASK_NAME} \
 dataset_repo_id=$dataset_repo_id \
 policy=$policy \
 env=$ENV env.task=$ENV_TASK \
 training.offline_steps=$OFFLINE_STEPS training.batch_size=$TRAIN_BATCH_SIZE training.save_freq=$SAVE_FREQ \
 training.eval_freq=$EVAL_FREQ eval.n_episodes=50 eval.use_async_envs=$ASYNC_ENV eval.batch_size=$EVAL_BATCH_SIZE \
 training.lr=$LR training.lr_backbone=$LR \
 wandb.enable=false use_amp=$USE_AMP precision=$PRECISION \
 policy.vlm_backbone.feature_selection=$FEAT_SELECT policy.vlm_backbone.name=$VLM policy.action_decoder.dim_model=$DIM_MODEL \
 policy.use_prompt_template=$USE_PROMNPT_TEMPLATE  policy.num_img_tokens=$NUM_IMG_TOKENS policy.peft_config.r=$LORA_R policy.peft_method=$PEFT_METHOD \
 policy.use_action_connector=$USE_ACTION_CONNECTOR policy.vlm_backbone.hidden_size=$VLM_DIM  policy.action_decoder.name=$ACTION_DECODER 



How to checkout & try? (for the reviewer)

@danaaubakirova danaaubakirova self-requested a review December 17, 2024 18:08
Copy link
Collaborator

@danaaubakirova danaaubakirova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! Good catch on the bug and nicely done with adding observation.states as an input! If you could explain some design choices and the code works after testing/inference, this will be approved 👍

@@ -513,6 +514,49 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
return actions, (mu, log_sigma_x2)


class ACTEncoderDecoder(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share the motivation behind introducing the ACTEncoderDecoder module?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first I wanted to see if the 0 reward is coming from small action decoder. I think we should compare different design choices for the action decoder and this makes it a bit easier for ACT (ACT decoder only vs ACT encoder decoder)

self.action_decoder = ACTDecoder(action_decoder_config)
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.action_decoder["dim_model"])
self.use_robot_state = "observation.state" in config.input_shapes
if "act" in self.action_decoder_name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this related to the new ActEncoderDecoder?

To me this if-else block related to act looks a bit confusing, we could rather rename it completely or structure in this way

if "act" in self.action_decoder_name:
   action_decoder_config = OmegaConf.create(config.action_decoder)       
   if self.action_decoder_name == "act_decoder":
          # Use standalone ACTDecoder
          self.action_decoder = ACTDecoder(action_decoder_config)
          self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.action_decoder["dim_model"])
      else:
          # Use ACTEncoderDecoder, decide whether to include the encoder
          use_encoder = "decoder" not in self.action_decoder_name
          self.action_decoder = ACTEncoderDecoder(action_decoder_config, use_encoder=use_encoder)

Or even better: if ActDecoder is equivalent to ActEncoderDecoder with use_encoder=False then we can remove the use of ActDecoder completely and leave only the option use_encoder for act to avoid confusion and redundancy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ACTEncoderDecoder is more general but I kept it to be able your old checkpoints without changing its keys. If this is not needed anymore (e.g. we trained a good VLA with this branch) we should keep only the ACTEncoderDecoder

@@ -1,14 +1,6 @@
# @package _global_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this line helped avoid explicitly using the name attribute (e.g., policy.name = vla or env.name = aloha), which was causing various issues with override_dataset_stats. The exact reason for these issues is still unclear to me. If you have any idea why this happens, I’d be happy to learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but when you add this line you make the parameters global and hence you can access them directly. For example policy.n_obs_steps instead of e.g. policy.vla.policy.n_obs_steps

@@ -91,6 +92,7 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
if len(self.expected_image_keys) > 0:
batch = dict(batch)
batch["observation.images"] = [img for k in self.expected_image_keys for img in batch[k]]
batch = self.normalize_targets(batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still testing the code to verify that the reward 0 was indeed related to not normalizing targets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants