-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add transforms and presets for optical flow models (#5026)
- Loading branch information
1 parent
4dd8b5c
commit 47ae092
Showing
2 changed files
with
325 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
import transforms as T | ||
|
||
|
||
class OpticalFlowPresetEval(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.transforms = T.Compose( | ||
[ | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float32), | ||
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] | ||
T.ValidateModelInput(), | ||
] | ||
) | ||
|
||
def forward(self, img1, img2, flow, valid): | ||
return self.transforms(img1, img2, flow, valid) | ||
|
||
|
||
class OpticalFlowPresetTrain(torch.nn.Module): | ||
def __init__( | ||
self, | ||
# RandomResizeAndCrop params | ||
crop_size, | ||
min_scale=-0.2, | ||
max_scale=0.5, | ||
stretch_prob=0.8, | ||
# AsymmetricColorJitter params | ||
brightness=0.4, | ||
contrast=0.4, | ||
saturation=0.4, | ||
hue=0.5 / 3.14, | ||
# Random[H,V]Flip params | ||
asymmetric_jitter_prob=0.2, | ||
do_flip=True, | ||
): | ||
super().__init__() | ||
|
||
transforms = [ | ||
T.PILToTensor(), | ||
T.AsymmetricColorJitter( | ||
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob | ||
), | ||
T.RandomResizeAndCrop( | ||
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob | ||
), | ||
] | ||
|
||
if do_flip: | ||
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)] | ||
|
||
transforms += [ | ||
T.ConvertImageDtype(torch.float32), | ||
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] | ||
T.RandomErasing(max_erase=2), | ||
T.MakeValidFlowMask(), | ||
T.ValidateModelInput(), | ||
] | ||
self.transforms = T.Compose(transforms) | ||
|
||
def forward(self, img1, img2, flow, valid): | ||
return self.transforms(img1, img2, flow, valid) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
import torch | ||
import torchvision.transforms as T | ||
import torchvision.transforms.functional as F | ||
|
||
|
||
class ValidateModelInput(torch.nn.Module): | ||
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
|
||
assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None) | ||
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None) | ||
|
||
assert img1.shape == img2.shape | ||
h, w = img1.shape[-2:] | ||
if flow is not None: | ||
assert flow.shape == (2, h, w) | ||
if valid_flow_mask is not None: | ||
assert valid_flow_mask.shape == (h, w) | ||
assert valid_flow_mask.dtype == torch.bool | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class MakeValidFlowMask(torch.nn.Module): | ||
# This transform generates a valid_flow_mask if it doesn't exist. | ||
# The flow is considered valid if ||flow||_inf < threshold | ||
# This is a noop for Kitti and HD1K which already come with a built-in flow mask. | ||
def __init__(self, threshold=1000): | ||
super().__init__() | ||
self.threshold = threshold | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if flow is not None and valid_flow_mask is None: | ||
valid_flow_mask = (flow.abs() < self.threshold).all(axis=0) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class ConvertImageDtype(torch.nn.Module): | ||
def __init__(self, dtype): | ||
super().__init__() | ||
self.dtype = dtype | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
img1 = F.convert_image_dtype(img1, dtype=self.dtype) | ||
img2 = F.convert_image_dtype(img2, dtype=self.dtype) | ||
|
||
img1 = img1.contiguous() | ||
img2 = img2.contiguous() | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class Normalize(torch.nn.Module): | ||
def __init__(self, mean, std): | ||
super().__init__() | ||
self.mean = mean | ||
self.std = std | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
img1 = F.normalize(img1, mean=self.mean, std=self.std) | ||
img2 = F.normalize(img2, mean=self.mean, std=self.std) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class PILToTensor(torch.nn.Module): | ||
# Converts all inputs to tensors | ||
# Technically the flow and the valid mask are numpy arrays, not PIL images, but we keep that naming | ||
# for consistency with the rest, e.g. the segmentation reference. | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
img1 = F.pil_to_tensor(img1) | ||
img2 = F.pil_to_tensor(img2) | ||
if flow is not None: | ||
flow = torch.from_numpy(flow) | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = torch.from_numpy(valid_flow_mask) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class AsymmetricColorJitter(T.ColorJitter): | ||
# p determines the proba of doing asymmertric vs symmetric color jittering | ||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.2): | ||
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) | ||
self.p = p | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
|
||
if torch.rand(1) < self.p: | ||
# asymmetric: different transform for img1 and img2 | ||
img1 = super().forward(img1) | ||
img2 = super().forward(img2) | ||
else: | ||
# symmetric: same transform for img1 and img2 | ||
batch = torch.stack([img1, img2]) | ||
batch = super().forward(batch) | ||
img1, img2 = batch[0], batch[1] | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomErasing(T.RandomErasing): | ||
# This only erases img2, and with an extra max_erase param | ||
# This max_erase is needed because in the RAFT training ref does: | ||
# 0 erasing with .5 proba | ||
# 1 erase with .25 proba | ||
# 2 erase with .25 proba | ||
# and there's no accurate way to achieve this otherwise. | ||
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1): | ||
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace) | ||
self.max_erase = max_erase | ||
assert self.max_erase > 0 | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if torch.rand(1) > self.p: | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
for _ in range(torch.randint(self.max_erase, size=(1,)).item()): | ||
x, y, h, w, v = self.get_params(img2, scale=self.scale, ratio=self.ratio, value=[self.value]) | ||
img2 = F.erase(img2, x, y, h, w, v, self.inplace) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomHorizontalFlip(T.RandomHorizontalFlip): | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if torch.rand(1) > self.p: | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
img1 = F.hflip(img1) | ||
img2 = F.hflip(img2) | ||
flow = F.hflip(flow) * torch.tensor([-1, 1])[:, None, None] | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = F.hflip(valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomVerticalFlip(T.RandomVerticalFlip): | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if torch.rand(1) > self.p: | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
img1 = F.vflip(img1) | ||
img2 = F.vflip(img2) | ||
flow = F.vflip(flow) * torch.tensor([1, -1])[:, None, None] | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = F.vflip(valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomResizeAndCrop(torch.nn.Module): | ||
# This transform will resize the input with a given proba, and then crop it. | ||
# These are the reversed operations of the built-in RandomResizedCrop, | ||
# although the order of the operations doesn't matter too much: resizing a | ||
# crop would give the same result as cropping a resized image, up to | ||
# interpolation artifact at the borders of the output. | ||
# | ||
# The reason we don't rely on RandomResizedCrop is because of a significant | ||
# difference in the parametrization of both transforms, in particular, | ||
# because of the way the random parameters are sampled in both transforms, | ||
# which leads to fairly different resuts (and different epe). For more details see | ||
# https://github.com/pytorch/vision/pull/5026/files#r762932579 | ||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, stretch_prob=0.8): | ||
super().__init__() | ||
self.crop_size = crop_size | ||
self.min_scale = min_scale | ||
self.max_scale = max_scale | ||
self.stretch_prob = stretch_prob | ||
self.resize_prob = 0.8 | ||
self.max_stretch = 0.2 | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
# randomly sample scale | ||
h, w = img1.shape[-2:] | ||
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti) | ||
# It shouldn't matter much | ||
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w) | ||
|
||
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item() | ||
scale_x = scale | ||
scale_y = scale | ||
if torch.rand(1) < self.stretch_prob: | ||
scale_x *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item() | ||
scale_y *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item() | ||
|
||
scale_x = max(scale_x, min_scale) | ||
scale_y = max(scale_y, min_scale) | ||
|
||
new_h, new_w = round(h * scale_y), round(w * scale_x) | ||
|
||
if torch.rand(1).item() < self.resize_prob: | ||
# rescale the images | ||
img1 = F.resize(img1, size=(new_h, new_w)) | ||
img2 = F.resize(img2, size=(new_h, new_w)) | ||
if valid_flow_mask is None: | ||
flow = F.resize(flow, size=(new_h, new_w)) | ||
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None] | ||
else: | ||
flow, valid_flow_mask = self._resize_sparse_flow( | ||
flow, valid_flow_mask, scale_x=scale_x, scale_y=scale_y | ||
) | ||
|
||
# Note: For sparse datasets (Kitti), the original code uses a "margin" | ||
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220 | ||
# We don't, not sure it matters much | ||
y0 = torch.randint(0, img1.shape[1] - self.crop_size[0], size=(1,)).item() | ||
x0 = torch.randint(0, img1.shape[2] - self.crop_size[1], size=(1,)).item() | ||
|
||
img1 = F.crop(img1, y0, x0, self.crop_size[0], self.crop_size[1]) | ||
img2 = F.crop(img2, y0, x0, self.crop_size[0], self.crop_size[1]) | ||
flow = F.crop(flow, y0, x0, self.crop_size[0], self.crop_size[1]) | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = F.crop(valid_flow_mask, y0, x0, self.crop_size[0], self.crop_size[1]) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0): | ||
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse) | ||
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB) | ||
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4 | ||
|
||
h, w = flow.shape[-2:] | ||
|
||
h_new = int(round(h * scale_y)) | ||
w_new = int(round(w * scale_x)) | ||
flow_new = torch.zeros(size=[2, h_new, w_new], dtype=flow.dtype) | ||
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype) | ||
|
||
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy") | ||
|
||
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask] | ||
|
||
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long) | ||
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long) | ||
|
||
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new) | ||
|
||
ii_valid = ii_valid[within_bounds_mask] | ||
jj_valid = jj_valid[within_bounds_mask] | ||
ii_valid_new = ii_valid_new[within_bounds_mask] | ||
jj_valid_new = jj_valid_new[within_bounds_mask] | ||
|
||
valid_flow_new = flow[:, ii_valid, jj_valid] | ||
valid_flow_new[0] *= scale_x | ||
valid_flow_new[1] *= scale_y | ||
|
||
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new | ||
valid_new[ii_valid_new, jj_valid_new] = 1 | ||
|
||
return flow_new, valid_new | ||
|
||
|
||
class Compose(torch.nn.Module): | ||
def __init__(self, transforms): | ||
super().__init__() | ||
self.transforms = transforms | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
for t in self.transforms: | ||
img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask |