Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/dg 000 fix empty batch detection #160

Merged
merged 6 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/data_gradients/batch_processors/formatters/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from data_gradients.batch_processors.utils import check_all_integers
from data_gradients.batch_processors.formatters.base import BatchFormatter
from data_gradients.batch_processors.formatters.utils import ensure_images_shape, ensure_channel_first, drop_nan
from data_gradients.batch_processors.formatters.utils import check_images_shape, ensure_channel_first, drop_nan
from data_gradients.config.data.data_config import DetectionDataConfig
from data_gradients.batch_processors.formatters.utils import DatasetFormatError

Expand Down Expand Up @@ -57,9 +57,9 @@ def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, List[Tensor]]:
- labels: List of bounding boxes, each of shape (N_i, 5 [label_xyxy]) with N_i being the number of bounding boxes with class_id in class_ids
"""

# Might happen if the user passes tensors as [N, 5] with N=1; Depending on the Dataset implementation, it may actually return a [5] tensor instead
if labels.numel() == 0:
labels = torch.zeros((0, 5))
# First thing is to make sure that, if we have empty labels, they are in a correct format
labels = self.format_empty_labels(annotated_bboxes=labels)

# If the label is of shape [N, 5] we can assume that it represents the targets of a single sample (class_name + 4 bbox coordinates)
if labels.ndim == 2 and labels.shape[1] == 5:
Expand All @@ -69,7 +69,7 @@ def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, List[Tensor]]:
labels = drop_nan(labels)

images = ensure_channel_first(images, n_image_channels=self.n_image_channels)
images = ensure_images_shape(images, n_image_channels=self.n_image_channels)
images = check_images_shape(images, n_image_channels=self.n_image_channels)
labels = self.ensure_labels_shape(annotated_bboxes=labels)

targets_sample_str = f"Here's a sample of how your labels look like:\nEach line corresponds to a bounding box.\n{labels[0, :4, :]}"
Expand All @@ -91,6 +91,22 @@ def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, List[Tensor]]:

return images, labels

@staticmethod
def format_empty_labels(annotated_bboxes: Tensor) -> Tensor:
"""Ensure that empty labels have a valid shape, i.e. (N, 5) or (BS, N, 5)."""
if annotated_bboxes.numel() == 0:
# If we have an empty tensor, there is a risk that the target shape will be different.
# e.g. to tensor([]) for single sample or tensor([], size=(BS, 0)) for batch
# This breaks the expected target format which should be (N, 5), (BS, N, 5) or (N, 6)
if annotated_bboxes.shape[-1] == 0:
# This means that the last dim is N (representing the number of bounding boxes), i.e. (N, ) or (BS, N)
# In that case, we need to add the last dimension to represent each bounding box class/coordinates, i.e. (N, 5) or (BS, N, 5)
if annotated_bboxes.ndim == 1: # (N, ) -> (N, 5) with N=0
annotated_bboxes = torch.zeros((0, 5))
elif annotated_bboxes.ndim == 2: # (BS, N) -> (BS, N, 5) with N=0
annotated_bboxes = torch.zeros((annotated_bboxes.shape[0], 0, 5))
return annotated_bboxes

@staticmethod
def ensure_labels_shape(annotated_bboxes: Tensor) -> Tensor:
"""Make sure that the labels have the correct shape, i.e. (BS, N, 5)."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from data_gradients.config.data.questions import ask_user
from data_gradients.batch_processors.formatters.base import BatchFormatter
from data_gradients.batch_processors.utils import check_all_integers, to_one_hot
from data_gradients.batch_processors.formatters.utils import DatasetFormatError, ensure_images_shape, ensure_channel_first, drop_nan
from data_gradients.batch_processors.formatters.utils import DatasetFormatError, check_images_shape, ensure_channel_first, drop_nan


class SegmentationBatchFormatter(BatchFormatter):
Expand Down Expand Up @@ -79,7 +79,7 @@ def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
images = ensure_channel_first(images, n_image_channels=self.n_image_channels)
labels = ensure_channel_first(labels, n_image_channels=self.n_image_channels)

images = ensure_images_shape(images, n_image_channels=self.n_image_channels)
images = check_images_shape(images, n_image_channels=self.n_image_channels)
labels = self.validate_labels_dim(labels, n_classes=self.n_image_channels, ignore_labels=self.ignore_labels)

labels = self.ensure_hard_labels(labels, n_classes=len(self.class_names), threshold_value=self.threshold_value)
Expand Down
2 changes: 1 addition & 1 deletion src/data_gradients/batch_processors/formatters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ensure_channel_first(images: Tensor, n_image_channels: int) -> Tensor:
return images


def ensure_images_shape(images: Tensor, n_image_channels: int) -> Tensor:
def check_images_shape(images: Tensor, n_image_channels: int) -> Tensor:
"""Validate images dimensions are (BS, C, H, W)

:param images: Tensor [BS, C, H, W]
Expand Down
128 changes: 126 additions & 2 deletions tests/unit_tests/batch_processor/test_detection_batch_formatter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,134 @@
import torch
import unittest
from data_gradients.batch_processors.formatters.detection import DetectionBatchFormatter
from data_gradients.config.data import DetectionDataConfig


class GroupDetectionBatchTest(unittest.TestCase):
def test_update_and_aggregate(self):
class DetectionBatchFormatterTest(unittest.TestCase):
def setUp(self):
self.empty_data_config = DetectionDataConfig(use_cache=False, is_label_first=True, xyxy_converter="xyxy")
self.channel_last_image = torch.zeros(64, 32, 3, dtype=torch.uint8)
self.channel_first_image = torch.zeros(3, 64, 32, dtype=torch.uint8)
self.channel_last_images = torch.zeros(1, 64, 32, 3, dtype=torch.uint8)
self.channel_first_images = torch.zeros(1, 3, 64, 32, dtype=torch.uint8)

def test_format_sample_image(self):
formatter = DetectionBatchFormatter(data_config=self.empty_data_config, class_names=["0", "1"], class_names_to_use=["0", "1"], n_image_channels=3)
target_n5 = torch.Tensor(
[
[0, 10, 20, 15, 25],
[0, 5, 10, 15, 25],
[0, 5, 10, 15, 25],
[1, 10, 20, 15, 25],
]
)
images, labels = formatter.format(self.channel_last_image, target_n5)
self.assertTrue(torch.equal(images, self.channel_first_images))
self.assertTrue(torch.equal(labels[0], target_n5))

images, labels = formatter.format(self.channel_first_image, target_n5)
self.assertTrue(torch.equal(images, self.channel_first_images))
self.assertTrue(torch.equal(labels[0], target_n5))

def test_format_batch_n5(self):
formatter = DetectionBatchFormatter(data_config=self.empty_data_config, class_names=["0", "1"], class_names_to_use=["0", "1"], n_image_channels=3)
target_sample_n5 = torch.Tensor(
[
[
[0, 10, 20, 15, 25],
[0, 5, 10, 15, 25],
[1, 5, 10, 15, 25],
],
[
[0, 5, 10, 15, 25],
[1, 10, 20, 15, 25],
[0, 0, 0, 0, 0],
],
]
)
images, labels = formatter.format(self.channel_last_images, target_sample_n5)

self.assertTrue(torch.equal(images, self.channel_first_images))
expected_first_batch = torch.Tensor(
[
[0, 10, 20, 15, 25],
[0, 5, 10, 15, 25],
[1, 5, 10, 15, 25],
],
)
expected_second_batch = torch.Tensor(
[
[0, 5, 10, 15, 25],
[1, 10, 20, 15, 25],
]
)
self.assertTrue(torch.equal(labels[0], expected_first_batch))
self.assertTrue(torch.equal(labels[1], expected_second_batch))

def test_format_batch_n6(self):
formatter = DetectionBatchFormatter(data_config=self.empty_data_config, class_names=["0", "1"], class_names_to_use=["0", "1"], n_image_channels=3)
target_n6 = torch.Tensor(
[
[0, 0, 10, 20, 15, 25],
[0, 0, 5, 10, 15, 25],
[0, 1, 5, 10, 15, 25],
[1, 0, 5, 10, 15, 25],
[1, 1, 10, 20, 15, 25],
]
)
images, labels = formatter.format(self.channel_last_images, target_n6)

self.assertTrue(torch.equal(images, self.channel_first_images))
expected_first_batch = torch.Tensor(
[
[0, 10, 20, 15, 25],
[0, 5, 10, 15, 25],
[1, 5, 10, 15, 25],
],
)
expected_second_batch = torch.Tensor(
[
[0, 5, 10, 15, 25],
[1, 10, 20, 15, 25],
]
)
self.assertTrue(torch.equal(labels[0], expected_first_batch))
self.assertTrue(torch.equal(labels[1], expected_second_batch))

def test_format_empty_sample(self):
formatter = DetectionBatchFormatter(data_config=self.empty_data_config, class_names=["0", "1"], class_names_to_use=["0", "1"], n_image_channels=3)
expected_output_target = torch.zeros(0, 5)

empty_tensor = torch.Tensor([])
images, labels = formatter.format(self.channel_last_image, empty_tensor)
self.assertTrue(torch.equal(images, self.channel_first_images))
self.assertTrue(torch.equal(labels[0], expected_output_target))

empty_zero_tensor = torch.zeros(0, 5)
images, labels = formatter.format(self.channel_first_image, empty_zero_tensor)
self.assertTrue(torch.equal(images, self.channel_first_images))
self.assertTrue(torch.equal(labels[0], expected_output_target))

def test_format_empty_batch(self):
formatter = DetectionBatchFormatter(data_config=self.empty_data_config, class_names=["0", "1"], class_names_to_use=["0", "1"], n_image_channels=3)
batch_size = 7
expected_output_sample_target = torch.zeros(0, 5)

empty_tensor = torch.zeros(batch_size, 0)
images, labels = formatter.format(self.channel_last_images, empty_tensor)
self.assertTrue(torch.equal(images, self.channel_first_images))
self.assertEqual(len(labels), batch_size)
for sample in labels:
self.assertTrue(torch.equal(sample, expected_output_sample_target))

empty_zero_tensor = torch.zeros(batch_size, 0, 5)
images, labels = formatter.format(self.channel_first_images, empty_zero_tensor)
self.assertTrue(torch.equal(images, self.channel_first_images))
self.assertEqual(len(labels), batch_size)
for sample in labels:
self.assertTrue(torch.equal(sample, expected_output_sample_target))

def test_group_detection_batch(self):
flat_batch = torch.Tensor(
[
[0, 2, 10, 20, 15, 25],
Expand Down