Skip to content

Commit

Permalink
Fix a bug in implementation of DetectionRGB2BGR.get_equivalent_prepro…
Browse files Browse the repository at this point in the history
…cessing (#1352)
  • Loading branch information
BloodAxe authored Aug 7, 2023
1 parent de435cf commit bc1b24d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/super_gradients/training/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def __call__(self, sample: dict) -> dict:
def get_equivalent_preprocessing(self) -> List:
if self.prob < 1:
raise RuntimeError("Cannot set preprocessing pipeline with randomness. Set prob to 1.")
return [{Processings.ReverseImageChannels}]
return [{Processings.ReverseImageChannels: {}}]


@register_transform(Transforms.DetectionHSV)
Expand Down
25 changes: 24 additions & 1 deletion tests/unit_tests/preprocessing_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,23 @@
import unittest
from pathlib import Path

import numpy as np

from super_gradients import Trainer
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.processing_factory import ProcessingFactory
from super_gradients.training import models
from super_gradients.training.datasets import COCODetectionDataset
from super_gradients.training.metrics import DetectionMetrics
from super_gradients.training.models import YoloXPostPredictionCallback
from super_gradients.training.processing import ReverseImageChannels, DetectionLongestMaxSizeRescale, DetectionBottomRightPadding, ImagePermute
from super_gradients.training.processing import (
ReverseImageChannels,
DetectionLongestMaxSizeRescale,
DetectionBottomRightPadding,
ImagePermute,
ComposeProcessing,
)
from super_gradients.training.transforms import DetectionPaddedRescale, DetectionRGB2BGR
from super_gradients.training.utils.detection_utils import DetectionCollateFN, CrowdDetectionCollateFN
from super_gradients.training import dataloaders

Expand Down Expand Up @@ -176,6 +187,18 @@ def test_setting_preprocessing_params_from_checkpoint(self):
self.assertEqual(model._default_nms_iou, 0.65)
self.assertEqual(model._default_nms_conf, 0.5)

def test_processings_from_dataset_params(self):
transforms = [DetectionRGB2BGR(prob=1), DetectionPaddedRescale(input_dim=(512, 512))]

processings = []
for t in transforms:
processings += t.get_equivalent_preprocessing()

instantiated_processing = ListFactory(ProcessingFactory()).get(processings)
processing_pipeline = ComposeProcessing(instantiated_processing)
result = processing_pipeline.preprocess_image(np.zeros((480, 640, 3)))
print(result)


if __name__ == "__main__":
unittest.main()

0 comments on commit bc1b24d

Please sign in to comment.