-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_loader.py
52 lines (47 loc) · 1.83 KB
/
build_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from functools import partial
from mmcv.runner import get_dist_info
from mmcv.parallel import collate
from torch.utils.data import DataLoader
from sampler import GroupSampler, DistributedGroupSampler, DistributedSampler, ClassAwareSampler, MixSampler
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (12288, rlimit[1]))
def build_dataloader(dataset,
imgs_per_gpu,
workers_per_gpu=2,
num_gpus=1,
dist=True,
sampler=None,
**kwargs):
shuffle = kwargs.get('shuffle', True)
if dist:
rank, world_size = get_dist_info()
sampler = DistributedSampler(dataset, world_size, rank, shuffle=shuffle)
batch_size = imgs_per_gpu
num_workers = workers_per_gpu
kwargs.update(shuffle=False)
else:
if sampler is not None:
shuffle = False
if 'ClassAware' in sampler:
sampler = ClassAwareSampler(data_source=dataset)
elif 'Mix' in sampler:
sampler = MixSampler(data_source=dataset)
elif 'Group' in sampler:
sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
else:
raise NameError
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu
kwargs.update(shuffle=shuffle)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
pin_memory=False,
drop_last=False,
**kwargs)
return data_loader