Skip to content
This repository has been archived by the owner on Feb 22, 2024. It is now read-only.

Commit

Permalink
Added ZeroFlow XL Student Model
Browse files Browse the repository at this point in the history
  • Loading branch information
kylevedder committed Aug 1, 2023
1 parent 803a87c commit 47d227c
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 8 deletions.
4 changes: 2 additions & 2 deletions models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .feature_pyramid import FeaturePyramidNetwork
from .fast_flow_unet import FastFlowUNet
from .fast_flow_unet import FastFlowUNet, FastFlowUNetXL

__all__ = ["FeaturePyramidNetwork", "FastFlowUNet"]
__all__ = ["FeaturePyramidNetwork", "FastFlowUNet", "FastFlowUNetXL"]
78 changes: 77 additions & 1 deletion models/backbones/fast_flow_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,80 @@ def forward(self, pc0_B: torch.Tensor,
U = self.decoder_step3(T, Bstar)
V = self.decoder_step4(U)

return V
return V


class FastFlowUNetXL(nn.Module):
"""
FastFlowUNet with a 64 channel input and another conv stepdown layer.
"""

def __init__(self) -> None:
super().__init__()

self.encoder_step_1 = nn.Sequential(ConvWithNorms(64, 128, 3, 2, 1),
ConvWithNorms(128, 128, 3, 1, 1),
ConvWithNorms(128, 128, 3, 1, 1),
ConvWithNorms(128, 128, 3, 1, 1))
self.encoder_step_2 = nn.Sequential(ConvWithNorms(128, 256, 3, 2, 1),
ConvWithNorms(256, 256, 3, 1, 1),
ConvWithNorms(256, 256, 3, 1, 1),
ConvWithNorms(256, 256, 3, 1, 1),
ConvWithNorms(256, 256, 3, 1, 1),
ConvWithNorms(256, 256, 3, 1, 1))
self.encoder_step_3 = nn.Sequential(ConvWithNorms(256, 512, 3, 2, 1),
ConvWithNorms(512, 512, 3, 1, 1),
ConvWithNorms(512, 512, 3, 1, 1),
ConvWithNorms(512, 512, 3, 1, 1),
ConvWithNorms(512, 512, 3, 1, 1),
ConvWithNorms(512, 512, 3, 1, 1))
self.encoder_step_4 = nn.Sequential(ConvWithNorms(512, 1024, 3, 2, 1),
ConvWithNorms(1024, 1024, 3, 1, 1),
ConvWithNorms(1024, 1024, 3, 1, 1),
ConvWithNorms(1024, 1024, 3, 1, 1),
ConvWithNorms(1024, 1024, 3, 1, 1),
ConvWithNorms(1024, 1024, 3, 1, 1))

self.decoder_step1 = UpsampleSkip(2048, 1024, 1024)
self.decoder_step2 = UpsampleSkip(1024, 512, 512)
self.decoder_step3 = UpsampleSkip(512, 256, 256)
self.decoder_step4 = UpsampleSkip(256, 128, 128)
self.decoder_step5 = nn.Conv2d(128, 128, 3, 1, 1)

def forward(self, pc0_B: torch.Tensor,
pc1_B: torch.Tensor) -> torch.Tensor:

expected_channels = 64
assert pc0_B.shape[
1] == expected_channels, f"Expected {expected_channels} channels, got {pc0_B.shape[1]}"
assert pc1_B.shape[
1] == expected_channels, f"Expected {expected_channels} channels, got {pc1_B.shape[1]}"

pc0_F = self.encoder_step_1(pc0_B)
pc0_L = self.encoder_step_2(pc0_F)
pc0_R = self.encoder_step_3(pc0_L)
pc0_T = self.encoder_step_4(pc0_R)

pc1_F = self.encoder_step_1(pc1_B)
pc1_L = self.encoder_step_2(pc1_F)
pc1_R = self.encoder_step_3(pc1_L)
pc1_T = self.encoder_step_4(pc1_R)

Tstar = torch.cat([pc0_T, pc1_T],
dim=1) # torch.Size([1, 2048, 32, 32])
Rstar = torch.cat([pc0_R, pc1_R],
dim=1) # torch.Size([1, 1024, 64, 64])
Lstar = torch.cat([pc0_L, pc1_L],
dim=1) # torch.Size([1, 512, 128, 128])
Fstar = torch.cat([pc0_F, pc1_F],
dim=1) # torch.Size([1, 256, 256, 256])
Bstar = torch.cat([pc0_B, pc1_B],
dim=1) # torch.Size([1, 128, 512, 512])

S = self.decoder_step1(Tstar, Rstar)
T = self.decoder_step2(S, Lstar)
U = self.decoder_step3(T, Fstar)
V = self.decoder_step4(U, Bstar)
W = self.decoder_step5(V)

return W
18 changes: 13 additions & 5 deletions models/fast_flow_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
from models.embedders import HardEmbedder, DynamicEmbedder
from models.backbones import FastFlowUNet
from models.backbones import FastFlowUNet, FastFlowUNetXL
from models.heads import FastFlowDecoder, FastFlowDecoderStepDown
from pointclouds import from_fixed_array
from pointclouds.losses import warped_pc_loss
Expand Down Expand Up @@ -167,6 +167,10 @@ def __call__(self, input_batch, model_res_dict):
# self._visualize_regressed_ground_truth_pcs(model_res,
# gt_flow_array_stack)

assert len(estimated_flows) <= len(
gt_flow_array_stack
), f"estimated_flows {len(estimated_flows)} > gt_flow_array_stack {len(gt_flow_array_stack)}"

total_loss = 0
# Iterate through the batch
for est_flow, gt_flow_array in zip(estimated_flows,
Expand Down Expand Up @@ -311,21 +315,25 @@ def __init__(self,
POINT_CLOUD_RANGE,
FEATURE_CHANNELS,
SEQUENCE_LENGTH,
bottleneck_head=False) -> None:
bottleneck_head=False,
xl_backbone=False) -> None:
super().__init__()
self.SEQUENCE_LENGTH = SEQUENCE_LENGTH
assert self.SEQUENCE_LENGTH == 2, "This implementation only supports a sequence length of 2."
self.embedder = DynamicEmbedder(voxel_size=VOXEL_SIZE,
pseudo_image_dims=PSEUDO_IMAGE_DIMS,
point_cloud_range=POINT_CLOUD_RANGE,
feat_channels=FEATURE_CHANNELS)

self.backbone = FastFlowUNet()
if xl_backbone:
self.backbone = FastFlowUNetXL()
else:
self.backbone = FastFlowUNet()
if bottleneck_head:
self.head = FastFlowDecoderStepDown(
voxel_pillar_size=VOXEL_SIZE[:2], num_stepdowns=3)
else:
self.head = FastFlowDecoder()
self.head = FastFlowDecoder(pseudoimage_channels=FEATURE_CHANNELS *
2)

def _model_forward(self, pc0s, pc1s):

Expand Down

0 comments on commit 47d227c

Please sign in to comment.