Skip to content

Commit

Permalink
Merge pull request #37 from Visual-Behavior/raft_datamodule_fix
Browse files Browse the repository at this point in the history
use randomsampler in train_dataloader
  • Loading branch information
thibo73800 authored Sep 8, 2021
2 parents 51df902 + bdcf1f7 commit f552d2b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions alonet/raft/data_modules/chairs2raft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.utils.data import SequentialSampler
from torch.utils.data import SequentialSampler, RandomSampler
import pytorch_lightning as pl

from alodataset import FlyingChairs2Dataset, Split
Expand All @@ -13,7 +13,7 @@ def __init__(self, args):
def train_dataloader(self):
split = Split.VAL if self.train_on_val else Split.TRAIN
dataset = FlyingChairs2Dataset(split=split, transform_fn=self.train_transform, sample=self.sample)
sampler = SequentialSampler if self.sequential else None
sampler = SequentialSampler if self.sequential else RandomSampler
return dataset.train_loader(batch_size=self.batch_size, num_workers=self.num_workers, sampler=sampler)

def val_dataloader(self):
Expand Down
4 changes: 2 additions & 2 deletions alonet/raft/data_modules/chairssdhom2raft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.utils.data import SequentialSampler
from torch.utils.data import SequentialSampler, RandomSampler

from alodataset import ChairsSDHomDataset, Split
from alonet.raft.data_modules import Data2RAFT
Expand All @@ -12,7 +12,7 @@ def __init__(self, args):
def train_dataloader(self):
split = Split.VAL if self.train_on_val else Split.TRAIN
dataset = ChairsSDHomDataset(split=split, transform_fn=self.train_transform, sample=self.sample)
sampler = SequentialSampler if self.sequential else None
sampler = SequentialSampler if self.sequential else RandomSampler
return dataset.train_loader(batch_size=self.batch_size, num_workers=self.num_workers, sampler=sampler)

def val_dataloader(self):
Expand Down
4 changes: 2 additions & 2 deletions alonet/raft/data_modules/sintel2raft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.utils.data import SequentialSampler
from torch.utils.data import SequentialSampler, RandomSampler

# import pytorch_lightning as pl

Expand All @@ -24,7 +24,7 @@ def train_dataloader(self):
transform_fn=lambda f: self.train_transform(f["left"]),
)

sampler = SequentialSampler if self.sequential else None
sampler = SequentialSampler if self.sequential else RandomSampler
return dataset.train_loader(batch_size=self.batch_size, num_workers=self.num_workers, sampler=sampler)

def val_dataloader(self):
Expand Down
4 changes: 2 additions & 2 deletions alonet/raft/data_modules/things2raft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.utils.data import SequentialSampler
from torch.utils.data import SequentialSampler, RandomSampler

# import pytorch_lightning as pl

Expand Down Expand Up @@ -42,7 +42,7 @@ def train_dataloader(self):

dataset = MergeDataset(datasets)

sampler = SequentialSampler if self.sequential else None
sampler = SequentialSampler if self.sequential else RandomSampler
return dataset.train_loader(batch_size=self.batch_size, num_workers=self.num_workers, sampler=sampler)

def val_dataloader(self):
Expand Down

0 comments on commit f552d2b

Please sign in to comment.