diff --git a/models/backbones/__init__.py b/models/backbones/__init__.py index 4b810a8..21a47cd 100644 --- a/models/backbones/__init__.py +++ b/models/backbones/__init__.py @@ -1,4 +1,4 @@ from .feature_pyramid import FeaturePyramidNetwork -from .fast_flow_unet import FastFlowUNet +from .fast_flow_unet import FastFlowUNet, FastFlowUNetXL -__all__ = ["FeaturePyramidNetwork", "FastFlowUNet"] \ No newline at end of file +__all__ = ["FeaturePyramidNetwork", "FastFlowUNet", "FastFlowUNetXL"] diff --git a/models/backbones/fast_flow_unet.py b/models/backbones/fast_flow_unet.py index 13c30ec..f87aeae 100644 --- a/models/backbones/fast_flow_unet.py +++ b/models/backbones/fast_flow_unet.py @@ -144,4 +144,80 @@ def forward(self, pc0_B: torch.Tensor, U = self.decoder_step3(T, Bstar) V = self.decoder_step4(U) - return V \ No newline at end of file + return V + + +class FastFlowUNetXL(nn.Module): + """ + FastFlowUNet with a 64 channel input and another conv stepdown layer. + """ + + def __init__(self) -> None: + super().__init__() + + self.encoder_step_1 = nn.Sequential(ConvWithNorms(64, 128, 3, 2, 1), + ConvWithNorms(128, 128, 3, 1, 1), + ConvWithNorms(128, 128, 3, 1, 1), + ConvWithNorms(128, 128, 3, 1, 1)) + self.encoder_step_2 = nn.Sequential(ConvWithNorms(128, 256, 3, 2, 1), + ConvWithNorms(256, 256, 3, 1, 1), + ConvWithNorms(256, 256, 3, 1, 1), + ConvWithNorms(256, 256, 3, 1, 1), + ConvWithNorms(256, 256, 3, 1, 1), + ConvWithNorms(256, 256, 3, 1, 1)) + self.encoder_step_3 = nn.Sequential(ConvWithNorms(256, 512, 3, 2, 1), + ConvWithNorms(512, 512, 3, 1, 1), + ConvWithNorms(512, 512, 3, 1, 1), + ConvWithNorms(512, 512, 3, 1, 1), + ConvWithNorms(512, 512, 3, 1, 1), + ConvWithNorms(512, 512, 3, 1, 1)) + self.encoder_step_4 = nn.Sequential(ConvWithNorms(512, 1024, 3, 2, 1), + ConvWithNorms(1024, 1024, 3, 1, 1), + ConvWithNorms(1024, 1024, 3, 1, 1), + ConvWithNorms(1024, 1024, 3, 1, 1), + ConvWithNorms(1024, 1024, 3, 1, 1), + ConvWithNorms(1024, 1024, 3, 1, 1)) + + self.decoder_step1 = UpsampleSkip(2048, 1024, 1024) + self.decoder_step2 = UpsampleSkip(1024, 512, 512) + self.decoder_step3 = UpsampleSkip(512, 256, 256) + self.decoder_step4 = UpsampleSkip(256, 128, 128) + self.decoder_step5 = nn.Conv2d(128, 128, 3, 1, 1) + + def forward(self, pc0_B: torch.Tensor, + pc1_B: torch.Tensor) -> torch.Tensor: + + expected_channels = 64 + assert pc0_B.shape[ + 1] == expected_channels, f"Expected {expected_channels} channels, got {pc0_B.shape[1]}" + assert pc1_B.shape[ + 1] == expected_channels, f"Expected {expected_channels} channels, got {pc1_B.shape[1]}" + + pc0_F = self.encoder_step_1(pc0_B) + pc0_L = self.encoder_step_2(pc0_F) + pc0_R = self.encoder_step_3(pc0_L) + pc0_T = self.encoder_step_4(pc0_R) + + pc1_F = self.encoder_step_1(pc1_B) + pc1_L = self.encoder_step_2(pc1_F) + pc1_R = self.encoder_step_3(pc1_L) + pc1_T = self.encoder_step_4(pc1_R) + + Tstar = torch.cat([pc0_T, pc1_T], + dim=1) # torch.Size([1, 2048, 32, 32]) + Rstar = torch.cat([pc0_R, pc1_R], + dim=1) # torch.Size([1, 1024, 64, 64]) + Lstar = torch.cat([pc0_L, pc1_L], + dim=1) # torch.Size([1, 512, 128, 128]) + Fstar = torch.cat([pc0_F, pc1_F], + dim=1) # torch.Size([1, 256, 256, 256]) + Bstar = torch.cat([pc0_B, pc1_B], + dim=1) # torch.Size([1, 128, 512, 512]) + + S = self.decoder_step1(Tstar, Rstar) + T = self.decoder_step2(S, Lstar) + U = self.decoder_step3(T, Fstar) + V = self.decoder_step4(U, Bstar) + W = self.decoder_step5(V) + + return W diff --git a/models/fast_flow_3d.py b/models/fast_flow_3d.py index db352ae..6abbff5 100644 --- a/models/fast_flow_3d.py +++ b/models/fast_flow_3d.py @@ -3,7 +3,7 @@ import numpy as np from models.embedders import HardEmbedder, DynamicEmbedder -from models.backbones import FastFlowUNet +from models.backbones import FastFlowUNet, FastFlowUNetXL from models.heads import FastFlowDecoder, FastFlowDecoderStepDown from pointclouds import from_fixed_array from pointclouds.losses import warped_pc_loss @@ -167,6 +167,10 @@ def __call__(self, input_batch, model_res_dict): # self._visualize_regressed_ground_truth_pcs(model_res, # gt_flow_array_stack) + assert len(estimated_flows) <= len( + gt_flow_array_stack + ), f"estimated_flows {len(estimated_flows)} > gt_flow_array_stack {len(gt_flow_array_stack)}" + total_loss = 0 # Iterate through the batch for est_flow, gt_flow_array in zip(estimated_flows, @@ -311,7 +315,8 @@ def __init__(self, POINT_CLOUD_RANGE, FEATURE_CHANNELS, SEQUENCE_LENGTH, - bottleneck_head=False) -> None: + bottleneck_head=False, + xl_backbone=False) -> None: super().__init__() self.SEQUENCE_LENGTH = SEQUENCE_LENGTH assert self.SEQUENCE_LENGTH == 2, "This implementation only supports a sequence length of 2." @@ -319,13 +324,16 @@ def __init__(self, pseudo_image_dims=PSEUDO_IMAGE_DIMS, point_cloud_range=POINT_CLOUD_RANGE, feat_channels=FEATURE_CHANNELS) - - self.backbone = FastFlowUNet() + if xl_backbone: + self.backbone = FastFlowUNetXL() + else: + self.backbone = FastFlowUNet() if bottleneck_head: self.head = FastFlowDecoderStepDown( voxel_pillar_size=VOXEL_SIZE[:2], num_stepdowns=3) else: - self.head = FastFlowDecoder() + self.head = FastFlowDecoder(pseudoimage_channels=FEATURE_CHANNELS * + 2) def _model_forward(self, pc0s, pc1s):