Skip to content

Commit

Permalink
[fbsync] remove datapoints compatibility for prototype datasets (#7154)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D44416260

fbshipit-source-id: 9437edffc8a7ccf08f381c5147d9a8f3e18530a3
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Mar 28, 2023
1 parent 049e7e2 commit 6fb095e
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 36 deletions.
16 changes: 10 additions & 6 deletions test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, datasets, transforms
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE


Expand Down Expand Up @@ -136,18 +137,21 @@ def make_msg_and_close(head):
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_simple_tensors(self, dataset_mock, config):
def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))

simple_tensors = {
key
for key, value in next_consume(iter(dataset)).items()
if torchvision.prototype.transforms.utils.is_simple_tensor(value)
key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
}
if simple_tensors:

if simple_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
):
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors."
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, "
f"but didn't find any (encoded) image or video."
)

@parametrize_dataset_mocks(DATASET_MOCKS)
Expand Down
19 changes: 1 addition & 18 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,9 @@ def _to_tensor(
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)

# FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
# interpreted as images. We should decide if we want a public no-op datapoint like `GenericDatapoint` or make this
# one public again.
def __new__(
cls,
data: Any,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Datapoint:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return tensor.as_subclass(Datapoint)

@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
# this method should be made abstract
# raise NotImplementedError
return tensor.as_subclass(cls)
raise NotImplementedError

_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
Expand Down
5 changes: 3 additions & 2 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from typing import Any, BinaryIO, Dict, List, Tuple, Union

import numpy as np

import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
Expand Down Expand Up @@ -115,7 +116,7 @@ def _prepare_sample(
format="xyxy",
spatial_size=image.spatial_size,
),
contour=Datapoint(ann["obj_contour"].T),
contour=torch.as_tensor(ann["obj_contour"].T),
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -149,7 +149,7 @@ def _prepare_sample(
spatial_size=image.spatial_size,
),
landmarks={
landmark: Datapoint((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
landmark: torch.tensor((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()}
},
)
Expand Down
5 changes: 2 additions & 3 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
UnBatcher,
)
from torchvision.prototype.datapoints import BoundingBox, Label, Mask
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -124,8 +123,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
]
)
),
areas=Datapoint([ann["area"] for ann in anns]),
crowds=Datapoint([ann["iscrowd"] for ann in anns], dtype=torch.bool),
areas=torch.as_tensor([ann["area"] for ann in anns]),
crowds=torch.as_tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns],
format="xywh",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union

import torch
from torchdata.datapipes.iter import (
CSVDictParser,
CSVParser,
Expand All @@ -15,7 +16,6 @@
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -162,7 +162,7 @@ def _2010_prepare_ann(
format="xyxy",
spatial_size=spatial_size,
),
segmentation=Datapoint(content["seg"]),
segmentation=torch.as_tensor(content["seg"]),
)

def _prepare_sample(
Expand Down
8 changes: 5 additions & 3 deletions torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -92,8 +92,10 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries=Datapoint(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
segmentation=Datapoint(anns["Segmentation"].item()),
boundaries=torch.as_tensor(
np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])
),
segmentation=torch.as_tensor(anns["Segmentation"].item()),
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down

0 comments on commit 6fb095e

Please sign in to comment.