Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_prototype_datasets_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import FrozenMapping, FrozenBunch
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch


def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions test/test_prototype_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32)


def make_image(**kwargs):
data = make_tensor((3, *torch.randint(16, 33, (2,)).tolist()))
return features.Image(data, **kwargs)


def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
Expand Down Expand Up @@ -42,6 +47,7 @@ def make_bounding_box(*, format="xyxy", image_size=(10, 10)):


MAKE_DATA_MAP = {
features.Image: make_image,
features.BoundingBox: make_bounding_box,
}

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
Enumerator,
getitem,
read_mat,
FrozenMapping,
)
from torchvision.prototype.utils._internal import FrozenMapping


class ImageNet(Dataset):
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import _internal
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery
from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.utils._internal import FrozenBunch, make_repr
from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str

from .._home import use_sharded_dataset
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe
from ._internal import BUILTIN_DIR, _make_sharded_datapipe
from ._resource import OnlineResource


Expand Down
85 changes: 0 additions & 85 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import csv
import enum
import gzip
import io
Expand All @@ -7,7 +6,6 @@
import os.path
import pathlib
import pickle
import textwrap
from typing import (
Sequence,
Callable,
Expand All @@ -18,10 +16,7 @@
Iterator,
Dict,
Optional,
NoReturn,
IO,
Iterable,
Mapping,
Sized,
)
from typing import cast
Expand All @@ -38,10 +33,6 @@
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"make_repr",
"FrozenMapping",
"FrozenBunch",
"create_categories_file",
"read_mat",
"image_buffer_from_array",
"SequenceIterator",
Expand All @@ -62,82 +53,6 @@
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"


def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])

prefix = f"{name}("
postfix = ")"
body = to_str(", ")

line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix

body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"


class FrozenMapping(Mapping[K, D]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))

def __getitem__(self, item: K) -> D:
return cast(Mapping[K, D], self.__dict__["__data__"])[item]

def __iter__(self) -> Iterator[K]:
return iter(self.__dict__["__data__"].keys())

def __len__(self) -> int:
return len(self.__dict__["__data__"])

def __setitem__(self, key: K, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")

def __delitem__(self, key: K) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")

def __hash__(self) -> int:
return cast(int, self.__dict__["__final_hash__"])

def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenMapping):
return NotImplemented

return hash(self) == hash(other)

def __repr__(self) -> str:
return repr(self.__dict__["__data__"])


class FrozenBunch(FrozenMapping):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error

def __setattr__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")

def __delattr__(self, item: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")

def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())


def create_categories_file(
root: Union[str, pathlib.Path], name: str, categories: Sequence[Union[str, Sequence[str]]], **fmtparams: Any
) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w", newline="") as file:
csv.writer(file, **fmtparams).writerows(categories)


def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try:
import scipy.io as sio
Expand Down
41 changes: 41 additions & 0 deletions torchvision/prototype/datasets/utils/_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import collections.abc
from typing import Any, Callable, Iterator, Optional, Tuple, TypeVar, cast

from torchvision.prototype.features import BoundingBox, Image

T = TypeVar("T")


class SampleQuery:
def __init__(self, sample: Any) -> None:
self.sample = sample

@staticmethod
def _query_recursively(sample: Any, fn: Callable[[Any], Optional[T]]) -> Iterator[T]:
if isinstance(sample, (collections.abc.Sequence, collections.abc.Mapping)):
for item in sample.values() if isinstance(sample, collections.abc.Mapping) else sample:
yield from SampleQuery._query_recursively(item, fn)
else:
result = fn(sample)
if result is not None:
yield result

def query(self, fn: Callable[[Any], Optional[T]]) -> T:
results = set(self._query_recursively(self.sample, fn))
if not results:
raise RuntimeError("Query turned up empty.")
elif len(results) > 1:
raise RuntimeError(f"Found more than one result: {results}")

return results.pop()

def image_size(self) -> Tuple[int, int]:
def fn(sample: Any) -> Optional[Tuple[int, int]]:
if isinstance(sample, Image):
return cast(Tuple[int, int], sample.shape[-2:])
elif isinstance(sample, BoundingBox):
return sample.image_size
else:
return None

return self.query(fn)
9 changes: 9 additions & 0 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ class Image(Feature):
color_spaces = ColorSpace
color_space: ColorSpace

@classmethod
def _to_tensor(cls, data, *, dtype, device):
tensor = torch.as_tensor(data, dtype=dtype, device=device)
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 3:
raise ValueError("Only single images with 2 or 3 dimensions are allowed.")
return tensor

@classmethod
def _parse_meta_data(
cls,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from functools import partial
from typing import Any, Optional

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.alexnet import AlexNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import Any, Optional, Tuple

import torch.nn as nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.densenet import DenseNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Any, Optional, Union

from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.faster_rcnn import (
Expand All @@ -12,7 +13,6 @@
misc_nn_ops,
overwrite_eps,
)
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from typing import Any, Optional

from torch import nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.efficientnet import EfficientNet, MBConvConfig
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from functools import partial
from typing import Any, Optional

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from functools import partial
from typing import Any, Optional

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from functools import partial
from typing import Any, Optional

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.mnasnet import MNASNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from functools import partial
from typing import Any, Optional

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.mobilenetv2 import MobileNetV2
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from functools import partial
from typing import Any, Optional, List

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from functools import partial
from typing import Any, Optional, Union

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ....models.quantization.googlenet import (
QuantizableGoogLeNet,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..googlenet import GoogLeNetWeights
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from functools import partial
from typing import Any, Optional, Union

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ....models.quantization.inception import (
QuantizableInception3,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..inception import InceptionV3Weights
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial
from typing import Any, List, Optional, Type, Union

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ....models.quantization.resnet import (
Expand All @@ -11,7 +12,6 @@
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights
Expand Down
Loading