Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix class_inclusion_list in DetectionDataset #1327

Merged
merged 5 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .format_converter import ConcatenatedTensorFormatConverter
from .output_adapters import DetectionOutputAdapter
from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem, LabelTensorSliceItem
from .bbox_formats import (
CXCYWHCoordinateFormat,
NormalizedCXCYWHCoordinateFormat,
Expand All @@ -21,6 +21,7 @@
"NormalizedXYWHCoordinateFormat",
"NormalizedXYXYCoordinateFormat",
"TensorSliceItem",
"LabelTensorSliceItem",
"XYWHCoordinateFormat",
"XYXYCoordinateFormat",
"YXYXCoordinateFormat",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from super_gradients.common.object_names import ConcatenatedTensorFormats
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, LabelTensorSliceItem
from super_gradients.training.datasets.data_formats.bbox_formats import (
XYXYCoordinateFormat,
XYWHCoordinateFormat,
Expand All @@ -12,72 +12,72 @@
XYXY_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
XYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
CXCYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
LABEL_XYXY = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
)
)
LABEL_XYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()),
)
)
LABEL_CXCYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
)
)
NORMALIZED_XYXY_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
NORMALIZED_XYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
NORMALIZED_CXCYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
LABEL_NORMALIZED_XYXY = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()),
)
)
LABEL_NORMALIZED_XYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
)
)
LABEL_NORMALIZED_CXCYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def __repr__(self):
return f"name={self.name} length={self.length} format={self.format}"


class LabelTensorSliceItem(TensorSliceItem):
NAME = "labels"

def __init__(self):
super().__init__(name=self.NAME, length=1)


class ConcatenatedTensorFormat(DetectionOutputFormat):
"""
Define the output format that return a single tensor of shape [N,M] (N - number of detections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, LabelTensorSliceItem
from super_gradients.training.utils.utils import ensure_is_tuple_of_two

logger = get_logger(__name__)
Expand Down Expand Up @@ -232,9 +232,9 @@ def _sub_class_annotation(self, annotation: dict) -> Union[dict, None]:
:param annotation: Dict representing the annotation of a specific image
:return: Subclassed annotation if non empty after subclassing, otherwise None
"""
cls_posx = get_cls_posx_in_target(self.original_target_format)
cls_index = _get_cls_index_in_target(target_format=self.original_target_format)
for field in self.target_fields:
annotation[field] = self._sub_class_target(targets=annotation[field], cls_posx=cls_posx)
annotation[field] = self._sub_class_target(targets=annotation[field], cls_posx=cls_index)
return annotation

def _sub_class_target(self, targets: np.ndarray, cls_posx: int) -> np.ndarray:
Expand Down Expand Up @@ -505,3 +505,14 @@ def get_dataset_preprocessing_params(self):
conf=0.5,
)
return params


def _get_cls_index_in_target(target_format: DetectionTargetsFormat) -> int:
if isinstance(target_format, ConcatenatedTensorFormat):
return target_format.indexes[LabelTensorSliceItem.NAME][0]
elif isinstance(target_format, DetectionTargetsFormat):
return get_cls_posx_in_target(target_format)
else:
raise NotImplementedError(
f"{target_format} is not supported. Supported formats are: {ConcatenatedTensorFormat.__name__}, {DetectionTargetsFormat.__name__}"
)
7 changes: 4 additions & 3 deletions tests/unit_tests/detection_output_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
YXYXCoordinateFormat,
NormalizedCXCYWHCoordinateFormat,
DetectionOutputAdapter,
LabelTensorSliceItem,
)

from super_gradients.training.datasets.data_formats.bbox_formats.normalized_cxcywh import xyxy_to_normalized_cxcywh
Expand All @@ -25,22 +26,22 @@
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
TensorSliceItem(length=1, name="scores"),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)

CXCYWH_SCORES_LABELS = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="scores"),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)

CXCYWH_LABELS_SCORES_DISTANCE_ATTR = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
TensorSliceItem(length=1, name="scores"),
TensorSliceItem(length=1, name="distance"),
TensorSliceItem(length=4, name="attributes"),
Expand Down