diff --git a/src/super_gradients/training/dataloaders/dataloaders.py b/src/super_gradients/training/dataloaders/dataloaders.py index 8526bd73b4..4e5e6ad236 100644 --- a/src/super_gradients/training/dataloaders/dataloaders.py +++ b/src/super_gradients/training/dataloaders/dataloaders.py @@ -3,7 +3,7 @@ import hydra import numpy as np import torch -from omegaconf import OmegaConf, UnsupportedValueType +from omegaconf import OmegaConf, UnsupportedValueType, DictConfig, open_dict from torch.utils.data import BatchSampler, DataLoader, TensorDataset, RandomSampler import super_gradients @@ -101,12 +101,19 @@ def _process_dataset_params(cfg, dataset_params, train: bool): # >>> dataset_params = OmegaConf.merge(default_dataset_params, dataset_params) # >>> return hydra.utils.instantiate(dataset_params) # For some reason this breaks interpolation :shrug: - + if not isinstance(dataset_params, DictConfig): + dataset_params = OmegaConf.create(dataset_params) if train: - cfg.train_dataset_params = OmegaConf.merge(cfg.train_dataset_params, dataset_params) + train_dataset_params = cfg.train_dataset_params + with open_dict(train_dataset_params): + train_dataset_params.merge_with(dataset_params) + cfg.train_dataset_params = train_dataset_params return hydra.utils.instantiate(cfg.train_dataset_params) else: - cfg.val_dataset_params = OmegaConf.merge(cfg.val_dataset_params, dataset_params) + val_dataset_params = cfg.val_dataset_params + with open_dict(val_dataset_params): + val_dataset_params.merge_with(dataset_params) + cfg.val_dataset_params = val_dataset_params return hydra.utils.instantiate(cfg.val_dataset_params) except UnsupportedValueType: diff --git a/tests/unit_tests/detection_dataset_test.py b/tests/unit_tests/detection_dataset_test.py index fda8597f3d..3103364359 100644 --- a/tests/unit_tests/detection_dataset_test.py +++ b/tests/unit_tests/detection_dataset_test.py @@ -1,13 +1,31 @@ import unittest from pathlib import Path +from typing import Dict -from super_gradients.training.dataloaders import coco2017_train_yolo_nas +from torch.utils.data import DataLoader +from super_gradients.training.dataloaders import coco2017_train_yolo_nas, get_data_loader from super_gradients.training.datasets import COCODetectionDataset from super_gradients.training.datasets.data_formats.default_formats import LABEL_CXCYWH from super_gradients.training.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException from super_gradients.training.transforms import DetectionMosaic, DetectionTargetsFormatTransform, DetectionPaddedRescale +class DummyCOCODetectionDatasetInheritor(COCODetectionDataset): + def __init__(self, json_file: str, subdir: str, dummy_field: int, *args, **kwargs): + super(DummyCOCODetectionDatasetInheritor, self).__init__(json_file=json_file, subdir=subdir, *args, **kwargs) + self.dummy_field = dummy_field + + +def dummy_coco2017_inheritor_train_yolo_nas(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader: + return get_data_loader( + config_name="coco_detection_yolo_nas_dataset_params", + dataset_cls=DummyCOCODetectionDatasetInheritor, + train=True, + dataset_params=dataset_params, + dataloader_params=dataloader_params, + ) + + class DetectionDatasetTest(unittest.TestCase): def setUp(self) -> None: self.mini_coco_data_dir = str(Path(__file__).parent.parent / "data" / "tinycoco") @@ -23,7 +41,6 @@ def test_normal_coco_dataset_creation(self): COCODetectionDataset(**train_dataset_params) def test_coco_dataset_creation_with_wrong_classes(self): - train_dataset_params = { "data_dir": self.mini_coco_data_dir, "subdir": "images/train2017", @@ -88,6 +105,25 @@ def test_coco_detection_dataset_override_with_objects(self): self.assertEqual(batch[0].shape[2], 384) self.assertEqual(batch[0].shape[3], 384) + def test_coco_detection_dataset_override_with_new_entries(self): + train_dataset_params = { + "data_dir": self.mini_coco_data_dir, + "input_dim": 384, + "transforms": [ + DetectionMosaic(input_dim=384), + DetectionPaddedRescale(input_dim=384, max_targets=10), + DetectionTargetsFormatTransform(max_targets=10, output_format=LABEL_CXCYWH), + ], + "dummy_field": 10, + } + train_dataloader_params = {"num_workers": 0} + dataloader = dummy_coco2017_inheritor_train_yolo_nas(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + batch = next(iter(dataloader)) + print(batch[0].shape) + self.assertEqual(batch[0].shape[2], 384) + self.assertEqual(batch[0].shape[3], 384) + self.assertEqual(dataloader.dataset.dummy_field, 10) + if __name__ == "__main__": unittest.main()