diff --git a/src/super_gradients/training/models/detection_models/yolo_base.py b/src/super_gradients/training/models/detection_models/yolo_base.py index dab2a8523f..4ce9058258 100755 --- a/src/super_gradients/training/models/detection_models/yolo_base.py +++ b/src/super_gradients/training/models/detection_models/yolo_base.py @@ -102,16 +102,19 @@ def __init__( self.class_agnostic_nms = class_agnostic_nms self.multi_label_per_box = multi_label_per_box - def forward(self, x, device: str = None): + def forward(self, x: Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]], device: str = None): """Apply NMS to the raw output of the model and keep only top `max_predictions` results. :param x: Raw output of the model, with x[0] expected to be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...) :return: List of Tensors of shape (x1, y1, x2, y2, conf, cls) """ + # Use the main output features in case of multiple outputs. + if isinstance(x, (tuple, list)): + x = x[0] if self.nms_type == NMS_Type.ITERATIVE: nms_result = non_max_suppression( - x[0], + x, conf_thres=self.conf, iou_thres=self.iou, with_confidence=self.with_confidence, @@ -119,7 +122,7 @@ def forward(self, x, device: str = None): class_agnostic_nms=self.class_agnostic_nms, ) else: - nms_result = matrix_non_max_suppression(x[0], conf_thres=self.conf, max_num_of_detections=self.max_pred, class_agnostic_nms=self.class_agnostic_nms) + nms_result = matrix_non_max_suppression(x, conf_thres=self.conf, max_num_of_detections=self.max_pred, class_agnostic_nms=self.class_agnostic_nms) return self._filter_max_predictions(nms_result) diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index 28e3da17a4..c1d37602ca 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -21,6 +21,7 @@ CallTrainAfterTestTest, CrashTipTest, TestTransforms, + TestPostPredictionCallback, ) from tests.end_to_end_tests import TestTrainer from tests.unit_tests.detection_utils_test import TestDetectionUtils @@ -141,6 +142,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLONAS)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DeprecationsUnitTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMinSamplesSingleNode)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPostPredictionCallback)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationMetricsMultipleIgnored)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainWithTorchSchedulerTest)) diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py index 06100a4d4e..7656d036d0 100644 --- a/tests/unit_tests/__init__.py +++ b/tests/unit_tests/__init__.py @@ -23,6 +23,7 @@ from tests.unit_tests.training_params_factory_test import TrainingParamsTest from tests.unit_tests.config_inspector_test import ConfigInspectTest from tests.unit_tests.transforms_test import TestTransforms +from tests.unit_tests.post_prediction_callback_test import TestPostPredictionCallback __all__ = [ "CrashTipTest", @@ -49,4 +50,5 @@ "CallTrainAfterTestTest", "ConfigInspectTest", "TestTransforms", + "TestPostPredictionCallback", ] diff --git a/tests/unit_tests/post_prediction_callback_test.py b/tests/unit_tests/post_prediction_callback_test.py new file mode 100644 index 0000000000..aee212b884 --- /dev/null +++ b/tests/unit_tests/post_prediction_callback_test.py @@ -0,0 +1,50 @@ +import unittest + +import torch + +from super_gradients.training import models +from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback + + +class TestPostPredictionCallback(unittest.TestCase): + def _default_yolo_post_prediction_callback(self): + """ + Use low confidence to force a non-empty nms result. + """ + return YoloPostPredictionCallback(conf=1e-6) + + def _default_mock_decoded_output(self): + """ + mock output tensor after a final decode module, i.e DetectX, with shapes [B, Num anchors, 5 + num_classes] + """ + return torch.cat([torch.randn(1, 500, 4), torch.sigmoid(torch.randn(1, 500, 81))], dim=2) # localization # classification scores + + def test_yolo_post_prediction_callback_single_input(self): + callback = self._default_yolo_post_prediction_callback() + + mock_single_model_output = self._default_mock_decoded_output() + _ = callback(mock_single_model_output) + + def test_yolo_post_prediction_callback_multiple_input(self): + callback = self._default_yolo_post_prediction_callback() + + mock_multiple_model_outputs = [self._default_mock_decoded_output(), [torch.randn(1, 1, 10, 10, 85), torch.randn(1, 1, 20, 20, 85)]] # mock logits + # sanity check multiple input as list + _ = callback(mock_multiple_model_outputs) + # sanity check multiple input as tuple + _ = callback(tuple(mock_multiple_model_outputs)) + + def test_yolo_post_prediction_callback_yolox_output(self): + """ + Sanity check for yolox usage with YoloPostPredictionCallback. + """ + callback = self._default_yolo_post_prediction_callback() + model = models.get(model_name="yolox_s", num_classes=80).eval() + + x = torch.randn(1, 3, 320, 320) + output = model(x) + _ = callback(output) + + +if __name__ == "__main__": + unittest.main()