Skip to content

Commit

Permalink
Merge pull request #265 from marcorax/develop-1
Browse files Browse the repository at this point in the history
Add stabilization to NMNIST dataset
  • Loading branch information
biphasic authored Aug 10, 2023
2 parents 639d469 + ea97d58 commit 4b89557
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions tonic/datasets/nmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Callable, Optional

import numpy as np

from tonic.dataset import Dataset
from tonic.io import read_mnist_file

Expand All @@ -28,6 +27,7 @@ class NMNIST(Dataset):
train (bool): If True, uses training subset, otherwise testing subset.
first_saccade_only (bool): If True, only work with events of the first of three saccades.
Results in about a third of the events overall.
stabilize (bool): If True, it stabilizes egomotion of the saccades, centering the digit.
transform (callable, optional): A callable of transforms to apply to the data.
target_transform (callable, optional): A callable of transforms to apply to the targets/labels.
transforms (callable, optional): A callable of transforms that is applied to both data and
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(
save_to: str,
train: bool = True,
first_saccade_only: bool = False,
stabilize: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
Expand All @@ -78,6 +79,7 @@ def __init__(
)
self.train = train
self.first_saccade_only = first_saccade_only
self.stabilize = stabilize

if train:
self.filename = self.train_filename
Expand Down Expand Up @@ -105,11 +107,14 @@ def __init__(
def __getitem__(self, index):
"""
Returns:
a tuple of (events, target) where target is the index of the target class.
a tuple of (events, target) where target is the index of the target
class.
"""
events = read_mnist_file(self.data[index], dtype=self.dtype)
if self.first_saccade_only:
events = events[events["t"] < 1e5]
if self.stabilize:
events = stabilize(events)
target = self.targets[index]
if self.transform is not None:
events = self.transform(events)
Expand All @@ -127,3 +132,45 @@ def _check_exists(self) -> bool:
self._is_file_present()
and self._folder_contains_at_least_n_files_of_type(10000, ".bin")
)


def stabilize(events):
"""
Stabilize digits, code ported from https://www.garrickorchard.com/datasets/n-mnist
Returns:
stabilized events, removing the egomotion caused by saccades.
"""

stab_x = np.asarray(events["x"], dtype=np.float64)
stab_y = np.asarray(events["y"], dtype=np.float64)

# original code might result in a small offset, fixed manually
x_off = 4
y_off = 2

saccade_1_index = events["t"] <= 105e3
stab_x[saccade_1_index] = x_off + stab_x[saccade_1_index] - \
3.5*events["t"][saccade_1_index]/105e3
stab_y[saccade_1_index] = y_off + stab_y[saccade_1_index] - \
7*events["t"][saccade_1_index]/105e3

saccade_2_index = (events["t"] > 105e3) * (events["t"] <= 210e3)
stab_x[saccade_2_index] = x_off + stab_x[saccade_2_index] - \
3.5 - 3.5*(events["t"][saccade_2_index] - 105e3)/105e3
stab_y[saccade_2_index] = y_off + stab_y[saccade_2_index] - \
7 + 7*(events["t"][saccade_2_index] - 105e3)/105e3

saccade_3_index = (events["t"] > 210e3)
stab_x[saccade_3_index] = x_off + stab_x[saccade_3_index] - \
7 + 7*(events["t"][saccade_3_index]-210e3)/105e3
# events["y"] remains almonst unchaged because it is a horizontal saccade
stab_y[saccade_3_index] = y_off + stab_y[saccade_3_index]

events["x"] = np.asarray(np.round(stab_x), dtype=np.int64)
events["y"] = np.asarray(np.round(stab_y), dtype=np.int64)

nulls = (stab_x < 0) + (stab_y < 0) + (stab_x > 33) + (stab_y > 33)

events = events[nulls == 0]

return events

0 comments on commit 4b89557

Please sign in to comment.