Skip to content

Commit

Permalink
Merge pull request sinzlab#32 from MaxFBurg/master
Browse files Browse the repository at this point in the history
[add] argument to turn off shuffling for the train loader
  • Loading branch information
KonstantinWilleke authored Jan 28, 2022
2 parents 6136f6c + fe64d03 commit 8347c4b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion nndichromacy/datasets/mouse_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def static_loader(
include_px_position=None,
image_reshape_list=None,
trial_idx_selection=None,
train_shuffle: bool = True,
):
"""
returns a single data loader
Expand Down Expand Up @@ -299,7 +300,7 @@ def static_loader(

sampler = (
SubsetRandomSampler(subset_idx)
if tier == "train"
if tier == "train" and train_shuffle is True
else SubsetSequentialSampler(subset_idx)
)
dataloaders[tier] = DataLoader(dat, sampler=sampler, batch_size=batch_size)
Expand Down Expand Up @@ -341,6 +342,7 @@ def static_loaders(
include_px_position=None,
image_reshape_list=None,
trial_idx_selection=None,
train_shuffle: bool = True,
):
"""
Returns a dictionary of dataloaders (i.e., trainloaders, valloaders, and testloaders) for >= 1 dataset(s).
Expand Down Expand Up @@ -423,6 +425,7 @@ def static_loaders(
include_px_position=include_px_position,
image_reshape_list=image_reshape_list,
trial_idx_selection=trial_idx_selection,
train_shuffle=train_shuffle,
)
if not return_test_sampler:
for k in dls:
Expand Down

0 comments on commit 8347c4b

Please sign in to comment.