diff --git a/torchdrive/models/vista.py b/torchdrive/models/vista.py index 68c642d..716d76a 100644 --- a/torchdrive/models/vista.py +++ b/torchdrive/models/vista.py @@ -1,6 +1,6 @@ import os.path -from typing import Tuple import time +from typing import Tuple import torch import torch.nn.functional as F diff --git a/torchdrive/tasks/diff_traj.py b/torchdrive/tasks/diff_traj.py index 29ab504..f8c2795 100644 --- a/torchdrive/tasks/diff_traj.py +++ b/torchdrive/tasks/diff_traj.py @@ -1,13 +1,15 @@ import math import os.path -from collections import OrderedDict +from collections import defaultdict, OrderedDict from typing import Dict, List, Optional, Tuple import matplotlib.pyplot as plt import torch import torch.nn.functional as F +import torchmetrics from diffusers import EulerDiscreteScheduler +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_euler_angles from safetensors.torch import load_model from torch import nn from torch.utils.tensorboard import SummaryWriter @@ -19,6 +21,7 @@ from torchdrive.models.mlp import MLP from torchdrive.models.path import XYEncoder from torchdrive.models.transformer import transformer_init +from torchdrive.models.vista import VistaSampler from torchdrive.tasks.context import Context from torchdrive.tasks.van import Van from torchdrive.transforms.batch import Compose, ImageTransform, NormalizeCarPosition @@ -32,10 +35,8 @@ render_color, render_pca, ) -from torchworld.transforms.transform3d import Transform3d -from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_euler_angles from torchworld.transforms.mask import random_block_mask, true_mask -from torchdrive.models.vista import VistaSampler +from torchworld.transforms.transform3d import Transform3d def square_mask(mask: torch.Tensor, num_heads: int) -> torch.Tensor: @@ -735,6 +736,7 @@ def __init__( num_inference_timesteps: int = 50, num_train_timesteps: int = 1000, max_seq_len: int = 256, + test: bool = False, ): super().__init__() @@ -749,91 +751,22 @@ def __init__( self.model = ConvNextPathPred() - self.vista = VistaSampler() - - """ - self.encoders = nn.ModuleDict( - { - cam: MaskViT( - cam_shape=cam_shape, - dim=dim, - attention_dropout=0.1, - ) - for cam in cameras - } - ) - - # embedding - self.xy_embedding = XYGMMEncoder(dim=dim, max_dist=128.0) - - self.denoiser = Denoiser( - max_seq_len=max_seq_len, - num_layers=num_layers, - num_heads=num_heads, - dim=dim, - mlp_dim=dim_feedforward, - attention_dropout=dropout, - ) - - self.static_features_encoder = nn.Sequential( - nn.Linear(1, dim), - nn.ReLU(inplace=True), - nn.Linear(dim, dim), - ) - self.query_embed = nn.Parameter( - torch.empty(max_seq_len, dim).normal_(std=0.02) - ) - - self.noise_scheduler = EulerDiscreteScheduler( - num_train_timesteps=num_train_timesteps - ) - self.noise_scheduler.set_timesteps(num_train_timesteps) - self.eval_noise_scheduler = EulerDiscreteScheduler( - num_train_timesteps=num_train_timesteps - ) - self.eval_noise_scheduler.set_timesteps(self.num_inference_timesteps) - """ + if not test: + self.vista = VistaSampler() self.batch_transform = Compose( NormalizeCarPosition(start_frame=0), - #ImageTransform( + # ImageTransform( # v2.RandomRotation(15, InterpolationMode.BILINEAR), # v2.RandomErasing(), - #), + # ), ) + self.test_mae = torchmetrics.MeanAbsoluteError() + self.test_mse = torchmetrics.MeanSquaredError() + self.test_losses = defaultdict(lambda: torchmetrics.aggregation.MeanMetric()) + def param_opts(self, lr: float) -> List[Dict[str, object]]: - """ - return [ - { - "name": "encoders", - "params": list(self.encoders.parameters()), - "lr": lr, - "weight_decay": 1e-4, - }, - { - "name": "static_features", - "params": list(self.static_features_encoder.parameters()), - "lr": lr, - }, - { - "name": "query", - "params": [self.query_embed], - "lr": lr, - }, - { - "name": "denoiser", - "params": list(self.denoiser.parameters()), - "lr": lr, - "weight_decay": 1e-4, - }, - { - "name": "xy_embedding", - "params": list(self.xy_embedding.parameters()), - "lr": lr, - }, - ] - """ return [ { "name": "model", @@ -854,6 +787,34 @@ def should_log(self, global_step: int, BS: int) -> Tuple[bool, bool]: return log_img, log_text + def prepare_inputs( + self, batch: Batch + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + batch: a batch of data + + Returns: + positions: (bs, num_encode_frames, 2) + mask: (bs, num_encode_frames) + velocity: (bs,) + """ + world_to_car, mask, lengths = batch.long_cam_T + positions = batch.positions() + positions = positions[..., :2] + + # calculate velocity between first two frames to allow model to understand current speed + # TODO: convert this to a categorical embedding + velocity = positions[:, 1] - positions[:, 0] + assert positions.size(-1) == 2 + velocity = torch.linalg.vector_norm(velocity, dim=-1, keepdim=True) + + # approximately 0.5 fps since video is 12hz + positions = positions[:, ::6] + mask = mask[:, ::6] + + return positions, mask, velocity + def forward( self, batch: Batch, @@ -888,87 +849,7 @@ def forward( normalize_img(feats[0, 0]), ) - """ - # for size, device only - empty_mask = torch.empty(self.feat_shape, device=device) - - all_feats = [] - - for cam in self.cameras: - feats = batch.color[cam][:, : self.num_encode_frames] - block_size = min(*self.feat_shape) // 3 - if True: - mask = torch.ones_like(empty_mask).bool() - else: - mask = random_block_mask( - empty_mask, - block_size=(block_size, block_size), - num_blocks=8, - ) - - if writer is not None and log_text: - ctx.add_scalar( - f"{cam}/count", - mask.long().sum(), - ) - if writer is not None and log_img: - ctx.add_image( - f"{cam}/mask", - render_color(mask), - ) - - with autocast(): - # checkpoint encoders to save memory - encoder = self.encoders[cam] - unmasked, cam_feats = torch.utils.checkpoint.checkpoint( - encoder, - feats.flatten(0, 1), - mask, - use_reentrant=False, - ) - assert cam_feats.requires_grad, f"missing grad for cam {cam}" - - if writer is not None and log_img: - ctx.add_image( - f"{cam}/color", - normalize_img(feats[0, 0]), - ) - ctx.add_image( - f"{cam}/pca", - render_pca(unmasked[0].permute(1, 2, 0)), - ) - - if writer is not None and log_text: - register_log_grad_norm( - t=cam_feats, - writer=writer, - key="gradnorm/cam-encoder", - tag=cam, - ) - - # (n, seq_len, hidden_dim) -> (bs, num_encode_frames, seq_len, hidden_dim) - cam_feats = cam_feats.unflatten(0, feats.shape[:2]) - - # flatten time - # (bs, num_encode_frames, seq_len, hidden_dim) -> (bs, num_encode_frames * seq_len, hidden_dim) - cam_feats = cam_feats.flatten(1, 2) - - all_feats.append(cam_feats) - - input_tokens = torch.cat(all_feats, dim=1) - """ - - world_to_car, mask, lengths = batch.long_cam_T - positions = batch.positions() - positions = positions[..., :2] - - # calculate velocity between first two frames to allow model to understand current speed - # TODO: convert this to a categorical embedding - velocity = positions[:, 1] - positions[:, 0] - assert positions.size(-1) == 2 - velocity = torch.linalg.vector_norm(velocity, dim=-1, keepdim=True) - - # static_features = self.static_features_encoder(velocity).unsqueeze(1) + positions, mask, velocity = self.prepare_inputs(batch) lengths = mask.sum(dim=-1) min_len = lengths.amin() @@ -976,30 +857,6 @@ def forward( # truncate to shortest sequence pos_len = lengths.amin() - # if pos_len % align != 0: - # pos_len -= pos_len % align - # assert pos_len >= 8 - # positions = positions[:, :pos_len] - # mask = mask[:, :pos_len] - - # approximately 0.5 fps since video is 12hz - positions = positions[:, ::6] - mask = mask[:, ::6] - - """ - # we need to be aligned to size 8 - # pad length - align = 8 - if positions.size(1) % align != 0: - pad = align - positions.size(1) % align - mask = F.pad(mask, (0, pad), value=False) - positions = F.pad(positions, (0, 0, 0, pad), value=0) - pos_len = positions.size(1) - - assert positions.size(1) % align == 0 - assert mask.size(1) % align == 0 - assert positions.size(1) == mask.size(1) - """ num_elements = mask.float().sum() @@ -1016,53 +873,6 @@ def forward( posmax = positions.abs().amax() assert posmax < 100000, positions - """ - traj_embed = self.xy_embedding(positions) - """ - - """ - noise = torch.randn(traj_embed.shape, device=traj_embed.device) / self.noise_scale - timesteps = torch.randint( - 0, - self.noise_scheduler.config.num_train_timesteps, - (BS,), - device=traj_embed.device, - dtype=torch.int64, - ) - traj_embed_noise = self.noise_scheduler.add_noise(traj_embed, noise, timesteps) - - if writer and log_text: - ctx.add_scalars( - "paths/embed_scales", - { - "embed": torch.linalg.vector_norm(traj_embed, dim=-1).mean().cpu(), - "embed_with_noise": torch.linalg.vector_norm(traj_embed_noise, dim=-1).mean().cpu(), - "noise": torch.linalg.vector_norm(noise, dim=-1).mean().cpu(), - }, - ) - """ - - """ - query = self.query_embed.repeat(BS, 1, 1) - - #with autocast(): - # add static feature info to all condition keys to avoid noise - input_tokens = input_tokens + static_features - - pred_embed = self.denoiser(query, mask, input_tokens) - - # reduce to match target - pred_embed = pred_embed[:, :positions.size(1)] - """ - - """ - noise_loss = F.mse_loss(pred_noise, noise, reduction="none") - noise_loss = noise_loss[mask] - losses["diffusion"] = noise_loss.mean() - """ - - # pred_loss, pred_traj, all_pred_traj = self.xy_embedding.loss(pred_embed, positions, mask) - cam = self.cameras[0] pred_losses, pred_traj, all_pred_traj = self.model( velocity, batch.color[cam], positions, mask @@ -1073,13 +883,13 @@ def forward( dreamed_imgs = [] for i in range(BS): - cond_img = batch.color[cam][i:i+1, 0] - cond_traj = pred_traj[i:i+1] + cond_img = batch.color[cam][i : i + 1, 0] + cond_traj = pred_traj[i : i + 1] dreamed_img = self.vista.generate(cond_img, cond_traj) # add last img (frame 10 == 1s) dreamed_imgs.append(dreamed_img[-1]) - + # [BS, 1, 3, H, W] dream_img = torch.stack(dreamed_imgs, dim=0).unsqueeze(1) @@ -1093,7 +903,7 @@ def forward( dream_target, dream_mask, dream_positions, dream_pred = compute_dream_pos( positions[:, :pred_traj_len], mask[:, :pred_traj_len], - pred_traj[:, :pred_traj_len], + pred_traj[:, :pred_traj_len], step=2, ) @@ -1102,20 +912,8 @@ def forward( ) for k, v in dream_losses.items(): losses[f"dream-{k}"] = v - - - # noise_loss, noise_traj = self.y_embedding.loss(traj_embed_noise, positions) - # losses["ae/with_noise"] = ( - # noise_loss.mean() * 0.01 - # ) - # ae_loss, ae_traj, _ = self.xy_embedding.loss(traj_embed, positions, mask) - # losses["ae/ae"] = ae_loss.mean() * 0.1 if writer and log_text: - # ctx.add_scalar( - # "ae/ae", - # ae_loss.mean().cpu(), - # ) size = min(pred_traj.size(1), positions.size(1)) @@ -1137,44 +935,6 @@ def forward( with torch.no_grad(): fig = plt.figure() - # generate prediction - - """ - pred_traj = torch.randn_like(noise[:1]) / self.noise_scale - self.eval_noise_scheduler.set_timesteps(self.num_inference_timesteps) - for timestep in self.eval_noise_scheduler.timesteps: - with autocast(): - pred_traj = self.eval_noise_scheduler.scale_model_input( - pred_traj, timestep - ) - noise = self.denoiser(pred_traj, mask[:1], input_tokens[:1]) - pred_traj = self.eval_noise_scheduler.step( - noise, - timestep, - pred_traj, - generator=torch.Generator(device=device).manual_seed(0), - ).prev_sample - pred_positions = self.xy_embedding.decode(pred_traj)[0, :pred_len].cpu() - """ - - """ - noise_positions = self.xy_embedding.decode(traj_embed_noise[:1])[ - 0, - :pred_len, - ].cpu() - plt.plot( - noise_positions[..., 0], noise_positions[..., 1], label="with_noise" - ) - pos_positions = self.xy_embedding.decode(traj_embed[:1])[ - 0, :pred_len - ].cpu() - """ - - """ - ae_positions = ae_traj[0, :pred_len].cpu() - plt.plot(ae_positions[..., 0], ae_positions[..., 1], label="ae") - """ - target = positions[0, :pred_len].detach().cpu() plt.plot(target[..., 0], target[..., 1], label="target") @@ -1208,10 +968,14 @@ def forward( plt.plot(target[..., 0], target[..., 1], label="dream_target") pred_positions = dream_pred[0, :pred_len].cpu() - plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="og_pred") + plt.plot( + pred_positions[..., 0], pred_positions[..., 1], label="og_pred" + ) pred_positions = dream_traj[0, :pred_len].cpu() - plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="new_pred") + plt.plot( + pred_positions[..., 0], pred_positions[..., 1], label="new_pred" + ) fig.legend() plt.gca().set_aspect("equal") @@ -1223,7 +987,62 @@ def forward( return losses -def compute_dream_pos(positions: torch.Tensor, mask: torch.Tensor, pred_traj: torch.Tensor, step: int=2) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def test( + self, batch: Batch, global_step: int, writer: Optional[SummaryWriter] = None + ) -> Dict[str, torch.Tensor]: + batch = self.batch_transform(batch) + + losses = {} + + BS = len(batch.distances) + device = batch.device() + + log_img, log_text = self.should_log(global_step, BS) + ctx = Context( + log_img=log_img, + log_text=log_text, + global_step=global_step, + writer=writer, + output=None, + start_frame=0, + weights=batch.weight, + scaler=None, + name="test/", + ) + + positions, mask, velocity = self.prepare_inputs(batch) + cam = self.cameras[0] + + pred_losses, pred_traj, all_pred_traj = self.model( + velocity, batch.color[cam], positions, mask + ) + for name, loss in pred_losses.items(): + metric = self.test_losses[name] + metric.to(device) + metric.update(loss.mean()) + + size = min(pred_traj.size(1), positions.size(1)) + + pred_traj = pred_traj[:, :size].flatten() + positions = positions[:, :size].flatten() + + self.test_mae.update(pred_traj, positions) + self.test_mse.update(pred_traj, positions) + + if log_text: + for name, loss in self.test_losses.items(): + ctx.add_scalar( + f"loss/{name}", + loss.compute(), + ) + + ctx.add_scalar("mae", self.test_mae.compute()) + ctx.add_scalar("mse", self.test_mse.compute()) + + +def compute_dream_pos( + positions: torch.Tensor, mask: torch.Tensor, pred_traj: torch.Tensor, step: int = 2 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute a new ground truth trajectory for the dreamer to use as a loss. Outputted directory is centered at 0,0 and uses the new direction. @@ -1239,19 +1058,28 @@ def compute_dream_pos(positions: torch.Tensor, mask: torch.Tensor, pred_traj: to dream_positions: (B, T-step, 2) positions in step coordinate frame dream_pred: (B, T-step, 2) pred_traj in step coordinate frame """ - direction = pred_traj[:, step] - pred_traj[:, step-1] + direction = pred_traj[:, step] - pred_traj[:, step - 1] angle = torch.atan2(direction[:, 1], direction[:, 0]) - rot = torch.stack([ - torch.stack([ - torch.cos(angle), - -torch.sin(angle), - ], dim=-1), - torch.stack([ - torch.sin(angle), - torch.cos(angle), - ], dim=-1) - ], dim=-1) + rot = torch.stack( + [ + torch.stack( + [ + torch.cos(angle), + -torch.sin(angle), + ], + dim=-1, + ), + torch.stack( + [ + torch.sin(angle), + torch.cos(angle), + ], + dim=-1, + ), + ], + dim=-1, + ) rot = rot.pinverse() # drop old points @@ -1260,9 +1088,9 @@ def compute_dream_pos(positions: torch.Tensor, mask: torch.Tensor, pred_traj: to pred_traj = pred_traj[:, step:] # use linear interpolation between pred_traj and positions - #factor = torch.arange(0, positions.size(1), device=positions.device) / (positions.size(1) - 1) - #factor = factor.unsqueeze(0).unsqueeze(-1) - #dream_pos = pred_traj * (1-factor) + positions * factor + # factor = torch.arange(0, positions.size(1), device=positions.device) / (positions.size(1) - 1) + # factor = factor.unsqueeze(0).unsqueeze(-1) + # dream_pos = pred_traj * (1-factor) + positions * factor # use ema interpolation between pred_traj and positions factor = torch.full((positions.size(1),), 0.5, device=positions.device) @@ -1270,7 +1098,7 @@ def compute_dream_pos(positions: torch.Tensor, mask: torch.Tensor, pred_traj: to factor = torch.cumprod(factor, dim=0) factor = factor.unsqueeze(0).unsqueeze(-1) - dream_pos = pred_traj * factor + positions * (1-factor) + dream_pos = pred_traj * factor + positions * (1 - factor) origin = dream_pos[:, 0:1] diff --git a/torchdrive/tasks/test_diff_traj.py b/torchdrive/tasks/test_diff_traj.py index 24e15e1..e3fa1dc 100644 --- a/torchdrive/tasks/test_diff_traj.py +++ b/torchdrive/tasks/test_diff_traj.py @@ -5,6 +5,7 @@ import torch from torchdrive.data import Batch, dummy_batch from torchdrive.tasks.diff_traj import ( + compute_dream_pos, DiffTraj, square_mask, XEmbedding, @@ -12,7 +13,6 @@ XYLinearEmbedding, XYMLPEncoder, XYSineMLPEncoder, - compute_dream_pos, ) @@ -223,7 +223,9 @@ def test_compute_dream_pos(self): mask = torch.ones(2, 18) pred_traj = torch.rand(2, 18, 2) - dream_target, dream_mask, dream_positions, dream_pred = compute_dream_pos(positions, mask, pred_traj) + dream_target, dream_mask, dream_positions, dream_pred = compute_dream_pos( + positions, mask, pred_traj + ) self.assertEqual(dream_target.shape, (2, 16, 2)) self.assertEqual(dream_mask.shape, (2, 16)) self.assertEqual(dream_positions.shape, (2, 16, 2)) diff --git a/torchdrive/train_config.py b/torchdrive/train_config.py index 7ec8acd..ee09aa1 100644 --- a/torchdrive/train_config.py +++ b/torchdrive/train_config.py @@ -61,12 +61,15 @@ def create_dataset(self, smoke: bool = False) -> Tuple[Dataset, Optional[Dataset lidar=False, num_frames=self.num_frames, ) - test_dataset = NuscenesDataset( - data_dir=self.dataset_path, - version="v1.0-mini" if smoke else "v1.0-test", - lidar=False, - num_frames=self.num_frames, - ) + if smoke: + test_dataset = dataset + else: + test_dataset = NuscenesDataset( + data_dir=self.dataset_path, + version="v1.0-test", + lidar=False, + num_frames=self.num_frames, + ) elif self.dataset == Datasets.DUMMY: from torchdrive.datasets.dummy import DummyDataset @@ -128,6 +131,7 @@ def create_model( self, device: torch.device, compile_fn: Callable[[nn.Module], nn.Module] = lambda x: x, + test: bool = False, ) -> BEVTaskVan: from torchdrive.transforms.batch import ( Compose, @@ -340,12 +344,14 @@ def create_model( self, device: torch.device, compile_fn: Callable[[nn.Module], nn.Module] = lambda x: x, + test: bool = False, ) -> DiffTraj: model = DiffTraj( cameras=self.cameras, num_encode_frames=self.num_encode_frames, cam_shape=self.cam_shape, num_frames=self.num_frames, + test=test, ).to(device) # for cam_encoder in model.encoders.values(): @@ -433,6 +439,12 @@ def create_parser() -> argparse.ArgumentParser: action="store_true", help="run with a smaller smoke test config", ) + parser.add_argument( + "--test", + default=False, + action="store_true", + help="compute the test set metrics", + ) parser.add_argument( "--config", diff --git a/train.py b/train.py index 179d9d7..a2ec75f 100644 --- a/train.py +++ b/train.py @@ -90,10 +90,12 @@ dataset, test_dataset = config.create_dataset(smoke=args.smoke) +if args.test: + dataset = test_dataset if RANK == 0: # pyre-fixme[6]: len - print(f"trainset size {len(dataset)}") + print(f"dataset size {len(dataset)}") if args.anomaly_detection: torch.set_anomaly_enabled(True) @@ -122,7 +124,11 @@ def compile_parent(m: nn.Module) -> nn.Module: compile_fn = compile_parent -model: BEVTaskVan = config.create_model(device=device, compile_fn=compile_fn) +model: BEVTaskVan = config.create_model( + device=device, + compile_fn=compile_fn, + test=args.test, +) if False and WORLD_SIZE > 1: ddp_model: torch.nn.Module = DistributedDataParallel( @@ -274,7 +280,9 @@ def save(epoch: int) -> None: pin_memory=True, sampler=sampler, ) - test_collator = TransferCollator(dataloader, batch_size=config.batch_size, device=device) + test_collator = TransferCollator( + dataloader, batch_size=config.batch_size, device=device + ) meaned_losses: Dict[str, torchmetrics.aggregation.MeanMetric] = defaultdict( @@ -296,8 +304,7 @@ def save(epoch: int) -> None: prof = None -@record -def run(): +def train() -> None: global global_step for epoch in range(NUM_EPOCHS): @@ -407,5 +414,33 @@ def run(): save(epoch + 1) +def test() -> None: + # only show progress on rank 0 + batch_iter = tqdm(collator, desc=f"test") if LOCAL_RANK == 0 else collator + + ddp_model.eval() + + with torch.no_grad(): + for global_step, batch in enumerate(batch_iter): + batch = cast(Optional[Batch], batch) + if batch is None: + print("empty batch") + continue + + batch = batch.to(device) + + log_img, log_text = model.should_log(global_step, BS) + + ddp_model.test(batch, global_step, writer=writer) + + +@record +def main() -> None: + if args.test: + test() + else: + train() + + if __name__ == "__main__": - run() + main()