Skip to content

Commit 170a9d1

Browse files
authored
[Feature] Support persistent_workers in DataLoader (PyTorch>=1.7.0) (#646)
1 parent 98067be commit 170a9d1

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

mmseg/datasets/builder.py

+33-24
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from functools import partial
55

66
import numpy as np
7+
import torch
78
from mmcv.parallel import collate
89
from mmcv.runner import get_dist_info
910
from mmcv.utils import Registry, build_from_cfg
10-
from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
11-
from torch.utils.data import DistributedSampler
11+
from torch.utils.data import DataLoader, DistributedSampler
1212

1313
if platform.system() != 'Windows':
1414
# https://github.com/pytorch/pytorch/issues/973
@@ -84,7 +84,7 @@ def build_dataloader(dataset,
8484
seed=None,
8585
drop_last=False,
8686
pin_memory=True,
87-
dataloader_type='PoolDataLoader',
87+
persistent_workers=True,
8888
**kwargs):
8989
"""Build PyTorch DataLoader.
9090
@@ -106,7 +106,11 @@ def build_dataloader(dataset,
106106
Default: False
107107
pin_memory (bool): Whether to use pin_memory in DataLoader.
108108
Default: True
109-
dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
109+
persistent_workers (bool): If True, the data loader will not shutdown
110+
the worker processes after a dataset has been consumed once.
111+
This allows to maintain the workers Dataset instances alive.
112+
The argument also has effect in PyTorch>=1.7.0.
113+
Default: True
110114
kwargs: any keyword argument to be used to initialize DataLoader
111115
112116
Returns:
@@ -128,26 +132,31 @@ def build_dataloader(dataset,
128132
worker_init_fn, num_workers=num_workers, rank=rank,
129133
seed=seed) if seed is not None else None
130134

131-
assert dataloader_type in (
132-
'DataLoader',
133-
'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
134-
135-
if dataloader_type == 'PoolDataLoader':
136-
dataloader = PoolDataLoader
137-
elif dataloader_type == 'DataLoader':
138-
dataloader = DataLoader
139-
140-
data_loader = dataloader(
141-
dataset,
142-
batch_size=batch_size,
143-
sampler=sampler,
144-
num_workers=num_workers,
145-
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
146-
pin_memory=pin_memory,
147-
shuffle=shuffle,
148-
worker_init_fn=init_fn,
149-
drop_last=drop_last,
150-
**kwargs)
135+
if torch.__version__ >= '1.7.0':
136+
data_loader = DataLoader(
137+
dataset,
138+
batch_size=batch_size,
139+
sampler=sampler,
140+
num_workers=num_workers,
141+
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
142+
pin_memory=pin_memory,
143+
shuffle=shuffle,
144+
worker_init_fn=init_fn,
145+
drop_last=drop_last,
146+
persistent_workers=persistent_workers,
147+
**kwargs)
148+
else:
149+
data_loader = DataLoader(
150+
dataset,
151+
batch_size=batch_size,
152+
sampler=sampler,
153+
num_workers=num_workers,
154+
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
155+
pin_memory=pin_memory,
156+
shuffle=shuffle,
157+
worker_init_fn=init_fn,
158+
drop_last=drop_last,
159+
**kwargs)
151160

152161
return data_loader
153162

0 commit comments

Comments
 (0)