Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/sg 000 merge failure for datasetparams #1140

Merged
merged 11 commits into from
Jun 7, 2023
15 changes: 11 additions & 4 deletions src/super_gradients/training/dataloaders/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf, UnsupportedValueType
from omegaconf import OmegaConf, UnsupportedValueType, DictConfig, open_dict
from torch.utils.data import BatchSampler, DataLoader, TensorDataset, RandomSampler

import super_gradients
Expand Down Expand Up @@ -101,12 +101,19 @@ def _process_dataset_params(cfg, dataset_params, train: bool):
# >>> dataset_params = OmegaConf.merge(default_dataset_params, dataset_params)
# >>> return hydra.utils.instantiate(dataset_params)
# For some reason this breaks interpolation :shrug:

if not isinstance(dataset_params, DictConfig):
dataset_params = OmegaConf.create(dataset_params)
if train:
cfg.train_dataset_params = OmegaConf.merge(cfg.train_dataset_params, dataset_params)
train_dataset_params = cfg.train_dataset_params
with open_dict(train_dataset_params):
train_dataset_params.merge_with(dataset_params)
cfg.train_dataset_params = train_dataset_params
return hydra.utils.instantiate(cfg.train_dataset_params)
else:
cfg.val_dataset_params = OmegaConf.merge(cfg.val_dataset_params, dataset_params)
val_dataset_params = cfg.val_dataset_params
with open_dict(val_dataset_params):
val_dataset_params.merge_with(dataset_params)
cfg.val_dataset_params = val_dataset_params
return hydra.utils.instantiate(cfg.val_dataset_params)

except UnsupportedValueType:
Expand Down
40 changes: 38 additions & 2 deletions tests/unit_tests/detection_dataset_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
import unittest
from pathlib import Path
from typing import Dict

from super_gradients.training.dataloaders import coco2017_train_yolo_nas
from torch.utils.data import DataLoader
from super_gradients.training.dataloaders import coco2017_train_yolo_nas, get_data_loader
from super_gradients.training.datasets import COCODetectionDataset
from super_gradients.training.datasets.data_formats.default_formats import LABEL_CXCYWH
from super_gradients.training.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
from super_gradients.training.transforms import DetectionMosaic, DetectionTargetsFormatTransform, DetectionPaddedRescale


class DummyCOCODetectionDatasetInheritor(COCODetectionDataset):
def __init__(self, json_file: str, subdir: str, dummy_field: int, *args, **kwargs):
super(DummyCOCODetectionDatasetInheritor, self).__init__(json_file=json_file, subdir=subdir, *args, **kwargs)
self.dummy_field = dummy_field


def dummy_coco2017_inheritor_train_yolo_nas(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
return get_data_loader(
config_name="coco_detection_yolo_nas_dataset_params",
dataset_cls=DummyCOCODetectionDatasetInheritor,
train=True,
dataset_params=dataset_params,
dataloader_params=dataloader_params,
)


class DetectionDatasetTest(unittest.TestCase):
def setUp(self) -> None:
self.mini_coco_data_dir = str(Path(__file__).parent.parent / "data" / "tinycoco")
Expand All @@ -23,7 +41,6 @@ def test_normal_coco_dataset_creation(self):
COCODetectionDataset(**train_dataset_params)

def test_coco_dataset_creation_with_wrong_classes(self):

train_dataset_params = {
"data_dir": self.mini_coco_data_dir,
"subdir": "images/train2017",
Expand Down Expand Up @@ -88,6 +105,25 @@ def test_coco_detection_dataset_override_with_objects(self):
self.assertEqual(batch[0].shape[2], 384)
self.assertEqual(batch[0].shape[3], 384)

def test_coco_detection_dataset_override_with_new_entries(self):
train_dataset_params = {
"data_dir": self.mini_coco_data_dir,
"input_dim": 384,
"transforms": [
DetectionMosaic(input_dim=384),
DetectionPaddedRescale(input_dim=384, max_targets=10),
DetectionTargetsFormatTransform(max_targets=10, output_format=LABEL_CXCYWH),
],
"dummy_field": 10,
}
train_dataloader_params = {"num_workers": 0}
dataloader = dummy_coco2017_inheritor_train_yolo_nas(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
batch = next(iter(dataloader))
print(batch[0].shape)
self.assertEqual(batch[0].shape[2], 384)
self.assertEqual(batch[0].shape[3], 384)
self.assertEqual(dataloader.dataset.dummy_field, 10)


if __name__ == "__main__":
unittest.main()