diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index f7a9926111..3ef328d0d6 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -4,11 +4,11 @@ from functools import partial import numpy as np +import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info from mmcv.utils import Registry, build_from_cfg -from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader -from torch.utils.data import DistributedSampler +from torch.utils.data import DataLoader, DistributedSampler if platform.system() != 'Windows': # https://github.com/pytorch/pytorch/issues/973 @@ -84,7 +84,7 @@ def build_dataloader(dataset, seed=None, drop_last=False, pin_memory=True, - dataloader_type='PoolDataLoader', + persistent_workers=True, **kwargs): """Build PyTorch DataLoader. @@ -106,7 +106,11 @@ def build_dataloader(dataset, Default: False pin_memory (bool): Whether to use pin_memory in DataLoader. Default: True - dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader' + persistent_workers (bool): If True, the data loader will not shutdown + the worker processes after a dataset has been consumed once. + This allows to maintain the workers Dataset instances alive. + The argument also has effect in PyTorch>=1.7.0. + Default: True kwargs: any keyword argument to be used to initialize DataLoader Returns: @@ -128,26 +132,31 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None - assert dataloader_type in ( - 'DataLoader', - 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}' - - if dataloader_type == 'PoolDataLoader': - dataloader = PoolDataLoader - elif dataloader_type == 'DataLoader': - dataloader = DataLoader - - data_loader = dataloader( - dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=num_workers, - collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), - pin_memory=pin_memory, - shuffle=shuffle, - worker_init_fn=init_fn, - drop_last=drop_last, - **kwargs) + if torch.__version__ >= '1.7.0': + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=pin_memory, + shuffle=shuffle, + worker_init_fn=init_fn, + drop_last=drop_last, + persistent_workers=persistent_workers, + **kwargs) + else: + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=pin_memory, + shuffle=shuffle, + worker_init_fn=init_fn, + drop_last=drop_last, + **kwargs) return data_loader