Skip to content

Commit

Permalink
Add stereo preset transforms (#6549)
Browse files Browse the repository at this point in the history
* Added transforms for Stereo Matching

* changed implicit Y scaling to 0.

* Adressed some comments

* addressed type hint

* Added interpolation random interpolation strategy

* Aligned crop get params

* fixed bug in RandomErase

* Adressed scaling and typos

* Adressed occlusion typo

* Changed parameter order in F.erase

* fixed random erase

* Added inference preset transform for stereo matching

* added contiguous reshape to output tensors

* Adressed comments

* Modified the transform preset to use Tuple[int, int]

* adressed NITs

* added grayscale transform, align resize -> mask

* changed max disparity default behaviour

* added fixed resize, changed masking in sparse flow masking

* update to align with argparse

* changed default mask in asymetric pairs

* moved grayscale order

* changed grayscale api to accept to tensor variant

* mypy fix

* changed resize specs

* adressed nits

* added type hints

* mypy fix

* mypy fix

* mypy fix

Co-authored-by: Joao Gomes <jdsgomes@fb.com>
  • Loading branch information
TeodorPoncu and jdsgomes authored Sep 22, 2022
1 parent 2c1022e commit 0fcfaa1
Show file tree
Hide file tree
Showing 3 changed files with 864 additions and 0 deletions.
144 changes: 144 additions & 0 deletions references/depth/stereo/presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Optional, Tuple, Union

import torch
import transforms as T


class StereoMatchingEvalPreset(torch.nn.Module):
def __init__(
self,
mean: float = 0.5,
std: float = 0.5,
resize_size: Optional[Tuple[int, ...]] = None,
max_disparity: Optional[float] = None,
interpolation_type: str = "bilinear",
use_grayscale: bool = False,
) -> None:
super().__init__()

transforms = [
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
]

if use_grayscale:
transforms.append(T.ConvertToGrayscale())

if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))

transforms.extend(
[
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity=max_disparity),
T.ValidateModelInput(),
]
)

self.transforms = T.Compose(transforms)

def forward(self, images, disparities, masks):
return self.transforms(images, disparities, masks)


class StereoMatchingTrainPreset(torch.nn.Module):
def __init__(
self,
*,
resize_size: Optional[Tuple[int, ...]],
resize_interpolation_type: str = "bilinear",
# RandomResizeAndCrop params
crop_size: Tuple[int, int],
rescale_prob: float = 1.0,
scaling_type: str = "exponential",
scale_range: Tuple[float, float] = (-0.2, 0.5),
scale_interpolation_type: str = "bilinear",
# convert to grayscale
use_grayscale: bool = False,
# normalization params
mean: float = 0.5,
std: float = 0.5,
# processing device
gpu_transforms: bool = False,
# masking
max_disparity: Optional[int] = 256,
# SpatialShift params
spatial_shift_prob: float = 0.5,
spatial_shift_max_angle: float = 0.5,
spatial_shift_max_displacement: float = 0.5,
spatial_shift_interpolation_type: str = "bilinear",
# AssymetricColorJitter
gamma_range: Tuple[float, float] = (0.8, 1.2),
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
saturation: Union[int, Tuple[int, int]] = 0.0,
hue: Union[int, Tuple[int, int]] = 0.0,
asymmetric_jitter_prob: float = 1.0,
# RandomHorizontalFlip
horizontal_flip_prob: float = 0.5,
# RandomOcclusion
occlusion_prob: float = 0.0,
occlusion_px_range: Tuple[int, int] = (50, 100),
# RandomErase
erase_prob: float = 0.0,
erase_px_range: Tuple[int, int] = (50, 100),
erase_num_repeats: int = 1,
) -> None:

if scaling_type not in ["linear", "exponential"]:
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")

super().__init__()
transforms = [T.ToTensor()]

# when fixing size across multiple datasets, we ensure
# that the same size is used for all datasets when cropping
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))

if gpu_transforms:
transforms.append(T.ToGPU())

# color handling
color_transforms = [
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
]

if use_grayscale:
color_transforms.append(T.ConvertToGrayscale())

transforms.extend(color_transforms)

transforms.extend(
[
T.RandomSpatialShift(
p=spatial_shift_prob,
max_angle=spatial_shift_max_angle,
max_px_shift=spatial_shift_max_displacement,
interpolation_type=spatial_shift_interpolation_type,
),
T.ConvertImageDtype(torch.float32),
T.RandomRescaleAndCrop(
crop_size=crop_size,
scale_range=scale_range,
rescale_prob=rescale_prob,
scaling_type=scaling_type,
interpolation_type=scale_interpolation_type,
),
T.RandomHorizontalFlip(horizontal_flip_prob),
# occlusion after flip, otherwise we're occluding the reference image
T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity),
T.ValidateModelInput(),
]
)

self.transforms = T.Compose(transforms)

def forward(self, images, disparties, mask):
return self.transforms(images, disparties, mask)
Loading

0 comments on commit 0fcfaa1

Please sign in to comment.