Skip to content

Commit

Permalink
Fix class_inclusion_list in DetectionDataset (#1327)
Browse files Browse the repository at this point in the history
* fix

* add test

* rename
  • Loading branch information
Louis-Dupont authored Aug 1, 2023
1 parent 7afa83d commit 5792080
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .format_converter import ConcatenatedTensorFormatConverter
from .output_adapters import DetectionOutputAdapter
from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem, LabelTensorSliceItem
from .bbox_formats import (
CXCYWHCoordinateFormat,
NormalizedCXCYWHCoordinateFormat,
Expand All @@ -21,6 +21,7 @@
"NormalizedXYWHCoordinateFormat",
"NormalizedXYXYCoordinateFormat",
"TensorSliceItem",
"LabelTensorSliceItem",
"XYWHCoordinateFormat",
"XYXYCoordinateFormat",
"YXYXCoordinateFormat",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from super_gradients.common.object_names import ConcatenatedTensorFormats
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, LabelTensorSliceItem
from super_gradients.training.datasets.data_formats.bbox_formats import (
XYXYCoordinateFormat,
XYWHCoordinateFormat,
Expand All @@ -12,72 +12,72 @@
XYXY_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
XYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
CXCYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
LABEL_XYXY = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
)
)
LABEL_XYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()),
)
)
LABEL_CXCYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
)
)
NORMALIZED_XYXY_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
NORMALIZED_XYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
NORMALIZED_CXCYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
LABEL_NORMALIZED_XYXY = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()),
)
)
LABEL_NORMALIZED_XYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
)
)
LABEL_NORMALIZED_CXCYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
)
)
Expand Down
7 changes: 7 additions & 0 deletions src/super_gradients/training/datasets/data_formats/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def __repr__(self):
return f"name={self.name} length={self.length} format={self.format}"


class LabelTensorSliceItem(TensorSliceItem):
NAME = "labels"

def __init__(self):
super().__init__(name=self.NAME, length=1)


class ConcatenatedTensorFormat(DetectionOutputFormat):
"""
Define the output format that return a single tensor of shape [N,M] (N - number of detections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.training.utils.detection_utils import get_cls_posx_in_target
from super_gradients.training.utils.detection_utils import get_class_index_in_target
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.transforms.transforms import DetectionTransform, DetectionTargetsFormatTransform, DetectionTargetsFormat
from super_gradients.common.exceptions.dataset_exceptions import EmptyDatasetException, DatasetValidationException
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, LabelTensorSliceItem
from super_gradients.training.utils.utils import ensure_is_tuple_of_two

logger = get_logger(__name__)
Expand Down Expand Up @@ -298,26 +298,26 @@ def _sub_class_annotation(self, annotation: dict) -> Union[dict, None]:
:param annotation: Dict representing the annotation of a specific image
:return: Subclassed annotation if non-empty after subclassing, otherwise None
"""
cls_posx = get_cls_posx_in_target(self.original_target_format)
class_index = _get_class_index_in_target(target_format=self.original_target_format)
for field in self.target_fields:
annotation[field] = self._sub_class_target(targets=annotation[field], cls_posx=cls_posx)
annotation[field] = self._sub_class_target(targets=annotation[field], class_index=class_index)
return annotation

def _sub_class_target(self, targets: np.ndarray, cls_posx: int) -> np.ndarray:
def _sub_class_target(self, targets: np.ndarray, class_index: int) -> np.ndarray:
"""Sublass targets of a specific image.
:param targets: Target array to subclass of shape [n_targets, 5], 5 representing a bbox
:param cls_posx: Position of the class id in a bbox
:param class_index: Position of the class id in a bbox
ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label
:return: Subclassed target
"""
targets_kept = []
for target in targets:
cls_id = int(target[cls_posx])
cls_id = int(target[class_index])
cls_name = self.all_classes_list[cls_id]
if cls_name in self.class_inclusion_list:
# Replace the target cls_id in self.all_classes_list by cls_id in self.class_inclusion_list
target[cls_posx] = self.class_inclusion_list.index(cls_name)
target[class_index] = self.class_inclusion_list.index(cls_name)
targets_kept.append(target)

return np.array(targets_kept) if len(targets_kept) > 0 else np.zeros((0, 5), dtype=np.float32)
Expand Down Expand Up @@ -568,5 +568,16 @@ def get_dataset_preprocessing_params(self):
return params


# TODO
# - Integration Test
def _get_class_index_in_target(target_format: DetectionTargetsFormat) -> int:
"""Get the index of the class in the target format.
:param target_format: format of the target. E.g. XYXY_LABEL, LABEL_NORMALIZED_XYXY, ect...
:return: index of the class in the target format. E.g. XYXY_LABEL -> 4, LABEL_NORMALIZED_XYXY -> 0, ect....
"""
if isinstance(target_format, ConcatenatedTensorFormat):
return target_format.indexes[LabelTensorSliceItem.NAME][0]
elif isinstance(target_format, DetectionTargetsFormat):
return get_class_index_in_target(target_format)
else:
raise NotImplementedError(
f"{target_format} is not supported. Supported formats are: {ConcatenatedTensorFormat.__name__}, {DetectionTargetsFormat.__name__}"
)
2 changes: 1 addition & 1 deletion src/super_gradients/training/utils/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DetectionTargetsFormat(Enum):
NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL"


def get_cls_posx_in_target(target_format: DetectionTargetsFormat) -> int:
def get_class_index_in_target(target_format: DetectionTargetsFormat) -> int:
"""Get the label of a given target
:param target_format: Representation of the target (ex: LABEL_XYXY)
:return: Position of the class id in a bbox
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests/detection_output_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
YXYXCoordinateFormat,
NormalizedCXCYWHCoordinateFormat,
DetectionOutputAdapter,
LabelTensorSliceItem,
)

from super_gradients.training.datasets.data_formats.bbox_formats.normalized_cxcywh import xyxy_to_normalized_cxcywh
Expand All @@ -25,22 +26,22 @@
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
TensorSliceItem(length=1, name="scores"),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)

CXCYWH_SCORES_LABELS = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="scores"),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)

CXCYWH_LABELS_SCORES_DISTANCE_ATTR = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
TensorSliceItem(length=1, name="scores"),
TensorSliceItem(length=1, name="distance"),
TensorSliceItem(length=4, name="attributes"),
Expand Down
37 changes: 29 additions & 8 deletions tests/unit_tests/detection_sub_classing_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from super_gradients.common.exceptions.dataset_exceptions import EmptyDatasetException, DatasetValidationException

import unittest
import numpy as np
from typing import Union

from super_gradients.training.datasets import DetectionDataset
from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
from super_gradients.common.exceptions.dataset_exceptions import EmptyDatasetException, DatasetValidationException
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL


class DummyDetectionDataset(DetectionDataset):
def __init__(self, input_dim, *args, **kwargs):
def __init__(self, input_dim, target_format: Union[DetectionTargetsFormat, ConcatenatedTensorFormat], *args, **kwargs):
"""Dummy Dataset testing subclassing, designed with no annotation that includes class_2."""

self.dummy_targets = [
Expand All @@ -17,7 +21,7 @@ def __init__(self, input_dim, *args, **kwargs):

self.image_size = input_dim
kwargs["all_classes_list"] = ["class_0", "class_1", "class_2"]
kwargs["original_target_format"] = DetectionTargetsFormat.XYXY_LABEL
kwargs["original_target_format"] = target_format
super().__init__(data_dir="", input_dim=input_dim, *args, **kwargs)

def _setup_data_source(self):
Expand Down Expand Up @@ -53,28 +57,45 @@ def setUp(self) -> None:
def test_subclass_keep_empty(self):
"""Check that subclassing only keeps annotations of wanted class"""
for config in self.config_keep_empty_annotation:
test_dataset = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=False, class_inclusion_list=config["class_inclusion_list"])
test_dataset = DummyDetectionDataset(
input_dim=(640, 512), ignore_empty_annotations=False, class_inclusion_list=config["class_inclusion_list"], target_format=XYXY_LABEL
)
n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset)
self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass)

def test_subclass_drop_empty(self):
"""Check that empty annotations are not indexed (i.e. ignored) when ignore_empty_annotations=True"""
for config in self.config_ignore_empty_annotation:
test_dataset = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=config["class_inclusion_list"])
test_dataset = DummyDetectionDataset(
input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=config["class_inclusion_list"], target_format=XYXY_LABEL
)
n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset)
self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass)

# Check last case when class_2, which should raise EmptyDatasetException because not a single image has
# a target in class_inclusion_list
with self.assertRaises(EmptyDatasetException):
DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=["class_2"])
DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=["class_2"], target_format=XYXY_LABEL)

def test_wrong_subclass(self):
"""Check that ValueError is raised when class_inclusion_list includes a class that does not exist."""
with self.assertRaises(DatasetValidationException):
DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["non_existing_class"])
DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["non_existing_class"], target_format=XYXY_LABEL)
with self.assertRaises(DatasetValidationException):
DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["class_0", "non_existing_class"])
DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["class_0", "non_existing_class"], target_format=XYXY_LABEL)

def test_legacy_detection_targets_format(self):
"""Check that ValueError is raised when class_inclusion_list includes a class that does not exist."""

for config in self.config_keep_empty_annotation:
test_dataset = DummyDetectionDataset(
input_dim=(640, 512),
ignore_empty_annotations=False,
class_inclusion_list=config["class_inclusion_list"],
target_format=DetectionTargetsFormat.XYXY_LABEL,
)
n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset)
self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass)


def _count_targets_after_subclass_per_index(test_dataset: DummyDetectionDataset):
Expand Down

0 comments on commit 5792080

Please sign in to comment.