Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a fixed random sampler for dataloader #14

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion biogtr/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# to implement - config class that handles getters/setters
"""Data structures for handling config parsing."""
from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset
from biogtr.datasets.data_utils import FixedRandomSampler
from biogtr.datasets.microscopy_dataset import MicroscopyDataset
from biogtr.datasets.sleap_dataset import SleapDataset
from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset
from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer
from biogtr.models.gtr_runner import GTRRunner
from biogtr.models.model_utils import init_optimizer, init_scheduler, init_logger
Expand Down Expand Up @@ -180,12 +181,28 @@
)
else:
generator = None

if "sampler_params" in dataloader_params:
sampler_params = dataloader_params.sampler_params
sampler = FixedRandomSampler(dataset, **sampler_params)

Check warning on line 187 in biogtr/config.py

View check run for this annotation

Codecov / codecov/patch

biogtr/config.py#L185-L187

Added lines #L185 - L187 were not covered by tests

# make sure shuffle is set to false if using fixed random sampler
if "shuffle" in dataloader_params:
dataloader_params["shuffle"] = False

Check warning on line 191 in biogtr/config.py

View check run for this annotation

Codecov / codecov/patch

biogtr/config.py#L190-L191

Added lines #L190 - L191 were not covered by tests

# now remove from dataloader_params
del dataloader_params["sampler_params"]

Check warning on line 194 in biogtr/config.py

View check run for this annotation

Codecov / codecov/patch

biogtr/config.py#L194

Added line #L194 was not covered by tests

else:
sampler = None

Check warning on line 197 in biogtr/config.py

View check run for this annotation

Codecov / codecov/patch

biogtr/config.py#L197

Added line #L197 was not covered by tests

return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
generator=generator,
collate_fn=dataset.no_batching_fn,
sampler=sampler,
**dataloader_params,
)

Expand Down
2 changes: 1 addition & 1 deletion biogtr/datasets/cell_tracking_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
clip_length: int = 10,
mode: str = "train",
augmentations: Optional[dict] = None,
gt_list: str = None,
gt_list: Optional[str] = None,
):
"""Initialize CellTrackingDataset.

Expand Down
54 changes: 53 additions & 1 deletion biogtr/datasets/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from PIL import Image
from numpy.typing import ArrayLike
from torchvision.transforms import functional as tvf
from typing import List, Dict
from typing import List, Dict, Optional
from xml.etree import cElementTree as et
import albumentations as A
import math
Expand Down Expand Up @@ -473,3 +473,55 @@

plt.tight_layout()
plt.show()


class FixedRandomSampler(torch.utils.data.Sampler):
"""Custom sampler class."""

def __init__(
self,
dataset: torch.utils.data.Dataset,
seed: int = 1234,
num_epochs: Optional[int] = None,
):
"""Custom sampler that generates indices using a fixed random seed.

Args:
dataset: The dataset to sample from.
seed: The seed for the random number generator. Default is 1234.
num_epochs: The number of epochs to generate indices
for. Allows for random sampling over epochs.
"""
self.num_samples = len(dataset)
self.seed = seed
if num_epochs is None:
self.size = self.num_samples
else:
self.size = self.num_samples * num_epochs
self.indices = self._generate_indices()

def _generate_indices(self):
"""Generates indices for sampling.

Returns:
indices (list): The list of generated indices.
"""
rng = np.random.default_rng(seed=self.seed)
indices = rng.choice(range(self.num_samples), size=self.size, replace=True)
return indices

def __iter__(self):
"""Iterator for generating indices.

Returns:
iterator (iter): An iterator of the generated indices.
"""
return iter(self.indices)

def __len__(self):
"""Returns the total number of indices.

Returns:
size (int): The total number of indices.
"""
return self.size

Check warning on line 527 in biogtr/datasets/data_utils.py

View check run for this annotation

Codecov / codecov/patch

biogtr/datasets/data_utils.py#L527

Added line #L527 was not covered by tests
3 changes: 2 additions & 1 deletion biogtr/datasets/microscopy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from biogtr.datasets.base_dataset import BaseDataset
from torch.utils.data import Dataset
from torchvision.transforms import functional as tvf
from typing import Optional
import albumentations as A
import numpy as np
import random
Expand All @@ -23,7 +24,7 @@ def __init__(
chunk: bool = False,
clip_length: int = 10,
mode: str = "Train",
augmentations: dict = None,
augmentations: Optional[dict] = None,
):
"""Initialize MicroscopyDataset.

Expand Down
15 changes: 8 additions & 7 deletions biogtr/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Module containing logic for loading sleap datasets."""
from biogtr.datasets import data_utils
from biogtr.datasets.base_dataset import BaseDataset
from torchvision.transforms import functional as tvf
from typing import List, Optional
import albumentations as A
import torch
import imageio
import numpy as np
import sleap_io as sio
import random
from biogtr.datasets import data_utils
from biogtr.datasets.base_dataset import BaseDataset
from torchvision.transforms import functional as tvf
from typing import List
import sleap_io as sio
import torch


class SleapDataset(BaseDataset):
Expand All @@ -23,7 +23,7 @@ def __init__(
chunk: bool = True,
clip_length: int = 500,
mode: str = "train",
augmentations: dict = None,
augmentations: Optional[dict] = None,
):
"""Initialize SleapDataset.

Expand Down Expand Up @@ -137,6 +137,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
gt_track_ids, bboxes, crops, poses, shown_poses = [], [], [], [], []

i = int(i)
print(i)

lf = video[i]
img = vid_reader.get_data(i)
Expand Down
40 changes: 38 additions & 2 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Test dataset logic."""
from biogtr.datasets.base_dataset import BaseDataset
from biogtr.datasets.data_utils import get_max_padding
from biogtr.datasets.data_utils import get_max_padding, FixedRandomSampler
from biogtr.datasets.microscopy_dataset import MicroscopyDataset
from biogtr.datasets.sleap_dataset import SleapDataset
from biogtr.datasets.tracking_dataset import TrackingDataset
from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -388,3 +389,38 @@ def test_augmentations(two_flies, ten_icy_particles):
b = augs_instances[0]["crops"]

assert not torch.all(a.eq(b))


def test_fixed_random_sampler():
"""Test FixedRandomSampler logic."""

# dummy dataset
dataset = TensorDataset(torch.rand(100, 10))

random_seed = 12345

# Test only seed
sampler = FixedRandomSampler(dataset, seed=random_seed)
sample_indices = list(iter(sampler))
assert len(sample_indices) == len(dataset)
assert len(set(sample_indices)) <= len(dataset)

# Test both seed and num_epochs
sampler = FixedRandomSampler(dataset, seed=random_seed, num_epochs=2)
sample_indices = list(iter(sampler))
assert len(sample_indices) == 2 * len(dataset)

# Test fixed seed sampling consistency
sampler1 = FixedRandomSampler(dataset, seed=random_seed)
sampler2 = FixedRandomSampler(dataset, seed=random_seed)
sample_indices1 = list(iter(sampler1))
sample_indices2 = list(iter(sampler2))
np.testing.assert_array_equal(sample_indices1, sample_indices2)

# Test different seeds give different results
sampler1 = FixedRandomSampler(dataset, seed=random_seed)
sampler2 = FixedRandomSampler(dataset, seed=random_seed + 1)
sample_indices1 = list(iter(sampler1))
sample_indices2 = list(iter(sampler2))
with pytest.raises(AssertionError):
np.testing.assert_array_equal(sample_indices1, sample_indices2)
Loading