-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add stereo preset transforms (#6549)
* Added transforms for Stereo Matching * changed implicit Y scaling to 0. * Adressed some comments * addressed type hint * Added interpolation random interpolation strategy * Aligned crop get params * fixed bug in RandomErase * Adressed scaling and typos * Adressed occlusion typo * Changed parameter order in F.erase * fixed random erase * Added inference preset transform for stereo matching * added contiguous reshape to output tensors * Adressed comments * Modified the transform preset to use Tuple[int, int] * adressed NITs * added grayscale transform, align resize -> mask * changed max disparity default behaviour * added fixed resize, changed masking in sparse flow masking * update to align with argparse * changed default mask in asymetric pairs * moved grayscale order * changed grayscale api to accept to tensor variant * mypy fix * changed resize specs * adressed nits * added type hints * mypy fix * mypy fix * mypy fix Co-authored-by: Joao Gomes <jdsgomes@fb.com>
- Loading branch information
1 parent
2c1022e
commit 0fcfaa1
Showing
3 changed files
with
864 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from typing import Optional, Tuple, Union | ||
|
||
import torch | ||
import transforms as T | ||
|
||
|
||
class StereoMatchingEvalPreset(torch.nn.Module): | ||
def __init__( | ||
self, | ||
mean: float = 0.5, | ||
std: float = 0.5, | ||
resize_size: Optional[Tuple[int, ...]] = None, | ||
max_disparity: Optional[float] = None, | ||
interpolation_type: str = "bilinear", | ||
use_grayscale: bool = False, | ||
) -> None: | ||
super().__init__() | ||
|
||
transforms = [ | ||
T.ToTensor(), | ||
T.ConvertImageDtype(torch.float32), | ||
] | ||
|
||
if use_grayscale: | ||
transforms.append(T.ConvertToGrayscale()) | ||
|
||
if resize_size is not None: | ||
transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type)) | ||
|
||
transforms.extend( | ||
[ | ||
T.Normalize(mean=mean, std=std), | ||
T.MakeValidDisparityMask(max_disparity=max_disparity), | ||
T.ValidateModelInput(), | ||
] | ||
) | ||
|
||
self.transforms = T.Compose(transforms) | ||
|
||
def forward(self, images, disparities, masks): | ||
return self.transforms(images, disparities, masks) | ||
|
||
|
||
class StereoMatchingTrainPreset(torch.nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
resize_size: Optional[Tuple[int, ...]], | ||
resize_interpolation_type: str = "bilinear", | ||
# RandomResizeAndCrop params | ||
crop_size: Tuple[int, int], | ||
rescale_prob: float = 1.0, | ||
scaling_type: str = "exponential", | ||
scale_range: Tuple[float, float] = (-0.2, 0.5), | ||
scale_interpolation_type: str = "bilinear", | ||
# convert to grayscale | ||
use_grayscale: bool = False, | ||
# normalization params | ||
mean: float = 0.5, | ||
std: float = 0.5, | ||
# processing device | ||
gpu_transforms: bool = False, | ||
# masking | ||
max_disparity: Optional[int] = 256, | ||
# SpatialShift params | ||
spatial_shift_prob: float = 0.5, | ||
spatial_shift_max_angle: float = 0.5, | ||
spatial_shift_max_displacement: float = 0.5, | ||
spatial_shift_interpolation_type: str = "bilinear", | ||
# AssymetricColorJitter | ||
gamma_range: Tuple[float, float] = (0.8, 1.2), | ||
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2), | ||
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2), | ||
saturation: Union[int, Tuple[int, int]] = 0.0, | ||
hue: Union[int, Tuple[int, int]] = 0.0, | ||
asymmetric_jitter_prob: float = 1.0, | ||
# RandomHorizontalFlip | ||
horizontal_flip_prob: float = 0.5, | ||
# RandomOcclusion | ||
occlusion_prob: float = 0.0, | ||
occlusion_px_range: Tuple[int, int] = (50, 100), | ||
# RandomErase | ||
erase_prob: float = 0.0, | ||
erase_px_range: Tuple[int, int] = (50, 100), | ||
erase_num_repeats: int = 1, | ||
) -> None: | ||
|
||
if scaling_type not in ["linear", "exponential"]: | ||
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential") | ||
|
||
super().__init__() | ||
transforms = [T.ToTensor()] | ||
|
||
# when fixing size across multiple datasets, we ensure | ||
# that the same size is used for all datasets when cropping | ||
if resize_size is not None: | ||
transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type)) | ||
|
||
if gpu_transforms: | ||
transforms.append(T.ToGPU()) | ||
|
||
# color handling | ||
color_transforms = [ | ||
T.AsymmetricColorJitter( | ||
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob | ||
), | ||
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range), | ||
] | ||
|
||
if use_grayscale: | ||
color_transforms.append(T.ConvertToGrayscale()) | ||
|
||
transforms.extend(color_transforms) | ||
|
||
transforms.extend( | ||
[ | ||
T.RandomSpatialShift( | ||
p=spatial_shift_prob, | ||
max_angle=spatial_shift_max_angle, | ||
max_px_shift=spatial_shift_max_displacement, | ||
interpolation_type=spatial_shift_interpolation_type, | ||
), | ||
T.ConvertImageDtype(torch.float32), | ||
T.RandomRescaleAndCrop( | ||
crop_size=crop_size, | ||
scale_range=scale_range, | ||
rescale_prob=rescale_prob, | ||
scaling_type=scaling_type, | ||
interpolation_type=scale_interpolation_type, | ||
), | ||
T.RandomHorizontalFlip(horizontal_flip_prob), | ||
# occlusion after flip, otherwise we're occluding the reference image | ||
T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range), | ||
T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats), | ||
T.Normalize(mean=mean, std=std), | ||
T.MakeValidDisparityMask(max_disparity), | ||
T.ValidateModelInput(), | ||
] | ||
) | ||
|
||
self.transforms = T.Compose(transforms) | ||
|
||
def forward(self, images, disparties, mask): | ||
return self.transforms(images, disparties, mask) |
Oops, something went wrong.