diff --git a/documentation/source/ObjectDetection.md b/documentation/source/ObjectDetection.md index 3befaddaba..2b73159b5a 100644 --- a/documentation/source/ObjectDetection.md +++ b/documentation/source/ObjectDetection.md @@ -10,11 +10,11 @@ In SuperGradients, we aim to collect such models and make them very convenient a ## Implemented models -| Model | Yaml | Model class | Loss Class | NMS Callback | -|------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [SSD](https://arxiv.org/abs/1512.02325) | [ssd_lite_mobilenetv2_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/ssd_lite_mobilenetv2_arch_params.yaml) | [SSDLiteMobileNetV2](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/models/detection_models/ssd.py) | [SSDLoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.ssd_loss.SSDLoss) | [SSDPostPredictCallback](https://docs.deci.ai/super-gradients/docstring/training/utils/#training.utils.ssd_utils.SSDPostPredictCallback) | -| [YOLOX](https://arxiv.org/abs/2107.08430) | [yolox_s_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/yolox_s_arch_params.yaml) | [YoloX_S](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/models/detection_models/yolox.py) | [YoloXFastDetectionLoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.yolox_loss.YoloXFastDetectionLoss) | [YoloPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.yolo_base.YoloPostPredictionCallback) | -| [PPYolo](https://arxiv.org/abs/2007.12099) | [ppyoloe_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/ppyoloe_arch_params.yaml) | [PPYoloE](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.pp_yolo_e.pp_yolo_e.PPYoloE) | [PPYoloELoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.ppyolo_loss.PPYoloELoss) | [PPYoloEPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.pp_yolo_e.post_prediction_callback.PPYoloEPostPredictionCallback) | +| Model | Yaml | Model class | Loss Class | NMS Callback | +|--------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [SSD](https://arxiv.org/abs/1512.02325) | [ssd_lite_mobilenetv2_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/ssd_lite_mobilenetv2_arch_params.yaml) | [SSDLiteMobileNetV2](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/models/detection_models/ssd.py) | [SSDLoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.ssd_loss.SSDLoss) | [SSDPostPredictCallback](https://docs.deci.ai/super-gradients/docstring/training/utils/#training.utils.ssd_utils.SSDPostPredictCallback) | +| [YOLOX](https://arxiv.org/abs/2107.08430) | [yolox_s_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/yolox_s_arch_params.yaml) | [YoloX_S](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/models/detection_models/yolox.py) | [YoloXFastDetectionLoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.yolox_loss.YoloXFastDetectionLoss) | [YoloXPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.yolo_base.YoloXPostPredictionCallback) | +| [PPYolo](https://arxiv.org/abs/2007.12099) | [ppyoloe_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/ppyoloe_arch_params.yaml) | [PPYoloE](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.pp_yolo_e.pp_yolo_e.PPYoloE) | [PPYoloELoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.ppyolo_loss.PPYoloELoss) | [PPYoloEPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.pp_yolo_e.post_prediction_callback.PPYoloEPostPredictionCallback) | ## Training @@ -73,16 +73,16 @@ In order to use `DetectionMetrics` you have to pass a so-called `post_prediction ### Postprocessing Postprocessing refers to a process of transforming the model's raw output into final predictions. Postprocessing is also model-specific and depends on the model's output format. -For `YOLOX` model, the postprocessing step is implemented in [YoloPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.yolo_base.YoloPostPredictionCallback) class. +For `YOLOX` model, the postprocessing step is implemented in [YoloXPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.yolo_base.YoloXPostPredictionCallback) class. It can be passed into a `DetectionMetrics` as a `post_prediction_callback`. The postprocessing of all detection models involves non-maximum suppression (NMS) which filters dense model's predictions and leaves only boxes with the highest confidence and suppresses boxes with very high overlap based on the assumption that they likely belong to the same object. Thus, a confidence threshold and an IoU threshold must be passed into the postprocessing object. ```python -from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback +from super_gradients.training.models.detection_models.yolo_base import YoloXPostPredictionCallback -post_prediction_callback = YoloPostPredictionCallback(conf=0.001, iou=0.6) +post_prediction_callback = YoloXPostPredictionCallback(conf=0.001, iou=0.6) ``` ### Visualization @@ -114,7 +114,7 @@ def my_undo_image_preprocessing(im_tensor: torch.Tensor) -> np.ndarray: model = models.get("yolox_s", pretrained_weights="coco", num_classes=80) imgs, targets = next(iter(train_dataloader)) -preds = YoloPostPredictionCallback(conf=0.1, iou=0.6)(model(imgs)) +preds = YoloXPostPredictionCallback(conf=0.1, iou=0.6)(model(imgs)) DetectionVisualization.visualize_batch(imgs, preds, targets, batch_name='train', class_names=COCO_DETECTION_CLASSES_LIST, checkpoint_dir='/path/for/saved_images/', gt_alpha=0.5, undo_preprocessing_func=my_undo_image_preprocessing) @@ -148,13 +148,13 @@ valid_metrics_list: - DetectionMetrics: normalize_targets: True post_prediction_callback: - _target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback + _target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback iou: 0.65 conf: 0.01 num_cls: 80 ``` -Notice how `YoloPostPredictionCallback` is passed as a `post_prediction_callback`. +Notice how `YoloXPostPredictionCallback` is passed as a `post_prediction_callback`. A visualization belongs to `training_hyperparams` as well, specifically to the `phase_callbacks` list, as follows: ```yaml @@ -165,7 +165,7 @@ phase_callbacks: value: VALIDATION_EPOCH_END freq: 1 post_prediction_callback: - _target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback + _target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback iou: 0.65 conf: 0.01 classes: [ diff --git a/src/super_gradients/recipes/roboflow_yolox.yaml b/src/super_gradients/recipes/roboflow_yolox.yaml index c95bc7b3a4..38ad33fa43 100644 --- a/src/super_gradients/recipes/roboflow_yolox.yaml +++ b/src/super_gradients/recipes/roboflow_yolox.yaml @@ -45,7 +45,7 @@ training_hyperparams: - DetectionMetrics: normalize_targets: True post_prediction_callback: - _target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback + _target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback iou: 0.65 conf: 0.01 num_cls: 80 @@ -53,7 +53,7 @@ training_hyperparams: - DetectionMetrics: normalize_targets: True post_prediction_callback: - _target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback + _target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback iou: 0.65 conf: 0.01 num_cls: 80 diff --git a/src/super_gradients/recipes/training_hyperparams/coco2017_yolox_train_params.yaml b/src/super_gradients/recipes/training_hyperparams/coco2017_yolox_train_params.yaml index 4c1a0fa7de..cb0df61965 100644 --- a/src/super_gradients/recipes/training_hyperparams/coco2017_yolox_train_params.yaml +++ b/src/super_gradients/recipes/training_hyperparams/coco2017_yolox_train_params.yaml @@ -33,7 +33,7 @@ valid_metrics_list: - DetectionMetrics: normalize_targets: True post_prediction_callback: - _target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback + _target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback iou: 0.65 conf: 0.01 num_cls: 80 diff --git a/src/super_gradients/training/models/__init__.py b/src/super_gradients/training/models/__init__.py index 9437ec1526..fcbf815aae 100755 --- a/src/super_gradients/training/models/__init__.py +++ b/src/super_gradients/training/models/__init__.py @@ -70,7 +70,7 @@ from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloE, PPYoloE_S, PPYoloE_M, PPYoloE_L, PPYoloE_X from super_gradients.training.models.detection_models.darknet53 import Darknet53, Darknet53Base from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2 -from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloPostPredictionCallback +from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloXPostPredictionCallback, YoloPostPredictionCallback from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X, CustomYoloX from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector from super_gradients.training.models.detection_models.yolo_nas import ( @@ -291,6 +291,7 @@ def inner(*args, **kwargs): "YoloX_X", "CustomYoloX", "YoloPostPredictionCallback", + "YoloXPostPredictionCallback", "CustomizableDetector", "ShelfNet50", "ShelfNet101", diff --git a/src/super_gradients/training/models/detection_models/pp_yolo_e/post_prediction_callback.py b/src/super_gradients/training/models/detection_models/pp_yolo_e/post_prediction_callback.py index 2690bfa57e..6582e6665a 100644 --- a/src/super_gradients/training/models/detection_models/pp_yolo_e/post_prediction_callback.py +++ b/src/super_gradients/training/models/detection_models/pp_yolo_e/post_prediction_callback.py @@ -15,7 +15,10 @@ def __init__(self, score_threshold: float, nms_threshold: float, nms_top_k: int, :param iou: IoU threshold for NMS step. :param nms_top_k: Number of predictions participating in NMS step :param max_predictions: maximum number of boxes to return after NMS step - + :param multi_label_per_box: controls whether to decode multiple labels per box. + True - each anchor can produce multiple labels of different classes + that pass confidence threshold check (default). + False - each anchor can produce only one label of the class with the highest score. """ super(PPYoloEPostPredictionCallback, self).__init__() self.score_threshold = score_threshold diff --git a/src/super_gradients/training/models/detection_models/yolo_base.py b/src/super_gradients/training/models/detection_models/yolo_base.py index 69d2ba0165..dab2a8523f 100755 --- a/src/super_gradients/training/models/detection_models/yolo_base.py +++ b/src/super_gradients/training/models/detection_models/yolo_base.py @@ -1,4 +1,5 @@ import math +import warnings from typing import Union, Type, List, Tuple, Optional from functools import lru_cache @@ -58,8 +59,10 @@ } -class YoloPostPredictionCallback(DetectionPostPredictionCallback): - """Non-Maximum Suppression (NMS) module""" +class YoloXPostPredictionCallback(DetectionPostPredictionCallback): + """Post-prediction callback to decode YoloX model's output and apply Non-Maximum Suppression (NMS) to get + the final predictions. + """ def __init__( self, @@ -69,6 +72,8 @@ def __init__( nms_type: NMS_Type = NMS_Type.ITERATIVE, max_predictions: int = 300, with_confidence: bool = True, + class_agnostic_nms: bool = False, + multi_label_per_box: bool = True, ): """ :param conf: confidence threshold @@ -78,14 +83,24 @@ def __init__( :param max_predictions: maximum number of boxes to output (used in NMS_Type.MATRIX) :param with_confidence: in NMS, whether to multiply objectness (used in NMS_Type.ITERATIVE) score with class score + :param class_agnostic_nms: indicates how boxes of different classes will be treated during + NMS step (used in NMS_Type.ITERATIVE and NMS_Type.MATRIX) + True - NMS will be performed on all classes together. + False - NMS will be performed on each class separately (default). + :param multi_label_per_box: controls whether to decode multiple labels per box (used in NMS_Type.ITERATIVE) + True - each anchor can produce multiple labels of different classes + that pass confidence threshold check (default). + False - each anchor can produce only one label of the class with the highest score. """ - super(YoloPostPredictionCallback, self).__init__() + super(YoloXPostPredictionCallback, self).__init__() self.conf = conf self.iou = iou self.classes = classes self.nms_type = nms_type self.max_pred = max_predictions self.with_confidence = with_confidence + self.class_agnostic_nms = class_agnostic_nms + self.multi_label_per_box = multi_label_per_box def forward(self, x, device: str = None): """Apply NMS to the raw output of the model and keep only top `max_predictions` results. @@ -95,9 +110,16 @@ def forward(self, x, device: str = None): """ if self.nms_type == NMS_Type.ITERATIVE: - nms_result = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, with_confidence=self.with_confidence) + nms_result = non_max_suppression( + x[0], + conf_thres=self.conf, + iou_thres=self.iou, + with_confidence=self.with_confidence, + multi_label_per_box=self.multi_label_per_box, + class_agnostic_nms=self.class_agnostic_nms, + ) else: - nms_result = matrix_non_max_suppression(x[0], conf_thres=self.conf, max_num_of_detections=self.max_pred) + nms_result = matrix_non_max_suppression(x[0], conf_thres=self.conf, max_num_of_detections=self.max_pred, class_agnostic_nms=self.class_agnostic_nms) return self._filter_max_predictions(nms_result) @@ -106,6 +128,29 @@ def _filter_max_predictions(self, res: List) -> List: return res +class YoloPostPredictionCallback(YoloXPostPredictionCallback): + def __init__( + self, + conf: float = 0.001, + iou: float = 0.6, + classes: List[int] = None, + nms_type: NMS_Type = NMS_Type.ITERATIVE, + max_predictions: int = 300, + with_confidence: bool = True, + ): + warnings.warn("YoloPostPredictionCallback is deprecated since SG 3.1.3, please use YoloXPostPredictionCallback instead", DeprecationWarning) + super().__init__( + conf=conf, + iou=iou, + classes=classes, + nms_type=nms_type, + max_predictions=max_predictions, + with_confidence=with_confidence, + class_agnostic_nms=False, + multi_label_per_box=True, + ) + + class Concat(nn.Module): """CONCATENATE A LIST OF TENSORS ALONG DIMENSION""" @@ -427,7 +472,7 @@ def __init__(self, backbone: Type[nn.Module], arch_params: HpmStruct, initialize @staticmethod def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback: - return YoloPostPredictionCallback(conf=conf, iou=iou) + return YoloXPostPredictionCallback(conf=conf, iou=iou) @resolve_param("image_processor", ProcessingFactory()) def set_dataset_processing_params( diff --git a/src/super_gradients/training/utils/detection_utils.py b/src/super_gradients/training/utils/detection_utils.py index ae9728ccbf..b3e565d9f1 100755 --- a/src/super_gradients/training/utils/detection_utils.py +++ b/src/super_gradients/training/utils/detection_utils.py @@ -242,18 +242,26 @@ def box_area(box): return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) -def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label_per_box: bool = True, with_confidence: bool = False): +def non_max_suppression( + prediction, conf_thres=0.1, iou_thres=0.6, multi_label_per_box: bool = True, with_confidence: bool = False, class_agnostic_nms: bool = False +): """ Performs Non-Maximum Suppression (NMS) on inference results - :param prediction: raw model prediction. Should be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...) - :param conf_thres: below the confidence threshold - prediction are discarded - :param iou_thres: IoU threshold for the nms algorithm - :param multi_label_per_box: whether to use re-use each box with all possible labels - (instead of the maximum confidence all confidences above threshold - will be sent to NMS); by default is set to True - :param with_confidence: whether to multiply objectness score with class score. - usually valid for Yolo models only. - :return: detections with shape nx6 (x1, y1, x2, y2, object_conf, class_conf, class) + + :param prediction: raw model prediction. Should be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...) + :param conf_thres: below the confidence threshold - prediction are discarded + :param iou_thres: IoU threshold for the nms algorithm + :param multi_label_per_box: controls whether to decode multiple labels per box. + True - each anchor can produce multiple labels of different classes + that pass confidence threshold check (default). + False - each anchor can produce only one label of the class with the highest score. + :param with_confidence: whether to multiply objectness score with class score. + usually valid for Yolo models only. + :param class_agnostic_nms: indicates how boxes of different classes will be treated during NMS + True - NMS will be performed on all classes together. + False - NMS will be performed on each class separately (default). + :return: detections with shape nx6 (x1, y1, x2, y2, object_conf, class_conf, class) + """ candidates_above_thres = prediction[..., 4] > conf_thres # filter by confidence output = [None] * prediction.shape[0] @@ -284,18 +292,17 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label_p # Apply torch batched NMS algorithm boxes, scores, cls_idx = pred[:, :4], pred[:, 4], pred[:, 5] - idx_to_keep = torchvision.ops.boxes.batched_nms(boxes, scores, cls_idx, iou_thres) + if class_agnostic_nms: + idx_to_keep = torchvision.ops.boxes.nms(boxes, scores, iou_thres) + else: + idx_to_keep = torchvision.ops.boxes.batched_nms(boxes, scores, cls_idx, iou_thres) output[image_idx] = pred[idx_to_keep] return output def matrix_non_max_suppression( - pred, - conf_thres: float = 0.1, - kernel: str = "gaussian", - sigma: float = 3.0, - max_num_of_detections: int = 500, + pred, conf_thres: float = 0.1, kernel: str = "gaussian", sigma: float = 3.0, max_num_of_detections: int = 500, class_agnostic_nms: bool = False ) -> List[torch.Tensor]: """Performs Matrix Non-Maximum Suppression (NMS) on inference results https://arxiv.org/pdf/1912.04488.pdf @@ -326,11 +333,12 @@ def matrix_non_max_suppression( ious = ious.triu(1) - # CREATE A LABELS MASK, WE WANT ONLY BOXES WITH THE SAME LABEL TO AFFECT EACH OTHER - labels = pred[:, :, 5:] - labeles_matrix = (labels == labels.transpose(2, 1)).float().triu(1) + if not class_agnostic_nms: + # CREATE A LABELS MASK, WE WANT ONLY BOXES WITH THE SAME LABEL TO AFFECT EACH OTHER + labels = pred[:, :, 5:] + labeles_matrix = (labels == labels.transpose(2, 1)).float().triu(1) + ious *= labeles_matrix - ious *= labeles_matrix ious_cmax, _ = ious.max(1) ious_cmax = ious_cmax.unsqueeze(2).repeat(1, 1, max_num_of_detections) diff --git a/src/super_gradients/training/utils/ssd_utils.py b/src/super_gradients/training/utils/ssd_utils.py index e82578fe06..e9c2e544b5 100755 --- a/src/super_gradients/training/utils/ssd_utils.py +++ b/src/super_gradients/training/utils/ssd_utils.py @@ -121,9 +121,10 @@ def __init__( :param iou: IoU threshold :param classes: (optional list) filter by class :param nms_type: the type of nms to use (iterative or matrix) - :param multi_label_per_box: whether to use re-use each box with all possible labels - (instead of the maximum confidence all confidences above threshold - will be sent to NMS) + :param multi_label_per_box: controls whether to decode multiple labels per box. + True - each anchor can produce multiple labels of different classes + that pass confidence threshold check (default). + False - each anchor can produce only one label of the class with the highest score. """ super(SSDPostPredictCallback, self).__init__() self.conf = conf diff --git a/tests/integration_tests/pretrained_models_test.py b/tests/integration_tests/pretrained_models_test.py index 270635da24..6138bfa213 100644 --- a/tests/integration_tests/pretrained_models_test.py +++ b/tests/integration_tests/pretrained_models_test.py @@ -27,7 +27,7 @@ from super_gradients.training.losses.ddrnet_loss import DDRNetLoss from super_gradients.training.metrics import DetectionMetrics from super_gradients.training.losses.stdc_loss import STDCLoss -from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback +from super_gradients.training.models.detection_models.yolo_base import YoloXPostPredictionCallback from super_gradients.training import models import super_gradients @@ -152,7 +152,7 @@ def setUp(self) -> None: "loss": "yolox_loss", "criterion_params": {"strides": [8, 16, 32], "num_classes": 5}, # output strides of all yolo outputs "train_metrics_list": [], - "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True, num_cls=5)], + "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), normalize_targets=True, num_cls=5)], "metric_to_watch": "mAP@0.50:0.95", "greater_metric_to_watch_is_better": True, } @@ -522,7 +522,7 @@ def test_pretrained_yolox_s_coco(self): res = trainer.test( model=model, test_loader=self.coco_dataset["yolox"], - test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)], + test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), num_cls=80, normalize_targets=True)], ) self.assertAlmostEqual(res["mAP@0.50:0.95"].cpu().item(), self.coco_pretrained_maps[Models.YOLOX_S], delta=0.001) @@ -532,7 +532,7 @@ def test_pretrained_yolox_m_coco(self): res = trainer.test( model=model, test_loader=self.coco_dataset["yolox"], - test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)], + test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), num_cls=80, normalize_targets=True)], ) self.assertAlmostEqual(res["mAP@0.50:0.95"].cpu().item(), self.coco_pretrained_maps[Models.YOLOX_M], delta=0.001) @@ -542,7 +542,7 @@ def test_pretrained_yolox_l_coco(self): res = trainer.test( model=model, test_loader=self.coco_dataset["yolox"], - test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)], + test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), num_cls=80, normalize_targets=True)], ) self.assertAlmostEqual(res["mAP@0.50:0.95"].cpu().item(), self.coco_pretrained_maps[Models.YOLOX_L], delta=0.001) @@ -553,7 +553,7 @@ def test_pretrained_yolox_n_coco(self): res = trainer.test( model=model, test_loader=self.coco_dataset["yolox"], - test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)], + test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), num_cls=80, normalize_targets=True)], ) self.assertAlmostEqual(res["mAP@0.50:0.95"].cpu().item(), self.coco_pretrained_maps[Models.YOLOX_N], delta=0.001) @@ -563,7 +563,7 @@ def test_pretrained_yolox_t_coco(self): res = trainer.test( model=model, test_loader=self.coco_dataset["yolox"], - test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)], + test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), num_cls=80, normalize_targets=True)], ) self.assertAlmostEqual(res["mAP@0.50:0.95"].cpu().item(), self.coco_pretrained_maps[Models.YOLOX_T], delta=0.001) diff --git a/tests/unit_tests/dataset_statistics_test.py b/tests/unit_tests/dataset_statistics_test.py index f610e5a764..b3881b4c9a 100644 --- a/tests/unit_tests/dataset_statistics_test.py +++ b/tests/unit_tests/dataset_statistics_test.py @@ -5,7 +5,7 @@ from super_gradients.training.metrics.detection_metrics import DetectionMetrics from super_gradients.training import Trainer, models -from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback +from super_gradients.training.models.detection_models.yolo_base import YoloXPostPredictionCallback class TestDatasetStatisticsTensorboardLogger(unittest.TestCase): @@ -30,7 +30,7 @@ def test_dataset_statistics_tensorboard_logger(self): "criterion_params": {"strides": [8, 16, 32], "num_classes": 80}, "dataset_statistics": True, "launch_tensorboard": True, - "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True, num_cls=80)], + "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), normalize_targets=True, num_cls=80)], "metric_to_watch": "mAP@0.50:0.95", } trainer.train(model=model, training_params=training_params, train_loader=coco2017_train(), valid_loader=coco2017_val()) diff --git a/tests/unit_tests/detection_utils_test.py b/tests/unit_tests/detection_utils_test.py index 4af450c116..38da4f2e68 100644 --- a/tests/unit_tests/detection_utils_test.py +++ b/tests/unit_tests/detection_utils_test.py @@ -9,7 +9,7 @@ from super_gradients.training.dataloaders.dataloaders import coco2017_val from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST from super_gradients.training.metrics import DetectionMetrics, DetectionMetrics_050 -from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback +from super_gradients.training.models.detection_models.yolo_base import YoloXPostPredictionCallback from super_gradients.training.utils.detection_utils import DetectionVisualization from tests.core_test_utils import is_data_available @@ -25,7 +25,7 @@ def test_visualization(self): valid_loader = coco2017_val(dataloader_params={"batch_size": 16, "num_workers": 0}) trainer = Trainer("visualization_test") - post_prediction_callback = YoloPostPredictionCallback() + post_prediction_callback = YoloXPostPredictionCallback() # Simulate one iteration of validation subset batch_i, batch = 0, next(iter(valid_loader)) @@ -50,9 +50,9 @@ def test_detection_metrics(self): valid_loader = coco2017_val(dataloader_params={"batch_size": 16, "num_workers": 0}) metrics = [ - DetectionMetrics(num_cls=80, post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True), - DetectionMetrics_050(num_cls=80, post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True), - DetectionMetrics(num_cls=80, post_prediction_callback=YoloPostPredictionCallback(conf=2), normalize_targets=True), + DetectionMetrics(num_cls=80, post_prediction_callback=YoloXPostPredictionCallback(), normalize_targets=True), + DetectionMetrics_050(num_cls=80, post_prediction_callback=YoloXPostPredictionCallback(), normalize_targets=True), + DetectionMetrics(num_cls=80, post_prediction_callback=YoloXPostPredictionCallback(conf=2), normalize_targets=True), ] ref_values = [ diff --git a/tests/unit_tests/preprocessing_unit_test.py b/tests/unit_tests/preprocessing_unit_test.py index dff1434e95..1f020554fc 100644 --- a/tests/unit_tests/preprocessing_unit_test.py +++ b/tests/unit_tests/preprocessing_unit_test.py @@ -6,7 +6,7 @@ from super_gradients.training import models from super_gradients.training.datasets import COCODetectionDataset from super_gradients.training.metrics import DetectionMetrics -from super_gradients.training.models import YoloPostPredictionCallback +from super_gradients.training.models import YoloXPostPredictionCallback from super_gradients.training.processing import ReverseImageChannels, DetectionLongestMaxSizeRescale, DetectionBottomRightPadding, ImagePermute from super_gradients.training.utils.detection_utils import DetectionCollateFN, CrowdDetectionCollateFN from super_gradients.training import dataloaders @@ -90,7 +90,7 @@ def test_setting_preprocessing_params_from_validation_set(self): "loss": "yolox_loss", "criterion_params": {"strides": [8, 16, 32], "num_classes": 80}, # output strides of all yolo outputs "train_metrics_list": [], - "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True, num_cls=5)], + "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), normalize_targets=True, num_cls=5)], "metric_to_watch": "mAP@0.50:0.95", "greater_metric_to_watch_is_better": True, "average_best_models": False, @@ -157,7 +157,7 @@ def test_setting_preprocessing_params_from_checkpoint(self): "loss": "yolox_loss", "criterion_params": {"strides": [8, 16, 32], "num_classes": 80}, # output strides of all yolo outputs "train_metrics_list": [], - "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True, num_cls=5)], + "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), normalize_targets=True, num_cls=5)], "metric_to_watch": "mAP@0.50:0.95", "greater_metric_to_watch_is_better": True, "average_best_models": False, diff --git a/tests/unit_tests/test_without_train_test.py b/tests/unit_tests/test_without_train_test.py index b4d066c887..6cf7d13a64 100644 --- a/tests/unit_tests/test_without_train_test.py +++ b/tests/unit_tests/test_without_train_test.py @@ -7,7 +7,7 @@ from super_gradients.training import models from super_gradients.training.metrics.detection_metrics import DetectionMetrics from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU -from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback +from super_gradients.training.models.detection_models.yolo_base import YoloXPostPredictionCallback from super_gradients.common.object_names import Models @@ -51,7 +51,7 @@ def test_test_without_train(self): trainer, model = self.get_detection_trainer(self.folder_names[1]) - test_metrics = [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=5)] + test_metrics = [DetectionMetrics(post_prediction_callback=YoloXPostPredictionCallback(), num_cls=5)] assert isinstance( trainer.test(model=model, silent_mode=True, test_metrics_list=test_metrics, test_loader=detection_test_dataloader(image_size=320)), dict