Skip to content

Commit

Permalink
Merge pull request #140 from twosixlabs/tuple-dataset-adapter-improve…
Browse files Browse the repository at this point in the history
…ment

Support arbitrary-length tuple datasets
  • Loading branch information
treubig26 authored Apr 18, 2024
2 parents 124007e + ce1b112 commit 41e140c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
6 changes: 1 addition & 5 deletions examples/src/armory/examples/image_classification/food101.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,7 @@ def load_torchvision_dataset(

labels = tv_dataset.classes

armory_dataset = armory.dataset.TupleDataset(
tv_dataset,
x_key="image",
y_key="label",
)
armory_dataset = armory.dataset.TupleDataset(tv_dataset, ("image", "label"))

dataloader = armory.dataset.ImageClassificationDataLoader(
armory_dataset,
Expand Down
20 changes: 8 additions & 12 deletions library/src/armory/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Armory Dataset Classes"""

from typing import Any, Callable, List, Mapping, Sequence, cast
from typing import Any, Callable, Iterable, List, Mapping, Sequence, cast

import numpy as np
import torch
Expand Down Expand Up @@ -67,33 +67,29 @@ class TupleDataset(ArmoryDataset):
print(dataset[0])
# output: [[0, 0, 0], [0, 0, 0]], [5]
tuple_ds = TupleDataset(dataset, x_key="image", y_key="label")
tuple_ds = TupleDataset(dataset, ("image", "label"))
print(tuple_ds[0])
# output: {'image': [[0, 0, 0], [0, 0, 0]], 'label': [5]}
"""

def __init__(
self,
dataset,
x_key: str,
y_key: str,
keys: Iterable[str],
):
"""
Initializes the dataset.
Args:
dataset: Source dataset where samples are a two-entry tuple of data,
or x, and target, or y.
x_key: Key name to use for x data in the adapted sample dictionary
y_key: Key name to use for y data in the adapted sample dictionary
dataset: Source dataset where samples are a tuples of data
keys: List of key names to use for each element in the sample tuple
when converted to a dictionary
"""
super().__init__(dataset, self._adapt)
self._x_key = x_key
self._y_key = y_key
self._keys = keys

def _adapt(self, sample):
x, y = sample
return {self._x_key: x, self._y_key: y}
return dict(zip(self._keys, sample))


def _collate_by_type(values: List):
Expand Down
36 changes: 31 additions & 5 deletions library/tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,44 @@ def adapter(data):
assert_array_equal(sample["y"], np.array([4]), strict=True)


def test_TupleDataset():
def test_TupleDataset_with_two_elements():
raw_dataset = [
([1, 2, 3], 4),
([5, 6, 7], 8),
]

dataset = TupleDataset(raw_dataset, x_key="data", y_key="target")
dataset = TupleDataset(raw_dataset, ("data", "target"))
assert len(dataset) == 2

sample = dataset[1]
assert sample["data"] == [5, 6, 7]
assert sample["target"] == 8
assert dataset[0] == {
"data": [1, 2, 3],
"target": 4,
}
assert dataset[1] == {
"data": [5, 6, 7],
"target": 8,
}


def test_TupleDataset_with_three_elements():
raw_dataset = [
(7, [1, 2, 3], 4),
(3, [5, 6, 7], 8),
]

dataset = TupleDataset(raw_dataset, ("image_id", "data", "target"))
assert len(dataset) == 2

assert dataset[0] == {
"image_id": 7,
"data": [1, 2, 3],
"target": 4,
}
assert dataset[1] == {
"image_id": 3,
"data": [5, 6, 7],
"target": 8,
}


def test_ImageClassificationDataLoader():
Expand Down

0 comments on commit 41e140c

Please sign in to comment.