Skip to content

Commit

Permalink
[Enhance] Add extra dataloader settings in configs (open-mmlab#1435)
Browse files Browse the repository at this point in the history
* [Enhance] Add extra dataloader settings in configs

* val default samples

* val default samples

* del unuse

* del unused
  • Loading branch information
MeowZheng authored Apr 13, 2022
1 parent add835b commit f50bfe3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
44 changes: 27 additions & 17 deletions mmseg/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,25 @@ def train_segmentor(model,

# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
drop_last=True) for ds in dataset
]
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
drop_last=True)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})

# The specific dataloader settings
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

# put model on gpus
if distributed:
Expand Down Expand Up @@ -142,12 +150,14 @@ def train_segmentor(model,
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# The specific dataloader settings
val_loader_cfg = {
**loader_cfg,
'samples_per_gpu': 1,
'shuffle': False, # Not shuffle by default
**cfg.data.get('val_dataloader', {}),
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
Expand Down
24 changes: 20 additions & 4 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,28 @@ def main():
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
shuffle=False)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
test_loader_cfg = {
**loader_cfg,
'samples_per_gpu': 1,
'shuffle': False, # Not shuffle by default
**cfg.data.get('test_dataloader', {})
}
# build the dataloader
data_loader = build_dataloader(dataset, **test_loader_cfg)

# build the model and load checkpoint
cfg.model.train_cfg = None
Expand Down

0 comments on commit f50bfe3

Please sign in to comment.