4
4
from functools import partial
5
5
6
6
import numpy as np
7
+ import torch
7
8
from mmcv .parallel import collate
8
9
from mmcv .runner import get_dist_info
9
10
from mmcv .utils import Registry , build_from_cfg
10
- from mmcv .utils .parrots_wrapper import DataLoader , PoolDataLoader
11
- from torch .utils .data import DistributedSampler
11
+ from torch .utils .data import DataLoader , DistributedSampler
12
12
13
13
if platform .system () != 'Windows' :
14
14
# https://github.com/pytorch/pytorch/issues/973
@@ -84,7 +84,7 @@ def build_dataloader(dataset,
84
84
seed = None ,
85
85
drop_last = False ,
86
86
pin_memory = True ,
87
- dataloader_type = 'PoolDataLoader' ,
87
+ persistent_workers = True ,
88
88
** kwargs ):
89
89
"""Build PyTorch DataLoader.
90
90
@@ -106,7 +106,11 @@ def build_dataloader(dataset,
106
106
Default: False
107
107
pin_memory (bool): Whether to use pin_memory in DataLoader.
108
108
Default: True
109
- dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
109
+ persistent_workers (bool): If True, the data loader will not shutdown
110
+ the worker processes after a dataset has been consumed once.
111
+ This allows to maintain the workers Dataset instances alive.
112
+ The argument also has effect in PyTorch>=1.7.0.
113
+ Default: True
110
114
kwargs: any keyword argument to be used to initialize DataLoader
111
115
112
116
Returns:
@@ -128,26 +132,31 @@ def build_dataloader(dataset,
128
132
worker_init_fn , num_workers = num_workers , rank = rank ,
129
133
seed = seed ) if seed is not None else None
130
134
131
- assert dataloader_type in (
132
- 'DataLoader' ,
133
- 'PoolDataLoader' ), f'unsupported dataloader { dataloader_type } '
134
-
135
- if dataloader_type == 'PoolDataLoader' :
136
- dataloader = PoolDataLoader
137
- elif dataloader_type == 'DataLoader' :
138
- dataloader = DataLoader
139
-
140
- data_loader = dataloader (
141
- dataset ,
142
- batch_size = batch_size ,
143
- sampler = sampler ,
144
- num_workers = num_workers ,
145
- collate_fn = partial (collate , samples_per_gpu = samples_per_gpu ),
146
- pin_memory = pin_memory ,
147
- shuffle = shuffle ,
148
- worker_init_fn = init_fn ,
149
- drop_last = drop_last ,
150
- ** kwargs )
135
+ if torch .__version__ >= '1.7.0' :
136
+ data_loader = DataLoader (
137
+ dataset ,
138
+ batch_size = batch_size ,
139
+ sampler = sampler ,
140
+ num_workers = num_workers ,
141
+ collate_fn = partial (collate , samples_per_gpu = samples_per_gpu ),
142
+ pin_memory = pin_memory ,
143
+ shuffle = shuffle ,
144
+ worker_init_fn = init_fn ,
145
+ drop_last = drop_last ,
146
+ persistent_workers = persistent_workers ,
147
+ ** kwargs )
148
+ else :
149
+ data_loader = DataLoader (
150
+ dataset ,
151
+ batch_size = batch_size ,
152
+ sampler = sampler ,
153
+ num_workers = num_workers ,
154
+ collate_fn = partial (collate , samples_per_gpu = samples_per_gpu ),
155
+ pin_memory = pin_memory ,
156
+ shuffle = shuffle ,
157
+ worker_init_fn = init_fn ,
158
+ drop_last = drop_last ,
159
+ ** kwargs )
151
160
152
161
return data_loader
153
162
0 commit comments