Skip to content

Commit 272e080

Browse files
authored
add initial chunk of prototype transforms (#4861)
* add initial chunk of prototype transforms * fix tests * add error message * fix more imports * add explicit no-ops * add test for no-ops * cleanup
1 parent 57e6e30 commit 272e080

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+901
-120
lines changed

test/test_prototype_datasets_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
from torchvision.prototype import datasets
5-
from torchvision.prototype.datasets.utils._internal import FrozenMapping, FrozenBunch
5+
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch
66

77

88
def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs):

test/test_prototype_features.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32)
1212

1313

14+
def make_image(**kwargs):
15+
data = make_tensor((3, *torch.randint(16, 33, (2,)).tolist()))
16+
return features.Image(data, **kwargs)
17+
18+
1419
def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
1520
if isinstance(format, str):
1621
format = features.BoundingBoxFormat[format]
@@ -42,6 +47,7 @@ def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
4247

4348

4449
MAKE_DATA_MAP = {
50+
features.Image: make_image,
4551
features.BoundingBox: make_bounding_box,
4652
}
4753

test/test_prototype_transforms.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
from torchvision.prototype import transforms, features
3+
from torchvision.prototype.utils._internal import sequence_to_str
4+
5+
6+
FEATURE_TYPES = {
7+
feature_type
8+
for name, feature_type in features.__dict__.items()
9+
if not name.startswith("_")
10+
and isinstance(feature_type, type)
11+
and issubclass(feature_type, features.Feature)
12+
and feature_type is not features.Feature
13+
}
14+
15+
TRANSFORM_TYPES = tuple(
16+
transform_type
17+
for name, transform_type in transforms.__dict__.items()
18+
if not name.startswith("_")
19+
and isinstance(transform_type, type)
20+
and issubclass(transform_type, transforms.Transform)
21+
and transform_type is not transforms.Transform
22+
)
23+
24+
25+
def test_feature_type_support():
26+
missing_feature_types = FEATURE_TYPES - set(transforms.Transform._BUILTIN_FEATURE_TYPES)
27+
if missing_feature_types:
28+
names = sorted([feature_type.__name__ for feature_type in missing_feature_types])
29+
raise AssertionError(
30+
f"The feature(s) {sequence_to_str(names, separate_last='and ')} is/are exposed at "
31+
f"`torchvision.prototype.features`, but are missing in Transform._BUILTIN_FEATURE_TYPES. "
32+
f"Please add it/them to the collection."
33+
)
34+
35+
36+
@pytest.mark.parametrize(
37+
"transform_type",
38+
[transform_type for transform_type in TRANSFORM_TYPES if transform_type is not transforms.Identity],
39+
ids=lambda transform_type: transform_type.__name__,
40+
)
41+
def test_no_op(transform_type):
42+
unsupported_features = (
43+
FEATURE_TYPES - transform_type.supported_feature_types() - set(transform_type.NO_OP_FEATURE_TYPES)
44+
)
45+
if unsupported_features:
46+
names = sorted([feature_type.__name__ for feature_type in unsupported_features])
47+
raise AssertionError(
48+
f"The feature(s) {sequence_to_str(names, separate_last='and ')} are neither supported nor declared as "
49+
f"no-op for transform `{transform_type.__name__}`. Please either implement a feature transform for them, "
50+
f"or add them to the the `{transform_type.__name__}.NO_OP_FEATURE_TYPES` collection."
51+
)

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
Enumerator,
2121
getitem,
2222
read_mat,
23-
FrozenMapping,
2423
)
24+
from torchvision.prototype.utils._internal import FrozenMapping
2525

2626

2727
class ImageNet(Dataset):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import _internal
22
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
3+
from ._query import SampleQuery
34
from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
import torch
1111
from torch.utils.data import IterDataPipe
12+
from torchvision.prototype.utils._internal import FrozenBunch, make_repr
1213
from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str
1314

1415
from .._home import use_sharded_dataset
15-
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe
16+
from ._internal import BUILTIN_DIR, _make_sharded_datapipe
1617
from ._resource import OnlineResource
1718

1819

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import csv
21
import enum
32
import gzip
43
import io
@@ -7,7 +6,6 @@
76
import os.path
87
import pathlib
98
import pickle
10-
import textwrap
119
from typing import (
1210
Sequence,
1311
Callable,
@@ -18,10 +16,7 @@
1816
Iterator,
1917
Dict,
2018
Optional,
21-
NoReturn,
2219
IO,
23-
Iterable,
24-
Mapping,
2520
Sized,
2621
)
2722
from typing import cast
@@ -38,10 +33,6 @@
3833
__all__ = [
3934
"INFINITE_BUFFER_SIZE",
4035
"BUILTIN_DIR",
41-
"make_repr",
42-
"FrozenMapping",
43-
"FrozenBunch",
44-
"create_categories_file",
4536
"read_mat",
4637
"image_buffer_from_array",
4738
"SequenceIterator",
@@ -62,82 +53,6 @@
6253
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
6354

6455

65-
def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
66-
def to_str(sep: str) -> str:
67-
return sep.join([f"{key}={value}" for key, value in items])
68-
69-
prefix = f"{name}("
70-
postfix = ")"
71-
body = to_str(", ")
72-
73-
line_length = int(os.environ.get("COLUMNS", 80))
74-
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
75-
multiline_body = len(str(body).splitlines()) > 1
76-
if not (body_too_long or multiline_body):
77-
return prefix + body + postfix
78-
79-
body = textwrap.indent(to_str(",\n"), " " * 2)
80-
return f"{prefix}\n{body}\n{postfix}"
81-
82-
83-
class FrozenMapping(Mapping[K, D]):
84-
def __init__(self, *args: Any, **kwargs: Any) -> None:
85-
data = dict(*args, **kwargs)
86-
self.__dict__["__data__"] = data
87-
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
88-
89-
def __getitem__(self, item: K) -> D:
90-
return cast(Mapping[K, D], self.__dict__["__data__"])[item]
91-
92-
def __iter__(self) -> Iterator[K]:
93-
return iter(self.__dict__["__data__"].keys())
94-
95-
def __len__(self) -> int:
96-
return len(self.__dict__["__data__"])
97-
98-
def __setitem__(self, key: K, value: Any) -> NoReturn:
99-
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
100-
101-
def __delitem__(self, key: K) -> NoReturn:
102-
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
103-
104-
def __hash__(self) -> int:
105-
return cast(int, self.__dict__["__final_hash__"])
106-
107-
def __eq__(self, other: Any) -> bool:
108-
if not isinstance(other, FrozenMapping):
109-
return NotImplemented
110-
111-
return hash(self) == hash(other)
112-
113-
def __repr__(self) -> str:
114-
return repr(self.__dict__["__data__"])
115-
116-
117-
class FrozenBunch(FrozenMapping):
118-
def __getattr__(self, name: str) -> Any:
119-
try:
120-
return self[name]
121-
except KeyError as error:
122-
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
123-
124-
def __setattr__(self, key: Any, value: Any) -> NoReturn:
125-
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
126-
127-
def __delattr__(self, item: Any) -> NoReturn:
128-
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
129-
130-
def __repr__(self) -> str:
131-
return make_repr(type(self).__name__, self.items())
132-
133-
134-
def create_categories_file(
135-
root: Union[str, pathlib.Path], name: str, categories: Sequence[Union[str, Sequence[str]]], **fmtparams: Any
136-
) -> None:
137-
with open(pathlib.Path(root) / f"{name}.categories", "w", newline="") as file:
138-
csv.writer(file, **fmtparams).writerows(categories)
139-
140-
14156
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
14257
try:
14358
import scipy.io as sio
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import collections.abc
2+
from typing import Any, Callable, Iterator, Optional, Tuple, TypeVar, cast
3+
4+
from torchvision.prototype.features import BoundingBox, Image
5+
6+
T = TypeVar("T")
7+
8+
9+
class SampleQuery:
10+
def __init__(self, sample: Any) -> None:
11+
self.sample = sample
12+
13+
@staticmethod
14+
def _query_recursively(sample: Any, fn: Callable[[Any], Optional[T]]) -> Iterator[T]:
15+
if isinstance(sample, (collections.abc.Sequence, collections.abc.Mapping)):
16+
for item in sample.values() if isinstance(sample, collections.abc.Mapping) else sample:
17+
yield from SampleQuery._query_recursively(item, fn)
18+
else:
19+
result = fn(sample)
20+
if result is not None:
21+
yield result
22+
23+
def query(self, fn: Callable[[Any], Optional[T]]) -> T:
24+
results = set(self._query_recursively(self.sample, fn))
25+
if not results:
26+
raise RuntimeError("Query turned up empty.")
27+
elif len(results) > 1:
28+
raise RuntimeError(f"Found more than one result: {results}")
29+
30+
return results.pop()
31+
32+
def image_size(self) -> Tuple[int, int]:
33+
def fn(sample: Any) -> Optional[Tuple[int, int]]:
34+
if isinstance(sample, Image):
35+
return cast(Tuple[int, int], sample.shape[-2:])
36+
elif isinstance(sample, BoundingBox):
37+
return sample.image_size
38+
else:
39+
return None
40+
41+
return self.query(fn)

torchvision/prototype/features/_image.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ class Image(Feature):
1818
color_spaces = ColorSpace
1919
color_space: ColorSpace
2020

21+
@classmethod
22+
def _to_tensor(cls, data, *, dtype, device):
23+
tensor = torch.as_tensor(data, dtype=dtype, device=device)
24+
if tensor.ndim == 2:
25+
tensor = tensor.unsqueeze(0)
26+
elif tensor.ndim != 3:
27+
raise ValueError("Only single images with 2 or 3 dimensions are allowed.")
28+
return tensor
29+
2130
@classmethod
2231
def _parse_meta_data(
2332
cls,

torchvision/prototype/models/alexnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from functools import partial
33
from typing import Any, Optional
44

5+
from torchvision.prototype.transforms import ImageNetEval
56
from torchvision.transforms.functional import InterpolationMode
67

78
from ...models.alexnet import AlexNet
8-
from ..transforms.presets import ImageNetEval
99
from ._api import Weights, WeightEntry
1010
from ._meta import _IMAGENET_CATEGORIES
1111

0 commit comments

Comments
 (0)