Skip to content

Commit

Permalink
[Enhancement] Implement Distance-Based Detection Matching in Detectio…
Browse files Browse the repository at this point in the history
…nMetrics (#1463)

* Signing unsigned commits in new branch

* Fix Missing `iou_thresholds` Argument for Distance-Based Metrics

- Updated `compute_detection_matching` to conditionally require `iou_thresholds` only for IoU-based strategies.

* Check for matching_strategy object

Returned check if matching_strategy object is passed

* Fix thresholds len for Distance-Based and IoU Metrics

- Updated `compute_img_detection_matching` to conditionally get thresholds using matching_strategy get_thresholds() method.

Note: Since thresholds passed via matching_strategy class, consider to remove iou_thresholds param
  • Loading branch information
DimaBir authored Dec 8, 2023
1 parent a466697 commit 953a671
Show file tree
Hide file tree
Showing 5 changed files with 1,114 additions and 84 deletions.
1 change: 1 addition & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Metrics:
BINARY_DICE = "BinaryDice"
PIXEL_ACCURACY = "PixelAccuracy"
POSE_ESTIMATION_METRICS = "PoseEstimationMetrics"
DETECTION_METRICS_DISTANCE_BASED = "DetectionMetricsDistanceBased"


class Transforms:
Expand Down
96 changes: 91 additions & 5 deletions src/super_gradients/training/metrics/detection_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from super_gradients.common.object_names import Metrics
from super_gradients.common.registry.registry import register_metric
from super_gradients.training.utils import tensor_container_to_device
from super_gradients.training.utils.detection_utils import compute_detection_matching, compute_detection_metrics
from super_gradients.training.utils.detection_utils import (
compute_detection_matching,
compute_detection_metrics,
DistanceMetric,
EuclideanDistance,
IoUMatching,
DistanceMatching,
)
from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback, IouThreshold
from super_gradients.common.abstractions.abstract_logger import get_logger

Expand Down Expand Up @@ -144,6 +151,7 @@ def update(self, preds, target: torch.Tensor, device: str, inputs: torch.tensor,
"""
self.iou_thresholds = self.iou_thresholds.to(device)
_, _, height, width = inputs.shape
iou_matcher = IoUMatching(self.iou_thresholds)

targets = target.clone()
crowd_targets = torch.zeros(size=(0, 6), device=device) if crowd_targets is None else crowd_targets.clone()
Expand All @@ -155,7 +163,8 @@ def update(self, preds, target: torch.Tensor, device: str, inputs: torch.tensor,
targets,
height,
width,
self.iou_thresholds,
iou_thresholds=iou_matcher.get_thresholds(),
matching_strategy=iou_matcher,
crowd_targets=crowd_targets,
top_k=self.top_k_predictions,
denormalize_targets=self.denormalize_targets,
Expand Down Expand Up @@ -243,6 +252,86 @@ def _get_range_str(self):
return "@%.2f" % self.iou_thresholds[0] if not len(self.iou_thresholds) > 1 else "@%.2f:%.2f" % (self.iou_thresholds[0], self.iou_thresholds[-1])


@register_metric(Metrics.DETECTION_METRICS_DISTANCE_BASED)
class DetectionMetricsDistanceBased(DetectionMetrics):
def __init__(
self,
num_cls: int,
post_prediction_callback: DetectionPostPredictionCallback,
normalize_targets: bool = False,
distance_thresholds: List[float] = [5.0],
distance_metric: DistanceMetric = EuclideanDistance(),
recall_thres: torch.Tensor = None,
score_thres: float = 0.1,
top_k_predictions: int = 100,
dist_sync_on_step: bool = False,
accumulate_on_cpu: bool = True,
calc_best_score_thresholds: bool = False,
include_classwise_ap: bool = False,
class_names: List[str] = None,
):
self.distance_thresholds = distance_thresholds
self.distance_metric = distance_metric
super().__init__(
num_cls=num_cls,
post_prediction_callback=post_prediction_callback,
normalize_targets=normalize_targets,
recall_thres=recall_thres,
score_thres=score_thres,
top_k_predictions=top_k_predictions,
dist_sync_on_step=dist_sync_on_step,
accumulate_on_cpu=accumulate_on_cpu,
calc_best_score_thresholds=calc_best_score_thresholds,
include_classwise_ap=include_classwise_ap,
class_names=class_names,
)

def update(self, preds: torch.Tensor, target: torch.Tensor, device: str, inputs: torch.tensor, crowd_targets: Optional[torch.Tensor] = None) -> None:
"""
Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.
Use distance-based definition of true positives.
:param preds: torch.Tensor: Raw output of the model. The format might change from one model to another,
but has to fit the input format of the post_prediction_callback (cx, cy, wh).
:param target: torch.Tensor: Targets for all images of shape (total_num_targets, 6) LABEL_CXCYWH.
Format: (index, label, cx, cy, w, h)
:param device: str: Device to run on.
:param inputs: torch.Tensor: Input image tensor of shape (batch_size, n_img, height, width).
:param crowd_targets: Optional[torch.Tensor]: Crowd targets for all images of shape (total_num_targets, 6), LABEL_CXCYWH.
"""
_, _, height, width = inputs.shape

distance_matcher = DistanceMatching(self.distance_metric, self.distance_thresholds)

targets = target.clone()
crowd_targets = torch.zeros(size=(0, 6), device=device) if crowd_targets is None else crowd_targets.clone()

preds = self.post_prediction_callback(preds, device=device)

new_matching_info = compute_detection_matching(
output=preds,
targets=targets,
height=height,
width=width,
crowd_targets=crowd_targets,
top_k=self.top_k_predictions,
denormalize_targets=self.denormalize_targets,
device=self.device,
return_on_cpu=self.accumulate_on_cpu,
matching_strategy=distance_matcher,
)

accumulated_matching_info = getattr(self, f"matching_info{self._get_range_str()}")
setattr(self, f"matching_info{self._get_range_str()}", accumulated_matching_info + new_matching_info)

def _get_range_str(self):
return (
"@DIST%.2f" % self.distance_thresholds[0]
if not len(self.distance_thresholds) > 1
else "@DIST%.2f:%.2f" % (self.distance_thresholds[0], self.distance_thresholds[-1])
)


@register_metric(Metrics.DETECTION_METRICS_050)
class DetectionMetrics_050(DetectionMetrics):
def __init__(
Expand All @@ -259,7 +348,6 @@ def __init__(
include_classwise_ap: bool = False,
class_names: List[str] = None,
):

super().__init__(
num_cls=num_cls,
post_prediction_callback=post_prediction_callback,
Expand Down Expand Up @@ -292,7 +380,6 @@ def __init__(
include_classwise_ap: bool = False,
class_names: List[str] = None,
):

super().__init__(
num_cls=num_cls,
post_prediction_callback=post_prediction_callback,
Expand Down Expand Up @@ -325,7 +412,6 @@ def __init__(
include_classwise_ap: bool = False,
class_names: List[str] = None,
):

super().__init__(
num_cls=num_cls,
post_prediction_callback=post_prediction_callback,
Expand Down
Loading

0 comments on commit 953a671

Please sign in to comment.