Skip to content

Commit

Permalink
Feature/sg 000 introduce class agnoistic nms option (#1232)
Browse files Browse the repository at this point in the history
* Added class_agnostic_nms

* Rename YoloPostPredictionCallback -> YoloXPostPredictionCallback
Make YoloPostPredictionCallback as deprecated

* Update docstrings

* Fix table paddings

* Import deprecated class

* Specify version at which class was made deprecated.
New class_agnostic_nms won't be exposed in this class to stimilate migration to a new class.

* Update docstring on multi_label_per_box parameter

* Update docstring on multi_label_per_box parameter
  • Loading branch information
BloodAxe authored Jul 3, 2023
1 parent 67970d4 commit 07ec5e5
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 65 deletions.
24 changes: 12 additions & 12 deletions documentation/source/ObjectDetection.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ In SuperGradients, we aim to collect such models and make them very convenient a

## Implemented models

| Model | Yaml | Model class | Loss Class | NMS Callback |
|------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [SSD](https://arxiv.org/abs/1512.02325) | [ssd_lite_mobilenetv2_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/ssd_lite_mobilenetv2_arch_params.yaml) | [SSDLiteMobileNetV2](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/models/detection_models/ssd.py) | [SSDLoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.ssd_loss.SSDLoss) | [SSDPostPredictCallback](https://docs.deci.ai/super-gradients/docstring/training/utils/#training.utils.ssd_utils.SSDPostPredictCallback) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | [yolox_s_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/yolox_s_arch_params.yaml) | [YoloX_S](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/models/detection_models/yolox.py) | [YoloXFastDetectionLoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.yolox_loss.YoloXFastDetectionLoss) | [YoloPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.yolo_base.YoloPostPredictionCallback) |
| [PPYolo](https://arxiv.org/abs/2007.12099) | [ppyoloe_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/ppyoloe_arch_params.yaml) | [PPYoloE](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.pp_yolo_e.pp_yolo_e.PPYoloE) | [PPYoloELoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.ppyolo_loss.PPYoloELoss) | [PPYoloEPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.pp_yolo_e.post_prediction_callback.PPYoloEPostPredictionCallback) |
| Model | Yaml | Model class | Loss Class | NMS Callback |
|--------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [SSD](https://arxiv.org/abs/1512.02325) | [ssd_lite_mobilenetv2_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/ssd_lite_mobilenetv2_arch_params.yaml) | [SSDLiteMobileNetV2](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/models/detection_models/ssd.py) | [SSDLoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.ssd_loss.SSDLoss) | [SSDPostPredictCallback](https://docs.deci.ai/super-gradients/docstring/training/utils/#training.utils.ssd_utils.SSDPostPredictCallback) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | [yolox_s_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/yolox_s_arch_params.yaml) | [YoloX_S](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/models/detection_models/yolox.py) | [YoloXFastDetectionLoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.yolox_loss.YoloXFastDetectionLoss) | [YoloXPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.yolo_base.YoloXPostPredictionCallback) |
| [PPYolo](https://arxiv.org/abs/2007.12099) | [ppyoloe_arch_params](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/arch_params/ppyoloe_arch_params.yaml) | [PPYoloE](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.pp_yolo_e.pp_yolo_e.PPYoloE) | [PPYoloELoss](https://docs.deci.ai/super-gradients/docstring/training/losses/#training.losses.ppyolo_loss.PPYoloELoss) | [PPYoloEPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.pp_yolo_e.post_prediction_callback.PPYoloEPostPredictionCallback) |


## Training
Expand Down Expand Up @@ -73,16 +73,16 @@ In order to use `DetectionMetrics` you have to pass a so-called `post_prediction
### Postprocessing

Postprocessing refers to a process of transforming the model's raw output into final predictions. Postprocessing is also model-specific and depends on the model's output format.
For `YOLOX` model, the postprocessing step is implemented in [YoloPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.yolo_base.YoloPostPredictionCallback) class.
For `YOLOX` model, the postprocessing step is implemented in [YoloXPostPredictionCallback](https://docs.deci.ai/super-gradients/docstring/training/models/#training.models.detection_models.yolo_base.YoloXPostPredictionCallback) class.
It can be passed into a `DetectionMetrics` as a `post_prediction_callback`.
The postprocessing of all detection models involves non-maximum suppression (NMS) which filters dense model's predictions and leaves only boxes with the highest confidence and suppresses boxes with very high overlap
based on the assumption that they likely belong to the same object. Thus, a confidence threshold and an IoU threshold must be passed into the postprocessing object.

```python
from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
from super_gradients.training.models.detection_models.yolo_base import YoloXPostPredictionCallback


post_prediction_callback = YoloPostPredictionCallback(conf=0.001, iou=0.6)
post_prediction_callback = YoloXPostPredictionCallback(conf=0.001, iou=0.6)
```

### Visualization
Expand Down Expand Up @@ -114,7 +114,7 @@ def my_undo_image_preprocessing(im_tensor: torch.Tensor) -> np.ndarray:

model = models.get("yolox_s", pretrained_weights="coco", num_classes=80)
imgs, targets = next(iter(train_dataloader))
preds = YoloPostPredictionCallback(conf=0.1, iou=0.6)(model(imgs))
preds = YoloXPostPredictionCallback(conf=0.1, iou=0.6)(model(imgs))
DetectionVisualization.visualize_batch(imgs, preds, targets, batch_name='train', class_names=COCO_DETECTION_CLASSES_LIST,
checkpoint_dir='/path/for/saved_images/', gt_alpha=0.5,
undo_preprocessing_func=my_undo_image_preprocessing)
Expand Down Expand Up @@ -148,13 +148,13 @@ valid_metrics_list:
- DetectionMetrics:
normalize_targets: True
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback
_target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback
iou: 0.65
conf: 0.01
num_cls: 80
```

Notice how `YoloPostPredictionCallback` is passed as a `post_prediction_callback`.
Notice how `YoloXPostPredictionCallback` is passed as a `post_prediction_callback`.

A visualization belongs to `training_hyperparams` as well, specifically to the `phase_callbacks` list, as follows:
```yaml
Expand All @@ -165,7 +165,7 @@ phase_callbacks:
value: VALIDATION_EPOCH_END
freq: 1
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback
_target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback
iou: 0.65
conf: 0.01
classes: [
Expand Down
4 changes: 2 additions & 2 deletions src/super_gradients/recipes/roboflow_yolox.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ training_hyperparams:
- DetectionMetrics:
normalize_targets: True
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback
_target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback
iou: 0.65
conf: 0.01
num_cls: 80
valid_metrics_list:
- DetectionMetrics:
normalize_targets: True
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback
_target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback
iou: 0.65
conf: 0.01
num_cls: 80
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ valid_metrics_list:
- DetectionMetrics:
normalize_targets: True
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.yolo_base.YoloPostPredictionCallback
_target_: super_gradients.training.models.detection_models.yolo_base.YoloXPostPredictionCallback
iou: 0.65
conf: 0.01
num_cls: 80
Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/training/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloE, PPYoloE_S, PPYoloE_M, PPYoloE_L, PPYoloE_X
from super_gradients.training.models.detection_models.darknet53 import Darknet53, Darknet53Base
from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2
from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloPostPredictionCallback
from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloXPostPredictionCallback, YoloPostPredictionCallback
from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X, CustomYoloX
from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
from super_gradients.training.models.detection_models.yolo_nas import (
Expand Down Expand Up @@ -291,6 +291,7 @@ def inner(*args, **kwargs):
"YoloX_X",
"CustomYoloX",
"YoloPostPredictionCallback",
"YoloXPostPredictionCallback",
"CustomizableDetector",
"ShelfNet50",
"ShelfNet101",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def __init__(self, score_threshold: float, nms_threshold: float, nms_top_k: int,
:param iou: IoU threshold for NMS step.
:param nms_top_k: Number of predictions participating in NMS step
:param max_predictions: maximum number of boxes to return after NMS step
:param multi_label_per_box: controls whether to decode multiple labels per box.
True - each anchor can produce multiple labels of different classes
that pass confidence threshold check (default).
False - each anchor can produce only one label of the class with the highest score.
"""
super(PPYoloEPostPredictionCallback, self).__init__()
self.score_threshold = score_threshold
Expand Down
57 changes: 51 additions & 6 deletions src/super_gradients/training/models/detection_models/yolo_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import warnings
from typing import Union, Type, List, Tuple, Optional
from functools import lru_cache

Expand Down Expand Up @@ -58,8 +59,10 @@
}


class YoloPostPredictionCallback(DetectionPostPredictionCallback):
"""Non-Maximum Suppression (NMS) module"""
class YoloXPostPredictionCallback(DetectionPostPredictionCallback):
"""Post-prediction callback to decode YoloX model's output and apply Non-Maximum Suppression (NMS) to get
the final predictions.
"""

def __init__(
self,
Expand All @@ -69,6 +72,8 @@ def __init__(
nms_type: NMS_Type = NMS_Type.ITERATIVE,
max_predictions: int = 300,
with_confidence: bool = True,
class_agnostic_nms: bool = False,
multi_label_per_box: bool = True,
):
"""
:param conf: confidence threshold
Expand All @@ -78,14 +83,24 @@ def __init__(
:param max_predictions: maximum number of boxes to output (used in NMS_Type.MATRIX)
:param with_confidence: in NMS, whether to multiply objectness (used in NMS_Type.ITERATIVE)
score with class score
:param class_agnostic_nms: indicates how boxes of different classes will be treated during
NMS step (used in NMS_Type.ITERATIVE and NMS_Type.MATRIX)
True - NMS will be performed on all classes together.
False - NMS will be performed on each class separately (default).
:param multi_label_per_box: controls whether to decode multiple labels per box (used in NMS_Type.ITERATIVE)
True - each anchor can produce multiple labels of different classes
that pass confidence threshold check (default).
False - each anchor can produce only one label of the class with the highest score.
"""
super(YoloPostPredictionCallback, self).__init__()
super(YoloXPostPredictionCallback, self).__init__()
self.conf = conf
self.iou = iou
self.classes = classes
self.nms_type = nms_type
self.max_pred = max_predictions
self.with_confidence = with_confidence
self.class_agnostic_nms = class_agnostic_nms
self.multi_label_per_box = multi_label_per_box

def forward(self, x, device: str = None):
"""Apply NMS to the raw output of the model and keep only top `max_predictions` results.
Expand All @@ -95,9 +110,16 @@ def forward(self, x, device: str = None):
"""

if self.nms_type == NMS_Type.ITERATIVE:
nms_result = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, with_confidence=self.with_confidence)
nms_result = non_max_suppression(
x[0],
conf_thres=self.conf,
iou_thres=self.iou,
with_confidence=self.with_confidence,
multi_label_per_box=self.multi_label_per_box,
class_agnostic_nms=self.class_agnostic_nms,
)
else:
nms_result = matrix_non_max_suppression(x[0], conf_thres=self.conf, max_num_of_detections=self.max_pred)
nms_result = matrix_non_max_suppression(x[0], conf_thres=self.conf, max_num_of_detections=self.max_pred, class_agnostic_nms=self.class_agnostic_nms)

return self._filter_max_predictions(nms_result)

Expand All @@ -106,6 +128,29 @@ def _filter_max_predictions(self, res: List) -> List:
return res


class YoloPostPredictionCallback(YoloXPostPredictionCallback):
def __init__(
self,
conf: float = 0.001,
iou: float = 0.6,
classes: List[int] = None,
nms_type: NMS_Type = NMS_Type.ITERATIVE,
max_predictions: int = 300,
with_confidence: bool = True,
):
warnings.warn("YoloPostPredictionCallback is deprecated since SG 3.1.3, please use YoloXPostPredictionCallback instead", DeprecationWarning)
super().__init__(
conf=conf,
iou=iou,
classes=classes,
nms_type=nms_type,
max_predictions=max_predictions,
with_confidence=with_confidence,
class_agnostic_nms=False,
multi_label_per_box=True,
)


class Concat(nn.Module):
"""CONCATENATE A LIST OF TENSORS ALONG DIMENSION"""

Expand Down Expand Up @@ -427,7 +472,7 @@ def __init__(self, backbone: Type[nn.Module], arch_params: HpmStruct, initialize

@staticmethod
def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
return YoloPostPredictionCallback(conf=conf, iou=iou)
return YoloXPostPredictionCallback(conf=conf, iou=iou)

@resolve_param("image_processor", ProcessingFactory())
def set_dataset_processing_params(
Expand Down
Loading

0 comments on commit 07ec5e5

Please sign in to comment.