Skip to content

Commit

Permalink
Move maybe_instantiate_test_loaders method to cfg_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Nov 15, 2023
1 parent a141f66 commit df8abdd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
37 changes: 36 additions & 1 deletion src/super_gradients/common/environment/cfg_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from pathlib import Path
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Union, Dict, Any, Mapping

import hydra
import pkg_resources

from hydra import initialize_config_dir, compose
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf, open_dict, DictConfig
from torch.utils.data import DataLoader

from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
from super_gradients.common.environment.path_utils import normalize_path
Expand Down Expand Up @@ -195,3 +196,37 @@ def export_recipe(config_name: str, save_path: str, config_dir: str = pkg_resour
cfg = compose(config_name=config_name)
OmegaConf.save(config=cfg, f=save_path)
logger.info(f"Successfully saved recipe at {save_path}. \n" f"Recipe content:\n {cfg}")


def maybe_instantiate_test_loaders(cfg) -> Optional[Mapping[str, DataLoader]]:
"""
Instantiate test loaders if they are defined in the config.
:param cfg: Recipe config
:return: A mapping from dataset name to test loader or None if no test loaders are defined.
"""
from super_gradients.training.utils.utils import get_param
from super_gradients.training import dataloaders

test_loaders = None
if "test_dataset_params" in cfg.dataset_params:
test_dataloaders = get_param(cfg, "test_dataloaders")
test_dataset_params = cfg.dataset_params.test_dataset_params
test_dataloader_params = get_param(cfg.dataset_params, "test_dataloader_params")

if test_dataloaders is not None:
if not isinstance(test_dataloaders, Mapping):
raise ValueError("`test_dataloaders` should be a mapping from test_loader_name to test_loader_params.")

if test_dataloader_params is not None and test_dataloader_params.keys() != test_dataset_params.keys():
raise ValueError("test_dataloader_params and test_dataset_params should have the same keys.")

test_loaders = {}
for dataset_name, dataset_params in test_dataset_params.items():
loader_name = test_dataloaders[dataset_name] if test_dataloaders is not None else None
dataset_params = test_dataset_params[dataset_name]
dataloader_params = test_dataloader_params[dataset_name] if test_dataloader_params is not None else cfg.dataset_params.val_dataloader_params
loader = dataloaders.get(loader_name, dataset_params=dataset_params, dataloader_params=dataloader_params)
test_loaders[dataset_name] = loader

return test_loaders
21 changes: 2 additions & 19 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from super_gradients.common.factories.losses_factory import LossesFactory
from super_gradients.common.factories.metrics_factory import MetricsFactory
from super_gradients.common.environment.package_utils import get_installed_packages
from super_gradients.common.environment.cfg_utils import maybe_instantiate_test_loaders

from super_gradients.training import utils as core_utils, models, dataloaders
from super_gradients.training.datasets.samplers import RepeatAugSampler
Expand Down Expand Up @@ -284,25 +285,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup
dataloader_params=cfg.dataset_params.val_dataloader_params,
)

test_loaders = {}
if "test_dataset_params" in cfg.dataset_params:
test_dataloaders = get_param(cfg, "test_dataloaders")
test_dataset_params = cfg.dataset_params.test_dataset_params
test_dataloader_params = get_param(cfg.dataset_params, "test_dataloader_params")

if test_dataloaders is not None:
if not isinstance(test_dataloaders, Mapping):
raise ValueError("`test_dataloaders` should be a mapping from test_loader_name to test_loader_params.")

if test_dataloader_params is not None and test_dataloader_params.keys() != test_dataset_params.keys():
raise ValueError("test_dataloader_params and test_dataset_params should have the same keys.")

for dataset_name, dataset_params in test_dataset_params.items():
loader_name = test_dataloaders[dataset_name] if test_dataloaders is not None else None
dataset_params = test_dataset_params[dataset_name]
dataloader_params = test_dataloader_params[dataset_name] if test_dataloader_params is not None else cfg.dataset_params.val_dataloader_params
loader = dataloaders.get(loader_name, dataset_params=dataset_params, dataloader_params=dataloader_params)
test_loaders[dataset_name] = loader
test_loaders = maybe_instantiate_test_loaders(cfg)

recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
# TRAIN
Expand Down

0 comments on commit df8abdd

Please sign in to comment.