Skip to content

Commit

Permalink
[Enhancement] More customizable fields in dataloaders (#933)
Browse files Browse the repository at this point in the history
* [Enhancement] More customizable fields in val and test dataloaders

* update default_loader_cfg
  • Loading branch information
gaotongxiao authored Apr 18, 2022
1 parent 20fc909 commit b4a9a87
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
30 changes: 15 additions & 15 deletions mmocr/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,30 @@ def train_detector(model,
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# step 1: give default values and override (if exist) from cfg.data
loader_cfg = {
default_loader_cfg = {
**dict(
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.get('seed'),
drop_last=False,
dist=distributed,
num_gpus=len(cfg.gpu_ids)),
persistent_workers=False),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
)),
**dict((k, cfg.data[k]) for k in [
'samples_per_gpu',
'workers_per_gpu',
'shuffle',
'seed',
'drop_last',
'prefetch_num',
'pin_memory',
'persistent_workers',
] if k in cfg.data)
}
# update overall dataloader(for train, val and test) setting
default_loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})

# step 2: cfg.data.train_dataloader has highest priority
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
train_loader_cfg = dict(default_loader_cfg,
**cfg.data.get('train_dataloader', {}))

data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

Expand Down Expand Up @@ -135,7 +135,7 @@ def train_detector(model,
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))

val_loader_cfg = {
**loader_cfg,
**default_loader_cfg,
**dict(shuffle=False, drop_last=False),
**cfg.data.get('val_dataloader', {}),
**dict(samples_per_gpu=val_samples_per_gpu)
Expand Down
20 changes: 10 additions & 10 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,22 @@ def main():
# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
# step 1: give default values and override (if exist) from cfg.data
loader_cfg = {
default_loader_cfg = {
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
)),
**dict((k, cfg.data[k]) for k in [
'workers_per_gpu',
'seed',
'prefetch_num',
'pin_memory',
'persistent_workers',
] if k in cfg.data)
))
}
default_loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
test_loader_cfg = {
**loader_cfg,
**default_loader_cfg,
**dict(shuffle=False, drop_last=False),
**cfg.data.get('test_dataloader', {}),
**dict(samples_per_gpu=samples_per_gpu)
Expand Down

0 comments on commit b4a9a87

Please sign in to comment.