From bee10c7ed23eb08127cbb3f5e51a3bc364f6063c Mon Sep 17 00:00:00 2001 From: ofir ozeri Date: Tue, 10 Dec 2024 19:59:47 +0200 Subject: [PATCH] a refactored modeling_act for cpu and memory optimization --- lerobot/common/policies/act/modeling_act.py | 148 ++++++++++++++------ 1 file changed, 107 insertions(+), 41 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 418863a14..1346a7c7e 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -81,7 +81,12 @@ def __init__( self.model = ACT(config) - self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + # Pre-compute and register expected image keys + self.register_buffer( + "expected_image_keys", + torch.tensor([k.startswith("observation.image") for k in config.input_shapes]) + ) + # self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] if config.temporal_ensemble_coeff is not None: self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) @@ -106,9 +111,14 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: self.eval() batch = self.normalize_inputs(batch) - if len(self.expected_image_keys) > 0: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + if self.expected_image_keys.any(): + batch = dict(batch) + keys = [k for k, is_img in zip(self.config.input_shapes.keys(), self.expected_image_keys) if is_img] + batch["observation.images"] = torch.stack([batch[k] for k in keys], dim=-4) + + # if len(self.expected_image_keys) > 0: + # batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + # batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) # If we are doing temporal ensembling, do online updates where we keep track of the number of actions # we are ensembling over. @@ -134,9 +144,15 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - if len(self.expected_image_keys) > 0: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + if self.expected_image_keys.any(): + batch = dict(batch) + keys = [k for k, is_img in zip(self.config.input_shapes.keys(), self.expected_image_keys) if is_img] + batch["observation.images"] = torch.stack([batch[k] for k in keys], dim=-4) + + # if len(self.expected_image_keys) > 0: + # batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + # batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) @@ -151,7 +167,9 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # KL-divergence per batch element, then take the mean over the batch. # (See App. B of https://arxiv.org/abs/1312.6114 for more details). mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())) + .sum(-1) + .mean() ) loss_dict["kld_loss"] = mean_kld.item() loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight @@ -161,7 +179,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: return loss_dict -class ACTTemporalEnsembler: +class ACTTemporalEnsembler(nn.Module): def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. @@ -204,9 +222,27 @@ def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: print("online", avg) ``` """ + super().__init__() self.chunk_size = chunk_size - self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) - self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) + # TODO: # These lines are redundant since we register them as buffers right after + # self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + # self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) + + # Register weights as buffers instead of attributes to improve prefoemence + self.register_buffer( + "ensemble_weights", + torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + ) + self.register_buffer( + "ensemble_weights_cumsum", + torch.cumsum(self.ensemble_weights, dim=0) + ) + self.register_buffer( + "ones_template", + torch.ones((chunk_size, 1), dtype=torch.long) + ) + + self.reset() def reset(self): @@ -220,17 +256,20 @@ def update(self, actions: Tensor) -> Tensor: Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all time steps, and pop/return the next batch of actions in the sequence. """ - self.ensemble_weights = self.ensemble_weights.to(device=actions.device) - self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + + # TODO: Remove assumimng upgrade of tensores working + # self.ensemble_weights = self.ensemble_weights.to(device=actions.device) + # self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) if self.ensembled_actions is None: # Initializes `self._ensembled_action` to the sequence of actions predicted during the first # time step of the episode. - self.ensembled_actions = actions.clone() + self.ensembled_actions = actions # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor # operations later. - self.ensembled_actions_count = torch.ones( - (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device - ) + self.ensembled_actions_count = self.ones_template.to(self.ensembled_actions.device) + # self.ensembled_actions_count = torch.ones( + # (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device + # ) else: # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute # the online update for those entries. @@ -367,6 +406,13 @@ def __init__(self, config: ACTConfig): # Final action regression head on the output of the transformer's decoder. self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0]) + self.register_buffer("zero_latent", torch.zeros(1, config.latent_dim)) + self.register_buffer("decoder_template", torch.zeros(config.chunk_size, 1, config.dim_model)) + self.register_buffer( + "cls_pad_template", + torch.full((1, 2 if self.use_robot_state else 1), False) + ) + self._reset_parameters() def _reset_parameters(self): @@ -424,16 +470,19 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso # Prepare fixed positional embedding. # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. - pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) + pos_embed = self.vae_encoder_pos_enc # (1, S+2, D) # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the # sequence depending whether we use the input states or not (cls and robot state) # False means not a padding token. - cls_joint_is_pad = torch.full( - (batch_size, 2 if self.use_robot_state else 1), - False, - device=batch["observation.state"].device, - ) + + cls_joint_is_pad = self.cls_pad_template.expand(batch_size, -1) + + # cls_joint_is_pad = torch.full( + # (batch_size, 2 if self.use_robot_state else 1), + # False, + # device=batch["observation.state"].device, + # ) key_padding_mask = torch.cat( [cls_joint_is_pad, batch["action_is_pad"]], axis=1 ) # (bs, seq+1 or 2) @@ -455,9 +504,10 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer - latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( - batch["observation.state"].device - ) + latent_sample = self.zero_latent.expand(batch_size, -1) + # latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( + # batch["observation.state"].device + # ) # Prepare transformer encoder inputs. encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] @@ -480,7 +530,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"] # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use # buffer - cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) all_cam_features.append(cam_features) all_cam_pos_embeds.append(cam_pos_embed) @@ -498,11 +548,12 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed) # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer - decoder_in = torch.zeros( - (self.config.chunk_size, batch_size, self.config.dim_model), - dtype=encoder_in_pos_embed.dtype, - device=encoder_in_pos_embed.device, - ) + decoder_in = self.decoder_template.expand(-1, batch_size, -1) + # decoder_in = torch.zeros( + # (self.config.chunk_size, batch_size, self.config.dim_model), + # dtype=encoder_in_pos_embed.dtype, + # device=encoder_in_pos_embed.device, + # ) decoder_out = self.decoder( decoder_in, encoder_out, @@ -708,6 +759,20 @@ def __init__(self, dimension: int): # Inverse "common ratio" for the geometric progression in sinusoid frequencies. self._temperature = 10000 + # Register arange buffer to avoid device transfer + self.register_buffer('_pi_tensor', torch.tensor([self._two_pi])) + self.register_buffer('_eps_tensor', torch.tensor([self._eps])) + self.register_buffer('_temp_tensor', torch.tensor([self._temperature])) + self.register_buffer( + "dim_arange", + torch.arange(dimension, dtype=torch.float32) + ) + self.register_buffer( + "inverse_frequency", + self._temp_tensor ** (2 * (self.dim_arange // 2) / self.dimension) + ) + + def forward(self, x: Tensor) -> Tensor: """ Args: @@ -718,21 +783,22 @@ def forward(self, x: Tensor) -> Tensor: not_mask = torch.ones_like(x[0, :1]) # (1, H, W) # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. - y_range = not_mask.cumsum(1, dtype=torch.float32) - x_range = not_mask.cumsum(2, dtype=torch.float32) + y_range = not_mask.cumsum(1) + x_range = not_mask.cumsum(2) # "Normalize" the position index such that it ranges in [0, 2π]. # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range # are non-zero by construction. This is an artifact of the original code. - y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi - x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi + y_range = y_range / (y_range[:, -1:, :] + self._eps_tensor) * self._pi_tensor + x_range = x_range / (x_range[:, :, -1:] + self._eps_tensor) * self._pi_tensor - inverse_frequency = self._temperature ** ( - 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension - ) + + # inverse_frequency = self._temperature ** ( + # 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + # ) - x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) - y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + x_range = x_range.unsqueeze(-1) / self.inverse_frequency # (1, H, W, 1) + y_range = y_range.unsqueeze(-1) / self.inverse_frequency # (1, H, W, 1) # Note: this stack then flatten operation results in interleaved sine and cosine terms. # pos_embed_x and pos_embed_y are (1, H, W, C // 2).