Skip to content

Commit

Permalink
Some fixes for crestereo (#6791)
Browse files Browse the repository at this point in the history
  • Loading branch information
YosuaMichael authored Oct 19, 2022
1 parent 78fdaf3 commit 7a62a54
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions torchvision/prototype/models/depth/stereo/crestereo.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def _get_window_type(self, iteration: int) -> str:
return "1d" if iteration % 2 == 0 else "2d"

def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor], num_iters: int = 10
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 10
) -> List[Tensor]:
features = torch.cat([left_image, right_image], dim=0)
features = self.feature_encoder(features)
Expand All @@ -781,10 +781,10 @@ def forward(
ctx_pyramid = self.downsampling_pyramid(ctx)

# we store in reversed order because we process the pyramid from top to bottom
l_pyramid: Dict[str, Tensor] = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
r_pyramid: Dict[str, Tensor] = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
net_pyramid: Dict[str, Tensor] = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
ctx_pyramid: Dict[str, Tensor] = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}
l_pyramid = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
r_pyramid = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
net_pyramid = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
ctx_pyramid = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}

# offsets for sampling pixel candidates in the correlation ops
offsets: Dict[str, Tensor] = {}
Expand Down Expand Up @@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
:members:
"""

weights = CREStereo_Base_Weights.verify(weights)

return _crestereo(
weights=weights,
progress=progress,
Expand Down

0 comments on commit 7a62a54

Please sign in to comment.