Skip to content

Commit

Permalink
Allow passing tuple of [low, high] to iou_thresholds parameter in Det…
Browse files Browse the repository at this point in the history
…ectionMetrics. (#1223)

* Allow passing tuple of [low, high] to iou_thresholds parameter in DetectionMetrics

* Added docs
  • Loading branch information
BloodAxe authored Jun 28, 2023
1 parent 81dbed1 commit 1889bc8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
8 changes: 6 additions & 2 deletions src/super_gradients/training/metrics/detection_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, Tuple
import torch
from torchmetrics import Metric

Expand All @@ -24,6 +24,7 @@ class DetectionMetrics(Metric):
:param post_prediction_callback: DetectionPostPredictionCallback to be applied on net's output prior to the metric computation (NMS).
:param normalize_targets: Whether to normalize bbox coordinates by image size.
:param iou_thres: IoU threshold to compute the mAP.
Could be either instance of IouThreshold, a tuple (lower bound, upper_bound) or single scalar.
:param recall_thres: Recall threshold to compute the mAP.
:param score_thres: Score threshold to compute Recall, Precision and F1.
:param top_k_predictions: Number of predictions per class used to compute metrics, ordered by confidence score
Expand All @@ -37,7 +38,7 @@ def __init__(
num_cls: int,
post_prediction_callback: DetectionPostPredictionCallback,
normalize_targets: bool = False,
iou_thres: Union[IouThreshold, float] = IouThreshold.MAP_05_TO_095,
iou_thres: Union[IouThreshold, Tuple[float, float], float] = IouThreshold.MAP_05_TO_095,
recall_thres: torch.Tensor = None,
score_thres: float = 0.1,
top_k_predictions: int = 100,
Expand All @@ -50,6 +51,9 @@ def __init__(

if isinstance(iou_thres, IouThreshold):
self.iou_thresholds = iou_thres.to_tensor()
if isinstance(iou_thres, tuple):
low, high = iou_thres
self.iou_thresholds = IouThreshold.from_bounds(low, high)
else:
self.iou_thresholds = torch.tensor([iou_thres])

Expand Down
16 changes: 13 additions & 3 deletions src/super_gradients/training/utils/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,22 @@ def is_range(self):

def to_tensor(self):
if self.is_range():
n_iou_thresh = int(round((self[1] - self[0]) / 0.05)) + 1
return torch.linspace(self[0], self[1], n_iou_thresh)
return self.from_bounds(self[0], self[1], step=0.05)
else:
n_iou_thresh = 1
return torch.tensor([self[0]])

@classmethod
def from_bounds(cls, low: float, high: float, step: float = 0.05) -> torch.Tensor:
"""
Create a tensor with values from low (including) to high (including) with a given step size.
:param low: Lower bound
:param high: Upper bound
:param step: Step size
:return: Tensor of [low, low + step, low + 2 * step, ..., high]
"""
n_iou_thresh = int(round((high - low) / step)) + 1
return torch.linspace(low, high, n_iou_thresh)


def box_iou(box1: torch.Tensor, box2: torch.Tensor) -> torch.Tensor:
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
Expand Down

0 comments on commit 1889bc8

Please sign in to comment.