-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathfer2013.py
120 lines (99 loc) · 5 KB
/
fer2013.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import csv
import pathlib
from typing import Any, Callable, Optional, Tuple, Union
import torch
from PIL import Image
from .utils import check_integrity, verify_str_arg
from .vision import VisionDataset
class FER2013(VisionDataset):
"""`FER2013
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
.. note::
This dataset can return test labels only if ``fer2013.csv`` OR
``icml_face_data.csv`` are present in ``root/fer2013/``. If only
``train.csv`` and ``test.csv`` are present, the test labels are set to
``None``.
Args:
root (str or ``pathlib.Path``): Root directory of dataset where directory
``root/fer2013`` exists. This directory may contain either
``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
``test.csv``. Precendence is given in that order, i.e. if
``fer2013.csv`` is present then the rest of the files will be
ignored. All these (combinations of) files contain the same data and
are supported for convenience, but only ``fer2013.csv`` and
``icml_face_data.csv`` are able to return non-None test labels.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
_RESOURCES = {
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
# The fer2013.csv and icml_face_data.csv files contain both train and
# tests instances, and unlike test.csv they contain the labels for the
# test instances. We give these 2 files precedence over train.csv and
# test.csv. And yes, they both contain the same data, but with different
# column names (note the spaces) and ordering:
# $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
# ==> fer2013.csv <==
# emotion,pixels,Usage
#
# ==> icml_face_data.csv <==
# emotion, Usage, pixels
#
# ==> train.csv <==
# emotion,pixels
#
# ==> test.csv <==
# pixels
"fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
"icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
}
def __init__(
self,
root: Union[str, pathlib.Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, transform=transform, target_transform=target_transform)
base_folder = pathlib.Path(self.root) / "fer2013"
use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
data_file = base_folder / file_name
if not check_integrity(str(data_file), md5=md5):
raise RuntimeError(
f"{file_name} not found in {base_folder} or corrupted. "
f"You can download it from "
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
)
pixels_key = " pixels" if use_icml_file else "pixels"
usage_key = " Usage" if use_icml_file else "Usage"
def get_img(row):
return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)
def get_label(row):
if use_fer_file or use_icml_file or self._split == "train":
return int(row["emotion"])
else:
return None
with open(data_file, "r", newline="") as file:
rows = (row for row in csv.DictReader(file))
if use_fer_file or use_icml_file:
valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
rows = (row for row in rows if row[usage_key] in valid_keys)
self._samples = [(get_img(row), get_label(row)) for row in rows]
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_tensor, target = self._samples[idx]
image = Image.fromarray(image_tensor.numpy())
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def extra_repr(self) -> str:
return f"split={self._split}"