diff --git a/src/super_gradients/training/datasets/data_formats/__init__.py b/src/super_gradients/training/datasets/data_formats/__init__.py index aa8fc5862d..c0801e0af1 100644 --- a/src/super_gradients/training/datasets/data_formats/__init__.py +++ b/src/super_gradients/training/datasets/data_formats/__init__.py @@ -1,6 +1,6 @@ from .format_converter import ConcatenatedTensorFormatConverter from .output_adapters import DetectionOutputAdapter -from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem +from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem, LabelTensorSliceItem from .bbox_formats import ( CXCYWHCoordinateFormat, NormalizedCXCYWHCoordinateFormat, @@ -21,6 +21,7 @@ "NormalizedXYWHCoordinateFormat", "NormalizedXYXYCoordinateFormat", "TensorSliceItem", + "LabelTensorSliceItem", "XYWHCoordinateFormat", "XYXYCoordinateFormat", "YXYXCoordinateFormat", diff --git a/src/super_gradients/training/datasets/data_formats/default_formats.py b/src/super_gradients/training/datasets/data_formats/default_formats.py index 83439d8b37..e96d3abf90 100644 --- a/src/super_gradients/training/datasets/data_formats/default_formats.py +++ b/src/super_gradients/training/datasets/data_formats/default_formats.py @@ -1,5 +1,5 @@ from super_gradients.common.object_names import ConcatenatedTensorFormats -from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem +from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, LabelTensorSliceItem from super_gradients.training.datasets.data_formats.bbox_formats import ( XYXYCoordinateFormat, XYWHCoordinateFormat, @@ -12,72 +12,72 @@ XYXY_LABEL = ConcatenatedTensorFormat( layout=( BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()), - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), ) ) XYWH_LABEL = ConcatenatedTensorFormat( layout=( BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()), - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), ) ) CXCYWH_LABEL = ConcatenatedTensorFormat( layout=( BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()), - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), ) ) LABEL_XYXY = ConcatenatedTensorFormat( layout=( - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()), ) ) LABEL_XYWH = ConcatenatedTensorFormat( layout=( - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()), ) ) LABEL_CXCYWH = ConcatenatedTensorFormat( layout=( - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()), ) ) NORMALIZED_XYXY_LABEL = ConcatenatedTensorFormat( layout=( BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()), - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), ) ) NORMALIZED_XYWH_LABEL = ConcatenatedTensorFormat( layout=( BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()), - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), ) ) NORMALIZED_CXCYWH_LABEL = ConcatenatedTensorFormat( layout=( BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()), - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), ) ) LABEL_NORMALIZED_XYXY = ConcatenatedTensorFormat( layout=( - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()), ) ) LABEL_NORMALIZED_XYWH = ConcatenatedTensorFormat( layout=( - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()), ) ) LABEL_NORMALIZED_CXCYWH = ConcatenatedTensorFormat( layout=( - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()), ) ) diff --git a/src/super_gradients/training/datasets/data_formats/formats.py b/src/super_gradients/training/datasets/data_formats/formats.py index 610039be9a..052b9e22bf 100644 --- a/src/super_gradients/training/datasets/data_formats/formats.py +++ b/src/super_gradients/training/datasets/data_formats/formats.py @@ -34,6 +34,13 @@ def __repr__(self): return f"name={self.name} length={self.length} format={self.format}" +class LabelTensorSliceItem(TensorSliceItem): + NAME = "labels" + + def __init__(self): + super().__init__(name=self.NAME, length=1) + + class ConcatenatedTensorFormat(DetectionOutputFormat): """ Define the output format that return a single tensor of shape [N,M] (N - number of detections, diff --git a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py index f7489f3c60..75f2764b50 100644 --- a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py +++ b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py @@ -16,14 +16,14 @@ from super_gradients.common.object_names import Datasets, Processings from super_gradients.common.registry.registry import register_dataset from super_gradients.common.decorators.factory_decorator import resolve_param -from super_gradients.training.utils.detection_utils import get_cls_posx_in_target +from super_gradients.training.utils.detection_utils import get_class_index_in_target from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.training.transforms.transforms import DetectionTransform, DetectionTargetsFormatTransform, DetectionTargetsFormat from super_gradients.common.exceptions.dataset_exceptions import EmptyDatasetException, DatasetValidationException from super_gradients.common.factories.list_factory import ListFactory from super_gradients.common.factories.transforms_factory import TransformsFactory from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL -from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat +from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, LabelTensorSliceItem from super_gradients.training.utils.utils import ensure_is_tuple_of_two logger = get_logger(__name__) @@ -298,26 +298,26 @@ def _sub_class_annotation(self, annotation: dict) -> Union[dict, None]: :param annotation: Dict representing the annotation of a specific image :return: Subclassed annotation if non-empty after subclassing, otherwise None """ - cls_posx = get_cls_posx_in_target(self.original_target_format) + class_index = _get_class_index_in_target(target_format=self.original_target_format) for field in self.target_fields: - annotation[field] = self._sub_class_target(targets=annotation[field], cls_posx=cls_posx) + annotation[field] = self._sub_class_target(targets=annotation[field], class_index=class_index) return annotation - def _sub_class_target(self, targets: np.ndarray, cls_posx: int) -> np.ndarray: + def _sub_class_target(self, targets: np.ndarray, class_index: int) -> np.ndarray: """Sublass targets of a specific image. :param targets: Target array to subclass of shape [n_targets, 5], 5 representing a bbox - :param cls_posx: Position of the class id in a bbox + :param class_index: Position of the class id in a bbox ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label :return: Subclassed target """ targets_kept = [] for target in targets: - cls_id = int(target[cls_posx]) + cls_id = int(target[class_index]) cls_name = self.all_classes_list[cls_id] if cls_name in self.class_inclusion_list: # Replace the target cls_id in self.all_classes_list by cls_id in self.class_inclusion_list - target[cls_posx] = self.class_inclusion_list.index(cls_name) + target[class_index] = self.class_inclusion_list.index(cls_name) targets_kept.append(target) return np.array(targets_kept) if len(targets_kept) > 0 else np.zeros((0, 5), dtype=np.float32) @@ -568,5 +568,16 @@ def get_dataset_preprocessing_params(self): return params -# TODO -# - Integration Test +def _get_class_index_in_target(target_format: DetectionTargetsFormat) -> int: + """Get the index of the class in the target format. + :param target_format: format of the target. E.g. XYXY_LABEL, LABEL_NORMALIZED_XYXY, ect... + :return: index of the class in the target format. E.g. XYXY_LABEL -> 4, LABEL_NORMALIZED_XYXY -> 0, ect.... + """ + if isinstance(target_format, ConcatenatedTensorFormat): + return target_format.indexes[LabelTensorSliceItem.NAME][0] + elif isinstance(target_format, DetectionTargetsFormat): + return get_class_index_in_target(target_format) + else: + raise NotImplementedError( + f"{target_format} is not supported. Supported formats are: {ConcatenatedTensorFormat.__name__}, {DetectionTargetsFormat.__name__}" + ) diff --git a/src/super_gradients/training/utils/detection_utils.py b/src/super_gradients/training/utils/detection_utils.py index eb5d84066b..1158b19c8b 100755 --- a/src/super_gradients/training/utils/detection_utils.py +++ b/src/super_gradients/training/utils/detection_utils.py @@ -39,7 +39,7 @@ class DetectionTargetsFormat(Enum): NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL" -def get_cls_posx_in_target(target_format: DetectionTargetsFormat) -> int: +def get_class_index_in_target(target_format: DetectionTargetsFormat) -> int: """Get the label of a given target :param target_format: Representation of the target (ex: LABEL_XYXY) :return: Position of the class id in a bbox diff --git a/tests/unit_tests/detection_output_adapter_test.py b/tests/unit_tests/detection_output_adapter_test.py index b4ace6cb28..76ab9eb93a 100644 --- a/tests/unit_tests/detection_output_adapter_test.py +++ b/tests/unit_tests/detection_output_adapter_test.py @@ -17,6 +17,7 @@ YXYXCoordinateFormat, NormalizedCXCYWHCoordinateFormat, DetectionOutputAdapter, + LabelTensorSliceItem, ) from super_gradients.training.datasets.data_formats.bbox_formats.normalized_cxcywh import xyxy_to_normalized_cxcywh @@ -25,7 +26,7 @@ layout=( BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()), TensorSliceItem(length=1, name="scores"), - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), ) ) @@ -33,14 +34,14 @@ layout=( BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()), TensorSliceItem(length=1, name="scores"), - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), ) ) CXCYWH_LABELS_SCORES_DISTANCE_ATTR = ConcatenatedTensorFormat( layout=( BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()), - TensorSliceItem(length=1, name="labels"), + LabelTensorSliceItem(), TensorSliceItem(length=1, name="scores"), TensorSliceItem(length=1, name="distance"), TensorSliceItem(length=4, name="attributes"), diff --git a/tests/unit_tests/detection_sub_classing_test.py b/tests/unit_tests/detection_sub_classing_test.py index b4100e349b..9c6901b611 100644 --- a/tests/unit_tests/detection_sub_classing_test.py +++ b/tests/unit_tests/detection_sub_classing_test.py @@ -1,13 +1,17 @@ +from super_gradients.common.exceptions.dataset_exceptions import EmptyDatasetException, DatasetValidationException + import unittest import numpy as np +from typing import Union from super_gradients.training.datasets import DetectionDataset from super_gradients.training.utils.detection_utils import DetectionTargetsFormat -from super_gradients.common.exceptions.dataset_exceptions import EmptyDatasetException, DatasetValidationException +from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat +from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL class DummyDetectionDataset(DetectionDataset): - def __init__(self, input_dim, *args, **kwargs): + def __init__(self, input_dim, target_format: Union[DetectionTargetsFormat, ConcatenatedTensorFormat], *args, **kwargs): """Dummy Dataset testing subclassing, designed with no annotation that includes class_2.""" self.dummy_targets = [ @@ -17,7 +21,7 @@ def __init__(self, input_dim, *args, **kwargs): self.image_size = input_dim kwargs["all_classes_list"] = ["class_0", "class_1", "class_2"] - kwargs["original_target_format"] = DetectionTargetsFormat.XYXY_LABEL + kwargs["original_target_format"] = target_format super().__init__(data_dir="", input_dim=input_dim, *args, **kwargs) def _setup_data_source(self): @@ -53,28 +57,45 @@ def setUp(self) -> None: def test_subclass_keep_empty(self): """Check that subclassing only keeps annotations of wanted class""" for config in self.config_keep_empty_annotation: - test_dataset = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=False, class_inclusion_list=config["class_inclusion_list"]) + test_dataset = DummyDetectionDataset( + input_dim=(640, 512), ignore_empty_annotations=False, class_inclusion_list=config["class_inclusion_list"], target_format=XYXY_LABEL + ) n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset) self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass) def test_subclass_drop_empty(self): """Check that empty annotations are not indexed (i.e. ignored) when ignore_empty_annotations=True""" for config in self.config_ignore_empty_annotation: - test_dataset = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=config["class_inclusion_list"]) + test_dataset = DummyDetectionDataset( + input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=config["class_inclusion_list"], target_format=XYXY_LABEL + ) n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset) self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass) # Check last case when class_2, which should raise EmptyDatasetException because not a single image has # a target in class_inclusion_list with self.assertRaises(EmptyDatasetException): - DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=["class_2"]) + DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=["class_2"], target_format=XYXY_LABEL) def test_wrong_subclass(self): """Check that ValueError is raised when class_inclusion_list includes a class that does not exist.""" with self.assertRaises(DatasetValidationException): - DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["non_existing_class"]) + DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["non_existing_class"], target_format=XYXY_LABEL) with self.assertRaises(DatasetValidationException): - DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["class_0", "non_existing_class"]) + DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["class_0", "non_existing_class"], target_format=XYXY_LABEL) + + def test_legacy_detection_targets_format(self): + """Check that ValueError is raised when class_inclusion_list includes a class that does not exist.""" + + for config in self.config_keep_empty_annotation: + test_dataset = DummyDetectionDataset( + input_dim=(640, 512), + ignore_empty_annotations=False, + class_inclusion_list=config["class_inclusion_list"], + target_format=DetectionTargetsFormat.XYXY_LABEL, + ) + n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset) + self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass) def _count_targets_after_subclass_per_index(test_dataset: DummyDetectionDataset):