Skip to content

Commit

Permalink
add predict model
Browse files Browse the repository at this point in the history
  • Loading branch information
czyuyue committed Oct 20, 2024
1 parent 742c753 commit b61f8a6
Show file tree
Hide file tree
Showing 15 changed files with 813 additions and 67 deletions.
32 changes: 16 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ You can find all scripted/human demo for simulated environments [here](https://d

conda create -n aloha python=3.8.10
conda activate aloha
pip install torchvision
pip install torch
pip install pyquaternion
pip install pyyaml
pip install rospkg
pip install pexpect
pip install mujoco==2.3.7
pip install dm_control==1.0.14
pip install opencv-python
pip install matplotlib
pip install einops
pip install packaging
pip install h5py
pip install ipython
cd act/detr && pip install -e .
uv pip install torchvision
uv pip install torch
uv pip install pyquaternion
uv pip install pyyaml
uv pip install rospkg
uv pip install pexpect
uv pip install mujoco==2.3.7
uv pip install dm_control==1.0.14
uv pip install opencv-python
uv pip install matplotlib
uv pip install einops
uv pip install packaging
uv pip install h5py
uv pip install ipython
cd act/detr && uv pip install -e .

### Example Usages

Expand All @@ -68,7 +68,7 @@ To visualize the episode after it is collected, run
python3 visualize_episodes.py --dataset_dir <data save dir> --episode_idx 0

To train ACT:

# Transfer Cube task
python3 imitate_episodes.py \
--task_name sim_transfer_cube_scripted \
Expand Down
2 changes: 1 addition & 1 deletion constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib

### Task parameters
DATA_DIR = '<put your data dir here>'
DATA_DIR = '/localdata/yy/datasets/aloha'
SIM_TASK_CONFIGS = {
'sim_transfer_cube_scripted':{
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
Expand Down
27 changes: 24 additions & 3 deletions detr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import torch
from .models import build_ACT_model, build_CNNMLP_model
from .models import build_ACT_model, build_CNNMLP_model,build_prediction_model

import IPython
e = IPython.embed
Expand Down Expand Up @@ -64,16 +64,19 @@ def get_args_parser():
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
parser.add_argument('--temporal_agg', action='store_true')

## newly added by yuyue
parser.add_argument('--query_freq', action='store', type=int, help='query_freq', required=False)
parser.add_argument('--decay_rate', action='store', type=float, help='decay_rate', required=False)
return parser


def build_ACT_model_and_optimizer(args_override):
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()

for k, v in args_override.items():
setattr(args, k, v)

# print(args.backbone, " args.backbone\n")
# exit(0)
model = build_ACT_model(args)
model.cuda()

Expand All @@ -89,6 +92,24 @@ def build_ACT_model_and_optimizer(args_override):

return model, optimizer

def build_ACT_prediction_model_and_optimizer(args_override):
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
for k, v in args_override.items():
setattr(args, k, v)
model = build_prediction_model(args)
model.cuda()
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)

return model, optimizer

def build_CNNMLP_model_and_optimizer(args_override):
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
Expand Down
7 changes: 5 additions & 2 deletions detr/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .detr_vae import build as build_vae
from .detr_vae import build_cnnmlp as build_cnnmlp

from .detr_vae import build_prediction_model as build_P
def build_ACT_model(args):
return build_vae(args)

def build_CNNMLP_model(args):
return build_cnnmlp(args)
return build_cnnmlp(args)

def build_prediction_model(args):
return build_P(args)
2 changes: 2 additions & 0 deletions detr/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,6 @@ def build_backbone(args):
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
print(model.num_channels,end=" num channels ???????????????????????????\n")
print(args,end=" backbone args ???????????????????????????\n")
return model
106 changes: 104 additions & 2 deletions detr/models/detr_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import nn
from torch.autograd import Variable
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer , Transformer

import numpy as np

Expand Down Expand Up @@ -56,6 +56,10 @@ def __init__(self, backbones, transformer, encoder, state_dim, num_queries, came
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
print(self.backbones, backbones,end=" backbones\n")
# for param in self.backbones[0].parameters():
# assert param.requires_grad == False
# print(param.requires_grad,end=" param\n")
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
Expand Down Expand Up @@ -83,6 +87,7 @@ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
print(qpos.shape, " qpos\n")
bs, _ = qpos.shape
### Obtain latent z from action sequence
if is_training:
Expand Down Expand Up @@ -112,7 +117,7 @@ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)

print(latent_input.shape, " latent_input\n")
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
Expand All @@ -122,20 +127,33 @@ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
print(features.shape, " features\n")
print(all_cam_features[-1].shape, " pos\n")
all_cam_pos.append(pos)
# proprioception features
print(self.camera_names,end=" camera_names\n")
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
print(src.shape, " src\n")
print(pos.shape, " pos\n")
print(latent_input.shape, " latent_input\n")
print(proprio_input.shape, " proprio_input\n")
print(self.additional_pos_embed.weight.shape, " additional_pos_embed\n")
print(self.query_embed.weight.shape, " query_embed\n")
print(self.transformer, " transformer\n")
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2

hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
print(hs.shape, " hs\n")
print(a_hat.shape, " a_hat\n")
return a_hat, is_pad_hat, [mu, logvar]


Expand Down Expand Up @@ -276,3 +294,87 @@ def build_cnnmlp(args):

return model

class DynamicLatentModel(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.model = None
self.optimizer = None
self.latent_dim = 512
self.state_dim = 14
self.chunk_size = 100
self.transformer = Transformer(d_model=self.latent_dim,
nhead=8,
num_encoder_layers=4,
num_decoder_layers=7,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False)
self.camera_names = ['top']
# Embedding layer for state
self.embed = nn.Embedding(self.state_dim, self.latent_dim)
self.query_embed = nn.Embedding(self.chunk_size, self.latent_dim)
# Backbone
print(args, "args\n")
backbones = []
backbone = build_backbone(args)
backbones = [backbone]
self.backbones = nn.ModuleList(backbones)
self.input_proj = nn.Conv2d(self.backbones[0].num_channels, self.latent_dim, kernel_size=1)
self.input_proj_robot_state = nn.Linear(self.state_dim, self.latent_dim)

self.features_head = nn.Linear(self.latent_dim, self.backbones[0].num_channels)
# self.is_pad_head = nn.Linear(self.latent_dim, 1)
self.additional_pos_embed = nn.Embedding(1, self.latent_dim)

def forward(self, qpos, image):
## input image, qpos
## output: latent of chunk of images [batch_size, chunk_size, latent_dim]
# [batch_size, latent_dim]
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
# print(features.shape, " features\n")
# print(all_cam_features[-1].shape, " pos\n")
all_cam_pos.append(pos)
# proprioception features
# print(self.camera_names,end=" camera_names\n")
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
proprio_input = self.input_proj_robot_state(qpos)
# print(src.shape, " src\n")
# print(pos.shape, " pos\n")
# print(latent_input.shape, " latent_input\n")
# print(proprio_input.shape, " proprio_input\n")
# print(self.additional_pos_embed.weight.shape, " additional_pos_embed\n")
# print(self.query_embed.weight.shape, " query_embed\n")
# print(self.transformer, " transformer\n")
hs = self.transformer(src, None, self.query_embed.weight, pos, None, proprio_input, self.additional_pos_embed.weight)
# print(hs.shape, " hs\n") # 100 512 8-> 8 100 512
hs = hs.permute(2, 0, 1)
features = self.features_head(hs)
return features
def get_features(self, image):
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
feat = torch.nn.functional.adaptive_avg_pool2d(src, (1, 1)).squeeze(-1).squeeze(-1)
print(feat.shape, "feat")
return feat
def build_prediction_model(args):
model = DynamicLatentModel(args)
return model
24 changes: 18 additions & 6 deletions detr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _reset_parameters(self):
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# print(src.shape,end=" src\n")
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
Expand All @@ -58,9 +59,14 @@ def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_

additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)

addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
if latent_input is not None and proprio_input is not None:
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
if latent_input is None and proprio_input is not None:
# print("proprio_input",proprio_input.shape)
# print("src",src.shape)
addition_input = proprio_input.unsqueeze(0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3
# flatten NxHWxC to HWxNxC
Expand All @@ -70,9 +76,13 @@ def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)

tgt = torch.zeros_like(query_embed)
# print(tgt.shape,end=" tgt\n ????")
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
# print(src.shape,end=" src\n")
# print(memory.shape,end=" memory\n")
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
# print(hs.shape,end=" hs after decoder\n")
hs = hs.transpose(1, 2)
return hs

Expand Down Expand Up @@ -117,7 +127,9 @@ def forward(self, tgt, memory,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt

# print(output.shape,end=" start ------\n")
# print(query_pos.shape,end=" query_pos\n")
# print(memory.shape,end=" memory\n")
intermediate = []

for layer in self.layers:
Expand All @@ -137,8 +149,8 @@ def forward(self, tgt, memory,

if self.return_intermediate:
return torch.stack(intermediate)

return output.unsqueeze(0)
# print(output.shape ,end=" end ------\n")
return output


class TransformerEncoderLayer(nn.Module):
Expand Down
Loading

0 comments on commit b61f8a6

Please sign in to comment.