Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transforms and presets for optical flow models #5026

Merged
merged 7 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions references/optical_flow/presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import transforms as T


class OpticalFlowPresetEval(torch.nn.Module):
def __init__(self):
super().__init__()

self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.ValidateModelInput(),
]
)

def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)


class OpticalFlowPresetTrain(torch.nn.Module):
def __init__(
self,
# RandomResizeAndCrop params
crop_size,
min_scale=-0.2,
max_scale=0.5,
stretch_prob=0.8,
# AsymmetricColorJitter params
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.5 / 3.14,
# Random[H,V]Flip params
asymmetric_jitter_prob=0.2,
do_flip=True,
):
super().__init__()

transforms = [
T.PILToTensor(),
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.RandomResizeAndCrop(
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
),
]

if do_flip:
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]

transforms += [
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
datumbox marked this conversation as resolved.
Show resolved Hide resolved
T.RandomErasing(max_erase=2),
T.MakeValidFlowMask(),
T.ValidateModelInput(),
]
self.transforms = T.Compose(transforms)

def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)
261 changes: 261 additions & 0 deletions references/optical_flow/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F


class ValidateModelInput(torch.nn.Module):
datumbox marked this conversation as resolved.
Show resolved Hide resolved
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
def forward(self, img1, img2, flow, valid_flow_mask):

assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None)
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None)

assert img1.shape == img2.shape
h, w = img1.shape[-2:]
if flow is not None:
assert flow.shape == (2, h, w)
if valid_flow_mask is not None:
assert valid_flow_mask.shape == (h, w)
assert valid_flow_mask.dtype == torch.bool

return img1, img2, flow, valid_flow_mask


class MakeValidFlowMask(torch.nn.Module):
# This transform generates a valid_flow_mask if it doesn't exist.
# The flow is considered valid if ||flow||_inf < threshold
# This is a noop for Kitti and HD1K which already come with a built-in flow mask.
def __init__(self, threshold=1000):
super().__init__()
self.threshold = threshold

def forward(self, img1, img2, flow, valid_flow_mask):
if flow is not None and valid_flow_mask is None:
valid_flow_mask = (flow.abs() < self.threshold).all(axis=0)
return img1, img2, flow, valid_flow_mask


class ConvertImageDtype(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.dtype = dtype

def forward(self, img1, img2, flow, valid_flow_mask):
img1 = F.convert_image_dtype(img1, dtype=self.dtype)
img2 = F.convert_image_dtype(img2, dtype=self.dtype)

img1 = img1.contiguous()
img2 = img2.contiguous()

return img1, img2, flow, valid_flow_mask


class Normalize(torch.nn.Module):
def __init__(self, mean, std):
super().__init__()
self.mean = mean
self.std = std

def forward(self, img1, img2, flow, valid_flow_mask):
img1 = F.normalize(img1, mean=self.mean, std=self.std)
img2 = F.normalize(img2, mean=self.mean, std=self.std)

return img1, img2, flow, valid_flow_mask


class PILToTensor(torch.nn.Module):
# Converts all inputs to tensors
# Technically the flow and the valid mask are numpy arrays, not PIL images, but we keep that naming
# for consistency with the rest, e.g. the segmentation reference.
def forward(self, img1, img2, flow, valid_flow_mask):
img1 = F.pil_to_tensor(img1)
img2 = F.pil_to_tensor(img2)
if flow is not None:
flow = torch.from_numpy(flow)
if valid_flow_mask is not None:
valid_flow_mask = torch.from_numpy(valid_flow_mask)

return img1, img2, flow, valid_flow_mask


class AsymmetricColorJitter(T.ColorJitter):
# p determines the proba of doing asymmertric vs symmetric color jittering
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.2):
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
self.p = p

def forward(self, img1, img2, flow, valid_flow_mask):

if torch.rand(1) < self.p:
# asymmetric: different transform for img1 and img2
img1 = super().forward(img1)
img2 = super().forward(img2)
else:
# symmetric: same transform for img1 and img2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug: So does the p determine the probability of doing symmetric VS asymmetric? If yes I would add a comment to clarify.

@pmeier: Could you please check this strange transform to confirm it's supported by the new Transforms API?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it stands, this would not be supported. A transform always treats a sample as atomic unit and so multiple images in the same sample would be transformed with the same parameters.

Copy link
Member Author

@NicolasHug NicolasHug Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll clarify.

Ultimately this is a special case of RandomApply(t1, t2, [p, p - 1]), so there's nothing too fancy here

t2 can be a Sequential(take_care_of_img1_only, take_care_of_img2_only)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug Sounds good, just add comments. No need to use RandomApply here.

Copy link
Contributor

@datumbox datumbox Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier No worries, this is why we give the option for someone to write custom transforms without the magic of the new API. For weird cases like this. Could you now confirm that this is indeed a workaround we can apply?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug

Sequential(take_care_of_img1_only, take_care_of_img2_only)

I'm guessing take_care_of_img1_only and take_care_of_img2_only are transforms here, correct? If yes, how would you tell the transform to only handle one or the other image if both receive the full sample?

I think this is one of the cases @datumbox mentioned where we need to circumvent the automatic dispatch a little. In case we want to transform both samples separately, we could split the sample and and perform the transformation once for the sample minus image 2 and once for image2. The problem I see with this, is that it can't be automated without assumptions about how the sample is structured. So we either need to use the same structure for every dataset (for example flat dictionary with image1 and image2 keys) or provide a way to parametrize the transform.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing take_care_of_img1_only and take_care_of_img2_only are transforms here, correct? If yes, how would you tell the transform to only handle one or the other image if both receive the full sample?

Each transform would receive the entire input (which IIRC is a dict) and operate on a subset of that dict.

Are you suggesting that img1 and img2 would be concatenated?

batch = torch.stack([img1, img2])
batch = super().forward(batch)
img1, img2 = batch[0], batch[1]

return img1, img2, flow, valid_flow_mask


class RandomErasing(T.RandomErasing):
# This only erases img2, and with an extra max_erase param
# This max_erase is needed because in the RAFT training ref does:
# 0 erasing with .5 proba
# 1 erase with .25 proba
# 2 erase with .25 proba
# and there's no accurate way to achieve this otherwise.
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
self.max_erase = max_erase
assert self.max_erase > 0

def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
return img1, img2, flow, valid_flow_mask

for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
x, y, h, w, v = self.get_params(img2, scale=self.scale, ratio=self.ratio, value=[self.value])
img2 = F.erase(img2, x, y, h, w, v, self.inplace)

return img1, img2, flow, valid_flow_mask


class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
return img1, img2, flow, valid_flow_mask

img1 = F.hflip(img1)
img2 = F.hflip(img2)
flow = F.hflip(flow) * torch.tensor([-1, 1])[:, None, None]
if valid_flow_mask is not None:
valid_flow_mask = F.hflip(valid_flow_mask)
return img1, img2, flow, valid_flow_mask


class RandomVerticalFlip(T.RandomVerticalFlip):
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
return img1, img2, flow, valid_flow_mask

img1 = F.vflip(img1)
img2 = F.vflip(img2)
flow = F.vflip(flow) * torch.tensor([1, -1])[:, None, None]
if valid_flow_mask is not None:
valid_flow_mask = F.vflip(valid_flow_mask)
return img1, img2, flow, valid_flow_mask


class RandomResizeAndCrop(torch.nn.Module):
# This transform will resize the input with a given proba, and then crop it.
# These are the reversed operations of the built-in RandomResizedCrop,
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
# although the order of the operations doesn't matter too much: resizing a
# crop would give the same result as cropping a resized image, up to
# interpolation artifact at the borders of the output.
#
# The reason we don't rely on RandomResizedCrop is because of a significant
# difference in the parametrization of both transforms, in particular,
# because of the way the random parameters are sampled in both transforms,
# which leads to fairly different resuts (and different epe). For more details see
# https://github.com/pytorch/vision/pull/5026/files#r762932579
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, stretch_prob=0.8):
super().__init__()
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.stretch_prob = stretch_prob
self.resize_prob = 0.8
self.max_stretch = 0.2

def forward(self, img1, img2, flow, valid_flow_mask):
# randomly sample scale
h, w = img1.shape[-2:]
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
# It shouldn't matter much
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)

scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
scale_x = scale
scale_y = scale
if torch.rand(1) < self.stretch_prob:
scale_x *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
scale_y *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()

scale_x = max(scale_x, min_scale)
scale_y = max(scale_y, min_scale)

new_h, new_w = round(h * scale_y), round(w * scale_x)

if torch.rand(1).item() < self.resize_prob:
# rescale the images
img1 = F.resize(img1, size=(new_h, new_w))
img2 = F.resize(img2, size=(new_h, new_w))
if valid_flow_mask is None:
flow = F.resize(flow, size=(new_h, new_w))
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None]
else:
flow, valid_flow_mask = self._resize_sparse_flow(
flow, valid_flow_mask, scale_x=scale_x, scale_y=scale_y
)

# Note: For sparse datasets (Kitti), the original code uses a "margin"
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
# We don't, not sure it matters much
y0 = torch.randint(0, img1.shape[1] - self.crop_size[0], size=(1,)).item()
x0 = torch.randint(0, img1.shape[2] - self.crop_size[1], size=(1,)).item()

img1 = F.crop(img1, y0, x0, self.crop_size[0], self.crop_size[1])
img2 = F.crop(img2, y0, x0, self.crop_size[0], self.crop_size[1])
flow = F.crop(flow, y0, x0, self.crop_size[0], self.crop_size[1])
if valid_flow_mask is not None:
valid_flow_mask = F.crop(valid_flow_mask, y0, x0, self.crop_size[0], self.crop_size[1])

return img1, img2, flow, valid_flow_mask

def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0):
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4

h, w = flow.shape[-2:]

h_new = int(round(h * scale_y))
w_new = int(round(w * scale_x))
flow_new = torch.zeros(size=[2, h_new, w_new], dtype=flow.dtype)
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)

jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")

ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]

ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)

within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)

ii_valid = ii_valid[within_bounds_mask]
jj_valid = jj_valid[within_bounds_mask]
ii_valid_new = ii_valid_new[within_bounds_mask]
jj_valid_new = jj_valid_new[within_bounds_mask]

valid_flow_new = flow[:, ii_valid, jj_valid]
valid_flow_new[0] *= scale_x
valid_flow_new[1] *= scale_y

flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
valid_new[ii_valid_new, jj_valid_new] = 1

return flow_new, valid_new


class Compose(torch.nn.Module):
def __init__(self, transforms):
super().__init__()
self.transforms = transforms

def forward(self, img1, img2, flow, valid_flow_mask):
for t in self.transforms:
img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask)
return img1, img2, flow, valid_flow_mask