Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support custom persistent_workers #6435

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def train_detector(model,
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
runner_type=runner_type) for ds in dataset
runner_type=runner_type,
persistent_workers=cfg.data.get('persistent_workers', False))
for ds in dataset
]

# put model on gpus
Expand Down
16 changes: 15 additions & 1 deletion mmdet/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import copy
import platform
import random
import warnings
from functools import partial

import numpy as np
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader

from .samplers import (DistributedGroupSampler, DistributedSampler,
Expand Down Expand Up @@ -90,6 +91,7 @@ def build_dataloader(dataset,
shuffle=True,
seed=None,
runner_type='EpochBasedRunner',
persistent_workers=False,
**kwargs):
"""Build PyTorch DataLoader.

Expand All @@ -106,7 +108,12 @@ def build_dataloader(dataset,
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
seed (int, Optional): Seed to be used. Default: None.
runner_type (str): Type of runner. Default: `EpochBasedRunner`
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers `Dataset` instances alive.
This argument is only valid when PyTorch>=1.7.0. Default: False.
kwargs: any keyword argument to be used to initialize DataLoader

Returns:
Expand Down Expand Up @@ -163,6 +170,13 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None

if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.7.0')):
kwargs['persistent_workers'] = persistent_workers
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
elif persistent_workers is True:
warnings.warn('persistent_workers is invalid because your pytorch '
'version is lower than 1.7.0')

data_loader = DataLoader(
dataset,
batch_size=batch_size,
Expand Down