-
Notifications
You must be signed in to change notification settings - Fork 798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/vla 2 #583
base: user/rcadene/2024_10_07_vla
Are you sure you want to change the base?
Feature/vla 2 #583
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! Good catch on the bug and nicely done with adding observation.states as an input! If you could explain some design choices and the code works after testing/inference, this will be approved 👍
@@ -513,6 +514,49 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso | |||
return actions, (mu, log_sigma_x2) | |||
|
|||
|
|||
class ACTEncoderDecoder(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please share the motivation behind introducing the ACTEncoderDecoder
module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At first I wanted to see if the 0 reward is coming from small action decoder. I think we should compare different design choices for the action decoder and this makes it a bit easier for ACT (ACT decoder only vs ACT encoder decoder)
self.action_decoder = ACTDecoder(action_decoder_config) | ||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.action_decoder["dim_model"]) | ||
self.use_robot_state = "observation.state" in config.input_shapes | ||
if "act" in self.action_decoder_name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this related to the new ActEncoderDecoder?
To me this if-else block related to act looks a bit confusing, we could rather rename it completely or structure in this way
if "act" in self.action_decoder_name:
action_decoder_config = OmegaConf.create(config.action_decoder)
if self.action_decoder_name == "act_decoder":
# Use standalone ACTDecoder
self.action_decoder = ACTDecoder(action_decoder_config)
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.action_decoder["dim_model"])
else:
# Use ACTEncoderDecoder, decide whether to include the encoder
use_encoder = "decoder" not in self.action_decoder_name
self.action_decoder = ACTEncoderDecoder(action_decoder_config, use_encoder=use_encoder)
Or even better: if ActDecoder is equivalent to ActEncoderDecoder
with use_encoder=False
then we can remove the use of ActDecoder completely and leave only the option use_encoder for act to avoid confusion and redundancy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, ACTEncoderDecoder is more general but I kept it to be able your old checkpoints without changing its keys. If this is not needed anymore (e.g. we trained a good VLA with this branch) we should keep only the ACTEncoderDecoder
@@ -1,14 +1,6 @@ | |||
# @package _global_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding this line helped avoid explicitly using the name attribute (e.g., policy.name = vla
or env.name = aloha
), which was causing various issues with override_dataset_stats
. The exact reason for these issues is still unclear to me. If you have any idea why this happens, I’d be happy to learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, but when you add this line you make the parameters global and hence you can access them directly. For example policy.n_obs_steps instead of e.g. policy.vla.policy.n_obs_steps
@@ -91,6 +92,7 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |||
if len(self.expected_image_keys) > 0: | |||
batch = dict(batch) | |||
batch["observation.images"] = [img for k in self.expected_image_keys for img in batch[k]] | |||
batch = self.normalize_targets(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still testing the code to verify that the reward 0 was indeed related to not normalizing targets.
What this does
How it was tested
How to checkout & try? (for the reviewer)