Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support custom persistent_workers #6435

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def train_detector(model,
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
runner_type=runner_type) for ds in dataset
runner_type=runner_type,
pin_memory=cfg.data.get('pin_memory', False),
persistent_workers=cfg.data.get('persistent_workers', False))
for ds in dataset
]

# put model on gpus
Expand Down
18 changes: 16 additions & 2 deletions mmdet/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader

from .samplers import (DistributedGroupSampler, DistributedSampler,
Expand Down Expand Up @@ -90,6 +90,8 @@ def build_dataloader(dataset,
shuffle=True,
seed=None,
runner_type='EpochBasedRunner',
pin_memory=False,
persistent_workers=False,
**kwargs):
"""Build PyTorch DataLoader.

Expand All @@ -106,7 +108,15 @@ def build_dataloader(dataset,
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
seed (int, Optional): Seed to be used. Default: None.
runner_type (str): Type of runner. Default: `EpochBasedRunner`
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: False.
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
The argument also has effect in PyTorch>=1.7.0.
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
Default: False.
kwargs: any keyword argument to be used to initialize DataLoader

Returns:
Expand Down Expand Up @@ -163,14 +173,18 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None

if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.7.0')):
kwargs['persistent_workers'] = persistent_workers
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved

data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=False,
pin_memory=pin_memory,
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
worker_init_fn=init_fn,
**kwargs)

Expand Down
4 changes: 3 additions & 1 deletion mmdet/models/dense_heads/dense_test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def aug_test_bboxes(self, feats, img_metas, rescale=False):

if merged_bboxes.numel() == 0:
det_bboxes = torch.cat([merged_bboxes, merged_scores[:, None]], -1)
return det_bboxes, merged_labels
return [
(det_bboxes, merged_labels),
]

det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores,
merged_labels, self.test_cfg.nms)
Expand Down