From ff69334f03770678f33d3aca58790500ffb78fdf Mon Sep 17 00:00:00 2001 From: sheridana Date: Thu, 29 Jun 2023 20:32:07 -0400 Subject: [PATCH 1/4] Add fixed random sampler, expose in config --- biogtr/config.py | 19 +++++++++++- biogtr/datasets/data_utils.py | 55 ++++++++++++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/biogtr/config.py b/biogtr/config.py index 8b5bb22e..950c1722 100644 --- a/biogtr/config.py +++ b/biogtr/config.py @@ -1,8 +1,9 @@ # to implement - config class that handles getters/setters """Data structures for handling config parsing.""" +from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset +from biogtr.datasets.data_utils import FixedRandomSampler from biogtr.datasets.microscopy_dataset import MicroscopyDataset from biogtr.datasets.sleap_dataset import SleapDataset -from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer from biogtr.models.gtr_runner import GTRRunner from biogtr.models.model_utils import init_optimizer, init_scheduler, init_logger @@ -180,12 +181,28 @@ def get_dataloader( ) else: generator = None + + if "sampler_params" in dataloader_params: + sampler_params = dataloader_params.sampler_params + sampler = FixedRandomSampler(dataset, **sampler_params) + + # make sure shuffle is set to false if using fixed random sampler + if "shuffle" in dataloader_params: + dataloader_params["shuffle"] = False + + # now remove from dataloader_params + del dataloader_params["sampler_params"] + + else: + sampler = None + return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, pin_memory=pin_memory, generator=generator, collate_fn=dataset.no_batching_fn, + sampler=sampler, **dataloader_params, ) diff --git a/biogtr/datasets/data_utils.py b/biogtr/datasets/data_utils.py index 81db2a42..99066811 100644 --- a/biogtr/datasets/data_utils.py +++ b/biogtr/datasets/data_utils.py @@ -2,7 +2,7 @@ from PIL import Image from numpy.typing import ArrayLike from torchvision.transforms import functional as tvf -from typing import List, Dict +from typing import List, Dict, Optional from xml.etree import cElementTree as et import albumentations as A import math @@ -473,3 +473,56 @@ def view_training_batch( plt.tight_layout() plt.show() + + +class FixedRandomSampler(torch.utils.data.Sampler): + """Custom sampler class.""" + + def __init__( + self, + dataset: torch.utils.data.Dataset, + seed: int = 1234, + num_epochs: Optional[int] = None, + ): + """Custom sampler that generates indices using a fixed random seed. + + Args: + dataset: The dataset to sample from. + seed: The seed for the random number generator. Default is 1234. + num_epochs: The number of epochs to generate indices + for. Allows for random sampling over epochs. + """ + + self.num_samples = len(dataset) + self.seed = seed + if num_epochs is None: + self.size = self.num_samples + else: + self.size = self.num_samples * num_epochs + self.indices = self._generate_indices() + + def _generate_indices(self): + """Generates indices for sampling. + + Returns: + indices (list): The list of generated indices. + """ + rng = np.random.default_rng(seed=self.seed) + indices = rng.choice(range(self.num_samples), size=self.size, replace=True) + return indices + + def __iter__(self): + """Iterator for generating indices. + + Returns: + iterator (iter): An iterator of the generated indices. + """ + return iter(self.indices) + + def __len__(self): + """Returns the total number of indices. + + Returns: + size (int): The total number of indices. + """ + return self.size From 5f9205286416a1bea0d63e3e50f8c2fd8ba41058 Mon Sep 17 00:00:00 2001 From: sheridana Date: Thu, 29 Jun 2023 20:32:20 -0400 Subject: [PATCH 2/4] Add fixed random sampler test --- tests/test_datasets.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index d4be9776..7debeb2c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,11 +1,12 @@ """Test dataset logic.""" from biogtr.datasets.base_dataset import BaseDataset -from biogtr.datasets.data_utils import get_max_padding +from biogtr.datasets.data_utils import get_max_padding, FixedRandomSampler from biogtr.datasets.microscopy_dataset import MicroscopyDataset from biogtr.datasets.sleap_dataset import SleapDataset from biogtr.datasets.tracking_dataset import TrackingDataset from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, TensorDataset +import numpy as np import pytest import torch @@ -388,3 +389,38 @@ def test_augmentations(two_flies, ten_icy_particles): b = augs_instances[0]["crops"] assert not torch.all(a.eq(b)) + + +def test_fixed_random_sampler(): + """Test FixedRandomSampler logic.""" + + # dummy dataset + dataset = TensorDataset(torch.rand(100, 10)) + + random_seed = 12345 + + # Test only seed + sampler = FixedRandomSampler(dataset, seed=random_seed) + sample_indices = list(iter(sampler)) + assert len(sample_indices) == len(dataset) + assert len(set(sample_indices)) <= len(dataset) + + # Test both seed and num_epochs + sampler = FixedRandomSampler(dataset, seed=random_seed, num_epochs=2) + sample_indices = list(iter(sampler)) + assert len(sample_indices) == 2 * len(dataset) + + # Test fixed seed sampling consistency + sampler1 = FixedRandomSampler(dataset, seed=random_seed) + sampler2 = FixedRandomSampler(dataset, seed=random_seed) + sample_indices1 = list(iter(sampler1)) + sample_indices2 = list(iter(sampler2)) + np.testing.assert_array_equal(sample_indices1, sample_indices2) + + # Test different seeds give different results + sampler1 = FixedRandomSampler(dataset, seed=random_seed) + sampler2 = FixedRandomSampler(dataset, seed=random_seed + 1) + sample_indices1 = list(iter(sampler1)) + sample_indices2 = list(iter(sampler2)) + with pytest.raises(AssertionError): + np.testing.assert_array_equal(sample_indices1, sample_indices2) From d8a44a6d6f837e0fbd74905abf819ed540ca3509 Mon Sep 17 00:00:00 2001 From: sheridana Date: Thu, 29 Jun 2023 20:32:42 -0400 Subject: [PATCH 3/4] Quick typing fix for datasets --- biogtr/datasets/cell_tracking_dataset.py | 2 +- biogtr/datasets/microscopy_dataset.py | 3 ++- biogtr/datasets/sleap_dataset.py | 15 ++++++++------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 6e7b2076..cc872842 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -28,7 +28,7 @@ def __init__( clip_length: int = 10, mode: str = "train", augmentations: Optional[dict] = None, - gt_list: str = None, + gt_list: Optional[str] = None, ): """Initialize CellTrackingDataset. diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index 5afd582d..b93d5de8 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -4,6 +4,7 @@ from biogtr.datasets.base_dataset import BaseDataset from torch.utils.data import Dataset from torchvision.transforms import functional as tvf +from typing import Optional import albumentations as A import numpy as np import random @@ -23,7 +24,7 @@ def __init__( chunk: bool = False, clip_length: int = 10, mode: str = "Train", - augmentations: dict = None, + augmentations: Optional[dict] = None, ): """Initialize MicroscopyDataset. diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index c395f235..cdddd009 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -1,14 +1,14 @@ """Module containing logic for loading sleap datasets.""" +from biogtr.datasets import data_utils +from biogtr.datasets.base_dataset import BaseDataset +from torchvision.transforms import functional as tvf +from typing import List, Optional import albumentations as A -import torch import imageio import numpy as np -import sleap_io as sio import random -from biogtr.datasets import data_utils -from biogtr.datasets.base_dataset import BaseDataset -from torchvision.transforms import functional as tvf -from typing import List +import sleap_io as sio +import torch class SleapDataset(BaseDataset): @@ -23,7 +23,7 @@ def __init__( chunk: bool = True, clip_length: int = 500, mode: str = "train", - augmentations: dict = None, + augmentations: Optional[dict] = None, ): """Initialize SleapDataset. @@ -137,6 +137,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict gt_track_ids, bboxes, crops, poses, shown_poses = [], [], [], [], [] i = int(i) + print(i) lf = video[i] img = vid_reader.get_data(i) From f5f14a2b18ab903b03c4ca910e7270c0e162d231 Mon Sep 17 00:00:00 2001 From: sheridana Date: Thu, 29 Jun 2023 20:41:41 -0400 Subject: [PATCH 4/4] Fix docstring --- biogtr/datasets/data_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/biogtr/datasets/data_utils.py b/biogtr/datasets/data_utils.py index 99066811..80fc8a79 100644 --- a/biogtr/datasets/data_utils.py +++ b/biogtr/datasets/data_utils.py @@ -492,7 +492,6 @@ def __init__( num_epochs: The number of epochs to generate indices for. Allows for random sampling over epochs. """ - self.num_samples = len(dataset) self.seed = seed if num_epochs is None: