From 3e780be616edfb4e93e239916df190851125be86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Tue, 23 Nov 2021 20:28:07 +0800 Subject: [PATCH] [Feature] Support custom persistent_workers (#6435) * Fix aug test error when the number of prediction bboxes is 0 (#6398) * Fix aug test error when the number of prediction bboxes is 0 * test * test * fix lint * Support custom pin_memory and persistent_workers * fix comment * fix docstr * remove pin_memory --- mmdet/apis/train.py | 4 +++- mmdet/datasets/builder.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index a584b82c357..786a1cba619 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -106,7 +106,9 @@ def train_detector(model, num_gpus=len(cfg.gpu_ids), dist=distributed, seed=cfg.seed, - runner_type=runner_type) for ds in dataset + runner_type=runner_type, + persistent_workers=cfg.data.get('persistent_workers', False)) + for ds in dataset ] # put model on gpus diff --git a/mmdet/datasets/builder.py b/mmdet/datasets/builder.py index 484710ab0cf..c5c4a63a390 100644 --- a/mmdet/datasets/builder.py +++ b/mmdet/datasets/builder.py @@ -2,11 +2,12 @@ import copy import platform import random +import warnings from functools import partial import numpy as np from mmcv.runner import get_dist_info -from mmcv.utils import Registry, build_from_cfg +from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version from torch.utils.data import DataLoader from .samplers import (DistributedGroupSampler, DistributedSampler, @@ -89,6 +90,7 @@ def build_dataloader(dataset, shuffle=True, seed=None, runner_type='EpochBasedRunner', + persistent_workers=False, **kwargs): """Build PyTorch DataLoader. @@ -105,7 +107,12 @@ def build_dataloader(dataset, dist (bool): Distributed training/test or not. Default: True. shuffle (bool): Whether to shuffle the data at every epoch. Default: True. + seed (int, Optional): Seed to be used. Default: None. runner_type (str): Type of runner. Default: `EpochBasedRunner` + persistent_workers (bool): If True, the data loader will not shutdown + the worker processes after a dataset has been consumed once. + This allows to maintain the workers `Dataset` instances alive. + This argument is only valid when PyTorch>=1.7.0. Default: False. kwargs: any keyword argument to be used to initialize DataLoader Returns: @@ -162,6 +169,13 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) >= digit_version('1.7.0')): + kwargs['persistent_workers'] = persistent_workers + elif persistent_workers is True: + warnings.warn('persistent_workers is invalid because your pytorch ' + 'version is lower than 1.7.0') + data_loader = DataLoader( dataset, batch_size=batch_size,