diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 80063a6e3de..7af90d5e78a 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -24,9 +24,9 @@ class FlowDataset(ABC, VisionDataset): - # Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid - # For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow), - # and it's up to whatever consumes the dataset to decide what `valid` should be. + # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid + # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow), + # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. _has_builtin_flow_mask = False def __init__(self, root, transforms=None): @@ -38,11 +38,14 @@ def __init__(self, root, transforms=None): self._image_list = [] def _read_img(self, file_name): - return Image.open(file_name) + img = Image.open(file_name) + if img.mode != "RGB": + img = img.convert("RGB") + return img @abstractmethod def _read_flow(self, file_name): - # Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True + # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True pass def __getitem__(self, index): @@ -53,23 +56,27 @@ def __getitem__(self, index): if self._flow_list: # it will be empty for some dataset when split="test" flow = self._read_flow(self._flow_list[index]) if self._has_builtin_flow_mask: - flow, valid = flow + flow, valid_flow_mask = flow else: - valid = None + valid_flow_mask = None else: - flow = valid = None + flow = valid_flow_mask = None if self.transforms is not None: - img1, img2, flow, valid = self.transforms(img1, img2, flow, valid) + img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask) - if self._has_builtin_flow_mask: - return img1, img2, flow, valid + if self._has_builtin_flow_mask or valid_flow_mask is not None: + # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform + return img1, img2, flow, valid_flow_mask else: return img1, img2, flow def __len__(self): return len(self._image_list) + def __rmul__(self, v): + return torch.utils.data.ConcatDataset([self] * v) + class Sintel(FlowDataset): """`Sintel `_ Dataset for optical flow. @@ -107,8 +114,8 @@ class Sintel(FlowDataset): pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for details on the different passes. transforms (callable, optional): A function/transform that takes in - ``img1, img2, flow, valid`` and returns a transformed version. - ``valid`` is expected for consistency with other datasets which + ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. + ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ @@ -140,9 +147,11 @@ def __getitem__(self, index): index(int): The index of the example to retrieve Returns: - tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``. - The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a - 3-tuple with ``(img1, img2, None)`` is returned. + tuple: A 3-tuple with ``(img1, img2, flow)``. + The flow is a numpy array of shape (2, H, W) and the images are PIL images. + ``flow`` is None if ``split="test"``. + If a valid flow mask is generated within the ``transforms`` parameter, + a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned. """ return super().__getitem__(index) @@ -167,7 +176,7 @@ class KittiFlow(FlowDataset): root (string): Root directory of the KittiFlow Dataset. split (string, optional): The dataset split, either "train" (default) or "test" transforms (callable, optional): A function/transform that takes in - ``img1, img2, flow, valid`` and returns a transformed version. + ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. """ _has_builtin_flow_mask = True @@ -199,11 +208,11 @@ def __getitem__(self, index): index(int): The index of the example to retrieve Returns: - tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow, - valid)`` where ``valid`` is a numpy boolean mask of shape (H, W) + tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` + where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W) indicating which flow values are valid. The flow is a numpy array of - shape (2, H, W) and the images are PIL images. If `split="test"`, a - 4-tuple with ``(img1, img2, None, None)`` is returned. + shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if + ``split="test"``. """ return super().__getitem__(index) @@ -232,8 +241,8 @@ class FlyingChairs(FlowDataset): root (string): Root directory of the FlyingChairs Dataset. split (string, optional): The dataset split, either "train" (default) or "val" transforms (callable, optional): A function/transform that takes in - ``img1, img2, flow, valid`` and returns a transformed version. - ``valid`` is expected for consistency with other datasets which + ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. + ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ @@ -269,6 +278,9 @@ def __getitem__(self, index): Returns: tuple: A 3-tuple with ``(img1, img2, flow)``. The flow is a numpy array of shape (2, H, W) and the images are PIL images. + ``flow`` is None if ``split="val"``. + If a valid flow mask is generated within the ``transforms`` parameter, + a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned. """ return super().__getitem__(index) @@ -300,8 +312,8 @@ class FlyingThings3D(FlowDataset): details on the different passes. camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both". transforms (callable, optional): A function/transform that takes in - ``img1, img2, flow, valid`` and returns a transformed version. - ``valid`` is expected for consistency with other datasets which + ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. + ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ @@ -357,6 +369,9 @@ def __getitem__(self, index): Returns: tuple: A 3-tuple with ``(img1, img2, flow)``. The flow is a numpy array of shape (2, H, W) and the images are PIL images. + ``flow`` is None if ``split="test"``. + If a valid flow mask is generated within the ``transforms`` parameter, + a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned. """ return super().__getitem__(index) @@ -382,7 +397,7 @@ class HD1K(FlowDataset): root (string): Root directory of the HD1K Dataset. split (string, optional): The dataset split, either "train" (default) or "test" transforms (callable, optional): A function/transform that takes in - ``img1, img2, flow, valid`` and returns a transformed version. + ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. """ _has_builtin_flow_mask = True @@ -422,11 +437,11 @@ def __getitem__(self, index): index(int): The index of the example to retrieve Returns: - tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow, - valid)`` where ``valid`` is a numpy boolean mask of shape (H, W) + tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask`` + is a numpy boolean mask of shape (H, W) indicating which flow values are valid. The flow is a numpy array of - shape (2, H, W) and the images are PIL images. If `split="test"`, a - 4-tuple with ``(img1, img2, None, None)`` is returned. + shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if + ``split="test"``. """ return super().__getitem__(index) @@ -451,11 +466,12 @@ def _read_flo(file_name): def _read_16bits_png_with_flow_and_valid_mask(file_name): flow_and_valid = _read_png_16(file_name).to(torch.float32) - flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] + flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive + valid_flow_mask = valid_flow_mask.bool() # For consistency with other datasets, we convert to numpy - return flow.numpy(), valid.numpy() + return flow.numpy(), valid_flow_mask.numpy() def _read_pfm(file_name):