Skip to content

Commit

Permalink
[fbsync] Some updates to optical flow datasets (#5004)
Browse files Browse the repository at this point in the history
Reviewed By: NicolasHug

Differential Revision: D32950932

fbshipit-source-id: a775ede624bf71cf8716f276892831717e55a4f7
  • Loading branch information
Vincent Moens authored and facebook-github-bot committed Dec 9, 2021
1 parent 8743344 commit 683167f
Showing 1 changed file with 48 additions and 32 deletions.
80 changes: 48 additions & 32 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@


class FlowDataset(ABC, VisionDataset):
# Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what `valid` should be.
# Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False

def __init__(self, root, transforms=None):
Expand All @@ -38,11 +38,14 @@ def __init__(self, root, transforms=None):
self._image_list = []

def _read_img(self, file_name):
return Image.open(file_name)
img = Image.open(file_name)
if img.mode != "RGB":
img = img.convert("RGB")
return img

@abstractmethod
def _read_flow(self, file_name):
# Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True
# Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
pass

def __getitem__(self, index):
Expand All @@ -53,23 +56,27 @@ def __getitem__(self, index):
if self._flow_list: # it will be empty for some dataset when split="test"
flow = self._read_flow(self._flow_list[index])
if self._has_builtin_flow_mask:
flow, valid = flow
flow, valid_flow_mask = flow
else:
valid = None
valid_flow_mask = None
else:
flow = valid = None
flow = valid_flow_mask = None

if self.transforms is not None:
img1, img2, flow, valid = self.transforms(img1, img2, flow, valid)
img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)

if self._has_builtin_flow_mask:
return img1, img2, flow, valid
if self._has_builtin_flow_mask or valid_flow_mask is not None:
# The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
return img1, img2, flow, valid_flow_mask
else:
return img1, img2, flow

def __len__(self):
return len(self._image_list)

def __rmul__(self, v):
return torch.utils.data.ConcatDataset([self] * v)


class Sintel(FlowDataset):
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
Expand Down Expand Up @@ -107,8 +114,8 @@ class Sintel(FlowDataset):
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes.
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

Expand Down Expand Up @@ -140,9 +147,11 @@ def __getitem__(self, index):
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a
3-tuple with ``(img1, img2, None)`` is returned.
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="test"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
return super().__getitem__(index)

Expand All @@ -167,7 +176,7 @@ class KittiFlow(FlowDataset):
root (string): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True
Expand Down Expand Up @@ -199,11 +208,11 @@ def __getitem__(self, index):
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
``split="test"``.
"""
return super().__getitem__(index)

Expand Down Expand Up @@ -232,8 +241,8 @@ class FlyingChairs(FlowDataset):
root (string): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

Expand Down Expand Up @@ -269,6 +278,9 @@ def __getitem__(self, index):
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="val"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
return super().__getitem__(index)

Expand Down Expand Up @@ -300,8 +312,8 @@ class FlyingThings3D(FlowDataset):
details on the different passes.
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

Expand Down Expand Up @@ -357,6 +369,9 @@ def __getitem__(self, index):
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="test"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
return super().__getitem__(index)

Expand All @@ -382,7 +397,7 @@ class HD1K(FlowDataset):
root (string): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True
Expand Down Expand Up @@ -422,11 +437,11 @@ def __getitem__(self, index):
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
``split="test"``.
"""
return super().__getitem__(index)

Expand All @@ -451,11 +466,12 @@ def _read_flo(file_name):
def _read_16bits_png_with_flow_and_valid_mask(file_name):

flow_and_valid = _read_png_16(file_name).to(torch.float32)
flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive
valid_flow_mask = valid_flow_mask.bool()

# For consistency with other datasets, we convert to numpy
return flow.numpy(), valid.numpy()
return flow.numpy(), valid_flow_mask.numpy()


def _read_pfm(file_name):
Expand Down

0 comments on commit 683167f

Please sign in to comment.