Skip to content

Commit

Permalink
added if statement to account for IterableDatasets doing distributed …
Browse files Browse the repository at this point in the history
…training (#2151)
  • Loading branch information
ShirleyWangCVR authored Oct 8, 2022
1 parent 6c746fa commit 9d2312b
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions mmseg/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterableDataset

from .samplers import DistributedSampler

Expand Down Expand Up @@ -129,12 +129,17 @@ def build_dataloader(dataset,
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
if dist and not isinstance(dataset, IterableDataset):
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
elif dist:
sampler = None
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = None
batch_size = num_gpus * samples_per_gpu
Expand Down

0 comments on commit 9d2312b

Please sign in to comment.