Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate Fer2013 prototype dataset #5759

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,13 +1013,14 @@ def dtd(info, root, config):
return num_samples_map[config]


# @register_mock
def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3
@register_mock(configs=combinations_grid(split=("train", "test")))
def fer2013(root, config):
split = config["split"]
num_samples = 5 if split == "train" else 3

path = root / f"{config.split}.csv"
path = root / f"{split}.csv"
with open(path, "w", newline="") as file:
field_names = ["emotion"] if config.split == "train" else []
field_names = ["emotion"] if split == "train" else []
field_names.append("pixels")

file.write(",".join(field_names) + "\n")
Expand All @@ -1029,7 +1030,7 @@ def fer2013(info, root, config):
rowdict = {
"pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)])
}
if config.split == "train":
if split == "train":
rowdict["emotion"] = int(torch.randint(7, ()))
writer.writerow(rowdict)

Expand Down
60 changes: 37 additions & 23 deletions torchvision/prototype/datasets/_builtin/fer2013.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, Dict, List, cast
import pathlib
from typing import Any, Dict, List, cast, Union

import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
Dataset2,
OnlineResource,
KaggleDownloadResource,
)
Expand All @@ -15,26 +14,40 @@
)
from torchvision.prototype.features import Label, Image

from .._api import register_dataset, register_info

NAME = "fer2013"


@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"))

class FER2013(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fer2013",
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"),
valid_options=dict(split=("train", "test")),
)

@register_dataset(NAME)
class FER2013(Dataset2):
"""FER 2013 Dataset
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
"""

def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]

super().__init__(root, skip_integrity_check=skip_integrity_check)

_CHECKSUMS = {
"train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10",
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
archive = KaggleDownloadResource(
cast(str, self.info.homepage),
file_name=f"{config.split}.csv.zip",
sha256=self._CHECKSUMS[config.split],
"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
file_name=f"{self._split}.csv.zip",
sha256=self._CHECKSUMS[self._split],
)
return [archive]

Expand All @@ -43,17 +56,18 @@ def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:

return dict(
image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)),
label=Label(int(label_id), categories=self.categories) if label_id is not None else None,
label=Label(int(label_id), categories=self._categories) if label_id is not None else None,
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVDictParser(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)

def __len__(self) -> int:
return {
"train": 28_709,
"test": 3_589,
}[self._split]
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved