diff --git a/mmflow/apis/train.py b/mmflow/apis/train.py index 26cb22c..65256d9 100644 --- a/mmflow/apis/train.py +++ b/mmflow/apis/train.py @@ -96,6 +96,17 @@ def train_model(model: Module, # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + + # The overall dataloader settings + loader_cfg = { + k: v + for k, v in cfg.data.items() if k not in [ + 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', + 'test_dataloader' + ] + } + # The specific training dataloader settings + train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})} data_loaders = [ build_dataloader( ds, @@ -103,7 +114,7 @@ def train_model(model: Module, num_gpus=len(cfg.gpu_ids), dist=distributed, seed=cfg.seed, - **cfg.data.train_dataloader) for ds in dataset + **train_loader_cfg) for ds in dataset ] # put model on gpus @@ -160,6 +171,9 @@ def train_model(model: Module, else: optimizer_config = cfg.optimizer_config + # The specific validation dataloader settings + val_loader_cfg = {**loader_cfg, **cfg.data.get('val_dataloader', {})} + # register hooks runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, @@ -174,7 +188,7 @@ def train_model(model: Module, ] val_dataloader = [ build_dataloader( - _val_dataset, **cfg.data.val_dataloader, dist=distributed) + _val_dataset, **val_loader_cfg, dist=distributed) for _val_dataset in val_dataset ] val_dataset_name = [ds.dataset_name for ds in val_dataset] @@ -182,7 +196,7 @@ def train_model(model: Module, val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) val_dataloader = build_dataloader( - val_dataset, **cfg.data.val_dataloader, dist=distributed) + val_dataset, **val_loader_cfg, dist=distributed) val_dataset_name = val_dataset.dataset_name eval_cfg = cfg.get('evaluation', {}) diff --git a/tools/test.py b/tools/test.py index 783d685..48466c2 100644 --- a/tools/test.py +++ b/tools/test.py @@ -120,6 +120,17 @@ def main(): # set multi-process settings setup_multi_processes(cfg) + # The overall dataloader settings + loader_cfg = { + k: v + for k, v in cfg.data.items() if k not in [ + 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', + 'test_dataloader' + ] + } + # The specific training dataloader settings + test_loader_cfg = {**loader_cfg, **cfg.data.get('test_dataloader', {})} + # build the dataloader separate_eval = cfg.data.test.get('separate_eval', False) if separate_eval: @@ -133,7 +144,7 @@ def main(): data_loader = [ build_dataloader( _dataset, - **cfg.data.test_dataloader, + **test_loader_cfg, dist=distributed, ) for _dataset in dataset ]