Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support single output in YoloX nms #813

Merged
merged 9 commits into from
Jul 5, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,27 @@ def __init__(
self.class_agnostic_nms = class_agnostic_nms
self.multi_label_per_box = multi_label_per_box

def forward(self, x, device: str = None):
def forward(self, x: Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]], device: str = None):
"""Apply NMS to the raw output of the model and keep only top `max_predictions` results.

:param x: Raw output of the model, with x[0] expected to be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...)
:return: List of Tensors of shape (x1, y1, x2, y2, conf, cls)
"""
# Use the main output features in case of multiple outputs.
if isinstance(x, (tuple, list)):
x = x[0]

if self.nms_type == NMS_Type.ITERATIVE:
nms_result = non_max_suppression(
x[0],
x,
conf_thres=self.conf,
iou_thres=self.iou,
with_confidence=self.with_confidence,
multi_label_per_box=self.multi_label_per_box,
class_agnostic_nms=self.class_agnostic_nms,
)
else:
nms_result = matrix_non_max_suppression(x[0], conf_thres=self.conf, max_num_of_detections=self.max_pred, class_agnostic_nms=self.class_agnostic_nms)
nms_result = matrix_non_max_suppression(x, conf_thres=self.conf, max_num_of_detections=self.max_pred, class_agnostic_nms=self.class_agnostic_nms)

return self._filter_max_predictions(nms_result)

Expand Down
2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CallTrainAfterTestTest,
CrashTipTest,
TestTransforms,
TestPostPredictionCallback,
)
from tests.end_to_end_tests import TestTrainer
from tests.unit_tests.detection_utils_test import TestDetectionUtils
Expand Down Expand Up @@ -141,6 +142,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLONAS))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DeprecationsUnitTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMinSamplesSingleNode))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPostPredictionCallback))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationMetricsMultipleIgnored))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainWithTorchSchedulerTest))

Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tests.unit_tests.training_params_factory_test import TrainingParamsTest
from tests.unit_tests.config_inspector_test import ConfigInspectTest
from tests.unit_tests.transforms_test import TestTransforms
from tests.unit_tests.post_prediction_callback_test import TestPostPredictionCallback

__all__ = [
"CrashTipTest",
Expand All @@ -49,4 +50,5 @@
"CallTrainAfterTestTest",
"ConfigInspectTest",
"TestTransforms",
"TestPostPredictionCallback",
]
50 changes: 50 additions & 0 deletions tests/unit_tests/post_prediction_callback_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest

import torch

from super_gradients.training import models
from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback


class TestPostPredictionCallback(unittest.TestCase):
def _default_yolo_post_prediction_callback(self):
"""
Use low confidence to force a non-empty nms result.
"""
return YoloPostPredictionCallback(conf=1e-6)

def _default_mock_decoded_output(self):
"""
mock output tensor after a final decode module, i.e DetectX, with shapes [B, Num anchors, 5 + num_classes]
"""
return torch.cat([torch.randn(1, 500, 4), torch.sigmoid(torch.randn(1, 500, 81))], dim=2) # localization # classification scores

def test_yolo_post_prediction_callback_single_input(self):
callback = self._default_yolo_post_prediction_callback()

mock_single_model_output = self._default_mock_decoded_output()
_ = callback(mock_single_model_output)

def test_yolo_post_prediction_callback_multiple_input(self):
callback = self._default_yolo_post_prediction_callback()

mock_multiple_model_outputs = [self._default_mock_decoded_output(), [torch.randn(1, 1, 10, 10, 85), torch.randn(1, 1, 20, 20, 85)]] # mock logits
# sanity check multiple input as list
_ = callback(mock_multiple_model_outputs)
# sanity check multiple input as tuple
_ = callback(tuple(mock_multiple_model_outputs))

def test_yolo_post_prediction_callback_yolox_output(self):
"""
Sanity check for yolox usage with YoloPostPredictionCallback.
"""
callback = self._default_yolo_post_prediction_callback()
model = models.get(model_name="yolox_s", num_classes=80).eval()

x = torch.randn(1, 3, 320, 320)
output = model(x)
_ = callback(output)


if __name__ == "__main__":
unittest.main()