Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove categories metadata from (OneHot)Label datapoint #7171

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 27 additions & 48 deletions torchvision/prototype/datapoints/_label.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,57 @@
from __future__ import annotations

from typing import Any, Optional, Sequence, Type, TypeVar, Union
from typing import Any, Optional, Union

import torch
from torch.utils._pytree import tree_map

from ._datapoint import Datapoint


L = TypeVar("L", bound="_LabelBase")


class _LabelBase(Datapoint):
categories: Optional[Sequence[str]]

class Label(Datapoint):
@classmethod
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L:
label_base = tensor.as_subclass(cls)
label_base.categories = categories
return label_base
def _wrap(cls, tensor: torch.Tensor) -> Label:
return tensor.as_subclass(cls)

def __new__(
cls: Type[L],
cls,
data: Any,
*,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> L:
) -> Label:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, categories=categories)
return cls._wrap(tensor)

@classmethod
def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L:
return cls._wrap(
tensor,
categories=categories if categories is not None else other.categories,
)

@classmethod
def from_category(
cls: Type[L],
category: str,
*,
categories: Sequence[str],
**kwargs: Any,
) -> L:
return cls(categories.index(category), categories=categories, **kwargs)


class Label(_LabelBase):
def to_categories(self) -> Any:
if self.categories is None:
raise RuntimeError("Label does not have categories")
def wrap_like(
cls,
other: Label,
tensor: torch.Tensor,
) -> Label:
return cls._wrap(tensor)

return tree_map(lambda idx: self.categories[idx], self.tolist())

class OneHotLabel(Datapoint):
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> OneHotLabel:
return tensor.as_subclass(cls)

class OneHotLabel(_LabelBase):
def __new__(
cls,
data: Any,
*,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
requires_grad: Optional[bool] = None,
) -> OneHotLabel:
one_hot_label = super().__new__(
cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad
)

if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError()
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor)

return one_hot_label
@classmethod
def wrap_like(
cls,
other: OneHotLabel,
tensor: torch.Tensor,
) -> OneHotLabel:
return cls._wrap(tensor)
16 changes: 12 additions & 4 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@

import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints import BoundingBox
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils import (
Dataset,
EncodedImage,
GDriveResource,
LabelWithCategories,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
Expand Down Expand Up @@ -106,7 +112,7 @@ def _prepare_sample(
ann = read_mat(ann_buffer)

return dict(
label=Label.from_category(category, categories=self._categories),
label=LabelWithCategories.from_category(category, categories=self._categories),
image_path=image_path,
image=image,
ann_path=ann_path,
Expand Down Expand Up @@ -188,7 +194,9 @@ def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
return dict(
path=path,
image=EncodedImage.from_file(buffer),
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
label=LabelWithCategories(
int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories
),
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
12 changes: 9 additions & 3 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union

from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints import BoundingBox
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils import (
Dataset,
EncodedImage,
GDriveResource,
LabelWithCategories,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
Expand Down Expand Up @@ -141,7 +147,7 @@ def _prepare_sample(
return dict(
path=path,
image=image,
identity=Label(int(identity["identity"])),
identity=LabelWithCategories(int(identity["identity"])),
attributes={attr: value == "1" for attr, value in attributes.items()},
bounding_box=BoundingBox(
[int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")],
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Image, Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datapoints import Image
from torchvision.prototype.datasets.utils import Dataset, HttpResource, LabelWithCategories, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
Expand Down Expand Up @@ -70,7 +70,7 @@ def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
return dict(
image=Image(image_array),
label=Label(category_idx, categories=self._categories),
label=LabelWithCategories(category_idx, categories=self._categories),
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
11 changes: 8 additions & 3 deletions torchvision/prototype/datasets/_builtin/clevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union

from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import (
Dataset,
EncodedImage,
HttpResource,
LabelWithCategories,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
Expand Down Expand Up @@ -66,7 +71,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, A
return dict(
path=path,
image=EncodedImage.from_file(buffer),
label=Label(len(scenes_data["objects"])) if scenes_data else None,
label=LabelWithCategories(len(scenes_data["objects"])) if scenes_data else None,
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
12 changes: 9 additions & 3 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
Mapper,
UnBatcher,
)
from torchvision.prototype.datapoints import BoundingBox, Label, Mask
from torchvision.prototype.datapoints import BoundingBox, Mask
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import (
Dataset,
EncodedImage,
HttpResource,
LabelWithCategories,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
Expand Down Expand Up @@ -131,7 +137,7 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
format="xywh",
spatial_size=spatial_size,
),
labels=Label(labels, categories=self._categories),
labels=LabelWithCategories(labels, categories=self._categories),
super_categories=[self._category_to_super_category[self._categories[label]] for label in labels],
ann_ids=[ann["id"] for ann in anns],
)
Expand Down
12 changes: 9 additions & 3 deletions torchvision/prototype/datasets/_builtin/country211.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
from typing import Any, Dict, List, Tuple, Union

from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource

from torchvision.prototype.datasets.utils import (
Dataset,
EncodedImage,
HttpResource,
LabelWithCategories,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
Expand Down Expand Up @@ -53,7 +59,7 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self._categories),
label=LabelWithCategories.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
Expand Down
12 changes: 9 additions & 3 deletions torchvision/prototype/datasets/_builtin/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
Mapper,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints import BoundingBox
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils import (
Dataset,
EncodedImage,
GDriveResource,
LabelWithCategories,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
Expand Down Expand Up @@ -180,7 +186,7 @@ def _prepare_sample(
return dict(
prepare_ann_fn(anns_data, image.spatial_size),
image=image,
label=Label(
label=LabelWithCategories(
int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]) - 1,
categories=self._categories,
),
Expand Down
12 changes: 9 additions & 3 deletions torchvision/prototype/datasets/_builtin/dtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union

from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource

from torchvision.prototype.datasets.utils import (
Dataset,
EncodedImage,
HttpResource,
LabelWithCategories,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
Expand Down Expand Up @@ -89,7 +95,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO

return dict(
joint_categories={category for category in joint_categories if category},
label=Label.from_category(category, categories=self._categories),
label=LabelWithCategories.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
Expand Down
11 changes: 8 additions & 3 deletions torchvision/prototype/datasets/_builtin/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
from typing import Any, Dict, List, Tuple, Union

from torchdata.datapipes.iter import IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import (
Dataset,
EncodedImage,
HttpResource,
LabelWithCategories,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling

from .._api import register_dataset, register_info
Expand Down Expand Up @@ -51,7 +56,7 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self._categories),
label=LabelWithCategories.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_builtin/fer2013.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch
from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Image, Label
from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource
from torchvision.prototype.datapoints import Image
from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, LabelWithCategories, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling

from .._api import register_dataset, register_info
Expand Down Expand Up @@ -49,7 +49,7 @@ def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:

return dict(
image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)),
label=Label(int(label_id), categories=self._categories) if label_id is not None else None,
label=LabelWithCategories(int(label_id), categories=self._categories) if label_id is not None else None,
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
12 changes: 9 additions & 3 deletions torchvision/prototype/datasets/_builtin/food101.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union

from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource

from torchvision.prototype.datasets.utils import (
Dataset,
EncodedImage,
HttpResource,
LabelWithCategories,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
Expand Down Expand Up @@ -57,7 +63,7 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
id, (path, buffer) = data
return dict(
label=Label.from_category(id.split("/", 1)[0], categories=self._categories),
label=LabelWithCategories.from_category(id.split("/", 1)[0], categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
Expand Down
Loading