diff --git a/src/super_gradients/training/datasets/Dataset_Setup_Instructions.md b/documentation/source/Dataset_Setup_Instructions.md
similarity index 91%
rename from src/super_gradients/training/datasets/Dataset_Setup_Instructions.md
rename to documentation/source/Dataset_Setup_Instructions.md
index fc8c3b8be9..e0620991b3 100644
--- a/src/super_gradients/training/datasets/Dataset_Setup_Instructions.md
+++ b/documentation/source/Dataset_Setup_Instructions.md
@@ -511,3 +511,63 @@ train_set = COCOKeypointsDataset(data_dir='.../coco', images_dir='images/train20
valid_set = COCOKeypointsDataset(data_dir='.../coco', images_dir='images/val2017', json_file='annotations/instances_val2017.json', ...)
```
+
+
+
+### Oriented Box Detection Datasets
+
+
+
+
+DOTA 2.0
+
+1. Download DOTA dataset: https://captain-whu.github.io/DOTA/dataset.html
+
+2. Unzip and organize it as below:
+```
+ dota
+ └── train
+ ├── images
+ │ ├─ 000000000001.jpg
+ │ └─ ...
+ └── ann
+ └─ 000000000001.txt
+ └── val
+ ├── images
+ │ ├─ 000000000002.jpg
+ │ └─ ...
+ └── ann
+ └─ 000000000002.txt
+```
+
+
+3. Run script to slice the dataset into tiles:
+
+```bash
+python src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py --data_dir /dota --output_dir /dota_tiles
+```
+
+4. Specify path to the sliced dataset in the dataset (CLI):
+```bash
+python -m super_gradients.train_from_recipe --config-name yolo_nas_r_s_dota dataset_params.data_dir=/dota_tiles
+```
+
+4. Specify path to the sliced dataset in the dataset (YAML):
+```yaml
+dataset_params:
+ train_dataset_params:
+ data_dir: /dota_tiles/train
+ val_dataset_params:
+ data_dir: /dota_tiles/train
+```
+
+4. Specify path to the sliced dataset in the dataset (CODE):
+
+```python
+
+from super_gradients.training.datasets import DOTAOBBDataset
+
+train_loader = DOTAOBBDataset(data_dir="/dota_tiles/train", ...)
+```
+
+
diff --git a/mkdocs.yml b/mkdocs.yml
index b6badb491b..a28cbb9cb0 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -19,7 +19,7 @@ nav:
- Models: ./documentation/source/models.md
- Dataset:
- Data: ./documentation/source/Data.md
- - Computer Vision Datasets: ./src/super_gradients/training/datasets/Dataset_Setup_Instructions.md
+ - Computer Vision Datasets: ./documentation/source/Dataset_Setup_Instructions.md
- Dataset Adapter: ./documentation/source/dataloader_adapter.md
- Loss functions: ./documentation/source/Losses.md
- LR Assignment: ./documentation/source/LRAssignment.md
diff --git a/src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py b/src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py
new file mode 100644
index 0000000000..8392facf25
--- /dev/null
+++ b/src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py
@@ -0,0 +1,79 @@
+"""
+This script slices the DOTA dataset into tiles of a usable size for training a model.
+The tiles are saved in the output directory with the same structure as the input directory.
+
+To use this script you should download the DOTA dataset from the official website:
+https://captain-whu.github.io/DOTA/dataset.html
+
+The dataset should be organized as follows:
+ dota
+ └── train
+ ├── images
+ │ ├─ 000000000001.jpg
+ │ └─ ...
+ └── ann
+ └─ 000000000001.txt
+ └── val
+ ├── images
+ │ ├─ 000000000002.jpg
+ │ └─ ...
+ └── ann
+ └─ 000000000002.txt
+
+Example usage:
+ python dota_prepare_dataset.py --input_dir /path/to/dota --output_dir /path/to/dota-sliced
+
+After running this script you can use /path/to/dota-sliced as the data_dir argument for training a model on DOTA dataset.
+"""
+
+import argparse
+from pathlib import Path
+
+import cv2
+from super_gradients.training.datasets import DOTAOBBDataset
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Slice DOTA dataset into tiles of usable size for training a model")
+ parser.add_argument("--input_dir", help="Where the full coco dataset is stored", required=True)
+ parser.add_argument("--output_dir", help="Where the resulting data should be stored", required=True)
+ parser.add_argument("--ann_subdir_name", default="ann", help="Name of the annotations subdirectory")
+ parser.add_argument("--output_ann_subdir_name", default="ann-obb", help="Name of the output annotations subdirectory")
+ parser.add_argument("--num_workers", default=cv2.getNumberOfCPUs() // 2)
+ args = parser.parse_args()
+
+ cv2.setNumThreads(cv2.getNumberOfCPUs() // 4)
+
+ input_dir = Path(args.input_dir)
+ output_dir = Path(args.output_dir)
+ ann_subdir_name = str(args.ann_subdir_name)
+ output_ann_subdir_name = str(args.output_ann_subdir_name)
+ DOTAOBBDataset.slice_dataset_into_tiles(
+ data_dir=input_dir / "train",
+ output_dir=output_dir / "train",
+ input_ann_subdir_name=ann_subdir_name,
+ output_ann_subdir_name=output_ann_subdir_name,
+ tile_size=1024,
+ tile_step=512,
+ scale_factors=(0.75, 1, 1.25),
+ min_visibility=0.4,
+ min_area=8,
+ num_workers=args.num_workers,
+ )
+
+ DOTAOBBDataset.slice_dataset_into_tiles(
+ data_dir=input_dir / "val",
+ output_dir=output_dir / "val",
+ input_ann_subdir_name=ann_subdir_name,
+ output_ann_subdir_name=output_ann_subdir_name,
+ tile_size=1024,
+ tile_step=1024,
+ scale_factors=(1,),
+ min_visibility=0.4,
+ min_area=8,
+ num_workers=args.num_workers,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/super_gradients/module_interfaces/__init__.py b/src/super_gradients/module_interfaces/__init__.py
index f9871c3825..cd8f30436d 100644
--- a/src/super_gradients/module_interfaces/__init__.py
+++ b/src/super_gradients/module_interfaces/__init__.py
@@ -12,6 +12,7 @@
SemanticSegmentationDecodingModule,
BinarySegmentationDecodingModule,
)
+from .obb_predictions import OBBPredictions, AbstractOBBPostPredictionCallback
__all__ = [
"HasPredict",
@@ -35,4 +36,6 @@
"AbstractSegmentationDecodingModule",
"SemanticSegmentationDecodingModule",
"BinarySegmentationDecodingModule",
+ "OBBPredictions",
+ "AbstractOBBPostPredictionCallback",
]
diff --git a/src/super_gradients/module_interfaces/obb_predictions.py b/src/super_gradients/module_interfaces/obb_predictions.py
new file mode 100644
index 0000000000..b074587b62
--- /dev/null
+++ b/src/super_gradients/module_interfaces/obb_predictions.py
@@ -0,0 +1,47 @@
+import abc
+import dataclasses
+from typing import Any, List
+from typing import Union
+
+import numpy as np
+from torch import Tensor
+
+__all__ = ["OBBPredictions", "AbstractOBBPostPredictionCallback"]
+
+
+@dataclasses.dataclass
+class OBBPredictions:
+ """
+ A data class that encapsulates oriented box predictions for a single image.
+
+ :param labels: Array of shape [N] with class indices
+ :param scores: Array of shape [N] with corresponding confidence scores.
+ :param rboxes_cxcywhr: Array of shape [N, 5] with rotated boxes for each pose in CXCYWHR format.
+ """
+
+ scores: Union[Tensor, np.ndarray]
+ labels: Union[Tensor, np.ndarray]
+ rboxes_cxcywhr: Union[Tensor, np.ndarray]
+
+ def __init__(self, rboxes_cxcywhr, scores, labels):
+ if len(rboxes_cxcywhr) != len(scores) or len(rboxes_cxcywhr) != len(labels):
+ raise ValueError(f"rboxes_cxcywhr, scores and labels must have the same length. Got: {len(rboxes_cxcywhr)}, {len(scores)}, {len(labels)}")
+ if rboxes_cxcywhr.ndim != 2 or rboxes_cxcywhr.shape[1] != 5:
+ raise ValueError(f"rboxes_cxcywhr must have shape [N, 5]. Got: {rboxes_cxcywhr.shape}")
+
+ self.scores = scores
+ self.labels = labels
+ self.rboxes_cxcywhr = rboxes_cxcywhr
+
+ def __len__(self):
+ return len(self.scores)
+
+
+class AbstractOBBPostPredictionCallback(abc.ABC):
+ """
+ A protocol interface of a post-prediction callback for pose estimation models.
+ """
+
+ @abc.abstractmethod
+ def __call__(self, predictions: Any) -> List[OBBPredictions]:
+ ...
diff --git a/src/super_gradients/recipes/dataset_params/dota2_yolo_nas_r_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/dota2_yolo_nas_r_dataset_params.yaml
new file mode 100644
index 0000000000..46d9ce6981
--- /dev/null
+++ b/src/super_gradients/recipes/dataset_params/dota2_yolo_nas_r_dataset_params.yaml
@@ -0,0 +1,115 @@
+# Configuration file for the dataset parameters of the DOTA2 dataset for the YOLO-NAS-R model.
+# A data_dir parameter should be explicitly defined in the config file that includes this file.
+# Please check documentation/source/Dataset_Setup_Instructions.md for more information on how to set up the dataset.
+
+num_classes: 18
+class_names:
+ - plane
+ - ship
+ - storage-tank
+ - baseball-diamond
+ - tennis-court
+ - basketball-court
+ - ground-track-field
+ - harbor
+ - bridge
+ - large-vehicle
+ - small-vehicle
+ - helicopter
+ - roundabout
+ - soccer-ball-field
+ - swimming-pool
+ - container-crane
+ - airport
+ - helipad
+
+data_dir: ???
+
+mixup_prob: 0.5
+
+train_dataset_params:
+ data_dir: ${dataset_params.data_dir}/train
+ class_names: ${dataset_params.class_names}
+ ignore_empty_annotations: True
+ transforms:
+ - Albumentations:
+ Compose:
+ keypoint_params:
+ transforms:
+ - ShiftScaleRotate:
+ shift_limit: 0.1
+ scale_limit: 0.75
+ rotate_limit: 45
+ interpolation: 1
+ border_mode: 0
+ - RandomBrightnessContrast:
+ brightness_limit: 0.2
+ contrast_limit: 0.2
+ p: 0.5
+ - RandomCrop:
+ p: 1.0
+ height: 640
+ width: 640
+ - HueSaturationValue:
+ hue_shift_limit: 20
+ sat_shift_limit: 30
+ val_shift_limit: 20
+ p: 0.5
+ - RandomRotate90:
+ p: 1.0
+ - HorizontalFlip:
+ p: 0.5
+
+ - OBBRemoveSmallObjects:
+ min_size: 8
+ min_area: 64
+
+ - OBBDetectionMixup:
+ prob: ${dataset_params.mixup_prob}
+
+ - OBBDetectionStandardize:
+ max_value: 255.
+
+train_dataloader_params:
+ dataset: DOTAOBBDataset
+ batch_size: 16
+ num_workers: 8
+ shuffle: True
+ drop_last: True
+ pin_memory: True
+ persistent_workers: True
+ collate_fn: OrientedBoxesCollate
+ sampler:
+ ClassBalancedSampler:
+ num_samples: 65536
+ oversample_threshold: 0.99
+ oversample_aggressiveness: 0.9945267123516118
+
+val_dataset_params:
+ data_dir: ${dataset_params.data_dir}/val
+ class_names: ${dataset_params.class_names}
+ ignore_empty_annotations: True
+ transforms:
+ - OBBDetectionLongestMaxSize:
+ max_height: 1024
+ max_width: 1024
+ - OBBDetectionPadIfNeeded:
+ min_height: 1024
+ min_width: 1024
+ pad_value: 114
+ padding_mode: bottom_right
+ - OBBDetectionStandardize:
+ max_value: 255.
+
+
+val_dataloader_params:
+ dataset: DOTAOBBDataset
+ batch_size: 16
+ num_workers: 8
+ drop_last: False
+ shuffle: False
+ pin_memory: True
+ persistent_workers: True
+ collate_fn: OrientedBoxesCollate
+
+_convert_: all
diff --git a/src/super_gradients/training/datasets/__init__.py b/src/super_gradients/training/datasets/__init__.py
index 754fd09a79..16e4d3cf7d 100755
--- a/src/super_gradients/training/datasets/__init__.py
+++ b/src/super_gradients/training/datasets/__init__.py
@@ -25,7 +25,7 @@
BaseKeypointsDataset,
COCOPoseEstimationDataset,
)
-
+from .obb import DOTAOBBDataset
__all__ = [
"BaseKeypointsDataset",
@@ -50,6 +50,7 @@
"SuperviselyPersonsDataset",
"COCOKeypointsDataset",
"COCOPoseEstimationDataset",
+ "DOTAOBBDataset",
]
cv2.setNumThreads(0)
diff --git a/src/super_gradients/training/datasets/datasets_conf.py b/src/super_gradients/training/datasets/datasets_conf.py
index 85219a8450..18635eb7b3 100755
--- a/src/super_gradients/training/datasets/datasets_conf.py
+++ b/src/super_gradients/training/datasets/datasets_conf.py
@@ -1226,3 +1226,24 @@
"motorcycle",
"bicycle",
]
+
+DOTA2_DEFAULT_CLASSES_LIST = [
+ "plane",
+ "ship",
+ "storage-tank",
+ "baseball-diamond",
+ "tennis-court",
+ "basketball-court",
+ "ground-track-field",
+ "harbor",
+ "bridge",
+ "large-vehicle",
+ "small-vehicle",
+ "helicopter",
+ "roundabout",
+ "soccer-ball-field",
+ "swimming-pool",
+ "container-crane",
+ "airport",
+ "helipad",
+]
diff --git a/src/super_gradients/training/datasets/obb/__init__.py b/src/super_gradients/training/datasets/obb/__init__.py
new file mode 100644
index 0000000000..8a9a448273
--- /dev/null
+++ b/src/super_gradients/training/datasets/obb/__init__.py
@@ -0,0 +1,7 @@
+from .collate import OrientedBoxesCollate
+from .dota import DOTAOBBDataset
+
+__all__ = [
+ "DOTAOBBDataset",
+ "OrientedBoxesCollate",
+]
diff --git a/src/super_gradients/training/datasets/obb/abstract_obb_dataset.py b/src/super_gradients/training/datasets/obb/abstract_obb_dataset.py
new file mode 100644
index 0000000000..b28c5a18a0
--- /dev/null
+++ b/src/super_gradients/training/datasets/obb/abstract_obb_dataset.py
@@ -0,0 +1,20 @@
+import abc
+from abc import ABC
+
+from super_gradients.training.transforms.obb import OBBSample
+from torch.utils.data import Dataset
+
+
+class AbstractOBBDataset(Dataset, ABC):
+ """
+ Abstract class for OBB detection datasets.
+ This class declares minimal interface for OBB detection datasets.
+ """
+
+ @abc.abstractmethod
+ def __len__(self):
+ pass
+
+ @abc.abstractmethod
+ def __getitem__(self, index) -> OBBSample:
+ pass
diff --git a/src/super_gradients/training/datasets/obb/collate.py b/src/super_gradients/training/datasets/obb/collate.py
new file mode 100644
index 0000000000..0ea3cc0b36
--- /dev/null
+++ b/src/super_gradients/training/datasets/obb/collate.py
@@ -0,0 +1,33 @@
+from typing import List
+
+import numpy as np
+import torch
+from super_gradients.common.registry import register_collate_function
+from super_gradients.training.transforms.obb import OBBSample
+
+
+@register_collate_function()
+class OrientedBoxesCollate:
+ def __call__(self, batch: List[OBBSample]):
+ from super_gradients.training.datasets.pose_estimation_datasets.yolo_nas_pose_collate_fn import flat_collate_tensors_with_batch_index
+
+ images = []
+ all_boxes = []
+ all_labels = []
+ all_crowd_masks = []
+
+ for sample in batch:
+ images.append(torch.from_numpy(np.transpose(sample.image, [2, 0, 1])))
+ all_boxes.append(torch.from_numpy(sample.rboxes_cxcywhr))
+ all_labels.append(torch.from_numpy(sample.labels.reshape((-1, 1))))
+ all_crowd_masks.append(torch.from_numpy(sample.is_crowd.reshape((-1, 1))))
+ sample.image = None
+
+ images = torch.stack(images)
+
+ boxes = flat_collate_tensors_with_batch_index(all_boxes).float()
+ labels = flat_collate_tensors_with_batch_index(all_labels).long()
+ is_crowd = flat_collate_tensors_with_batch_index(all_crowd_masks)
+
+ extras = {"gt_samples": batch}
+ return images, (boxes, labels, is_crowd), extras
diff --git a/src/super_gradients/training/datasets/obb/dota.py b/src/super_gradients/training/datasets/obb/dota.py
new file mode 100644
index 0000000000..0798d6c682
--- /dev/null
+++ b/src/super_gradients/training/datasets/obb/dota.py
@@ -0,0 +1,363 @@
+import multiprocessing
+import random
+import cv2
+import numpy as np
+
+from functools import partial
+from pathlib import Path
+from typing import Tuple, Iterable
+
+from super_gradients.module_interfaces import HasPreprocessingParams
+from super_gradients.training.datasets.data_formats.obb.cxcywhr import poly_to_cxcywhr
+from tqdm import tqdm
+
+from super_gradients.common.decorators.factory_decorator import resolve_param
+from super_gradients.common.object_names import Processings
+from super_gradients.common.registry import register_dataset
+from super_gradients.dataset_interfaces import HasClassesInformation
+from super_gradients.training.transforms import OBBDetectionCompose
+from super_gradients.training.transforms.obb import OBBSample
+from super_gradients.common.factories.transforms_factory import TransformsFactory
+from .abstract_obb_dataset import AbstractOBBDataset
+
+__all__ = ["DOTAOBBDataset"]
+
+
+@register_dataset()
+class DOTAOBBDataset(AbstractOBBDataset, HasPreprocessingParams, HasClassesInformation):
+ @resolve_param("transforms", TransformsFactory())
+ def __init__(
+ self,
+ data_dir,
+ transforms,
+ class_names: Iterable[str],
+ ignore_empty_annotations: bool = False,
+ difficult_labels_are_crowd: bool = False,
+ images_ext: str = ".jpg",
+ images_subdir="images",
+ ann_subdir="ann-obb",
+ ):
+ super().__init__()
+
+ images_dir = Path(data_dir) / images_subdir
+ ann_dir = Path(data_dir) / ann_subdir
+ images, labels = self.find_images_and_labels(images_dir, ann_dir, images_ext)
+ self.images = []
+ self.coords = []
+ self.classes = []
+ self.difficult = []
+ self.transforms = OBBDetectionCompose(transforms, load_sample_fn=self.load_random_sample)
+ self.class_names = list(class_names)
+ self.difficult_labels_are_crowd = difficult_labels_are_crowd
+
+ class_names_to_index = {name: i for i, name in enumerate(self.class_names)}
+ for image_path, label_path in tqdm(zip(images, labels), desc=f"Parsing annotations in {ann_dir}", total=len(images)):
+ coords, classes, difficult = self.parse_annotation_file(label_path)
+ if ignore_empty_annotations and len(coords) == 0:
+ continue
+ self.images.append(image_path)
+ self.coords.append(coords)
+ self.classes.append(np.array([class_names_to_index[c] for c in classes], dtype=int))
+ self.difficult.append(difficult)
+
+ def __len__(self):
+ return len(self.images)
+
+ def load_random_sample(self) -> OBBSample:
+ num_samples = len(self)
+ random_index = random.randrange(0, num_samples)
+ return self.load_sample(random_index)
+
+ def load_sample(self, index) -> OBBSample:
+ image = cv2.imread(str(self.images[index]))
+ coords = self.coords[index]
+ classes = self.classes[index]
+ difficult = self.difficult[index]
+
+ cxcywhr = poly_to_cxcywhr(coords)
+
+ is_crowd = difficult.reshape(-1) if self.difficult_labels_are_crowd else np.zeros_like(difficult, dtype=bool)
+ sample = OBBSample(
+ image=image,
+ rboxes_cxcywhr=cxcywhr.reshape(-1, 5),
+ labels=classes.reshape(-1),
+ is_crowd=is_crowd,
+ )
+ return sample.sanitize_sample()
+
+ def __getitem__(self, index) -> OBBSample:
+ sample = self.load_sample(index)
+ sample = self.transforms.apply_to_sample(sample)
+ return sample
+
+ def get_sample_classes_information(self, index) -> np.ndarray:
+ """
+ Returns a histogram of length `num_classes` with class occurrences at that index.
+ """
+ return np.bincount(self.classes[index], minlength=len(self.class_names))
+
+ def get_dataset_classes_information(self) -> np.ndarray:
+ """
+ Returns a matrix of shape (dataset_length, num_classes). Each row `i` is histogram of length `num_classes` with class occurrences for sample `i`.
+ Example implementation, assuming __len__: `np.vstack([self.get_sample_classes_information(i) for i in range(len(self))])`
+ """
+ m = np.zeros((len(self), len(self.class_names)), dtype=int)
+ for i in range(len(self)):
+ m[i] = self.get_sample_classes_information(i)
+ return m
+
+ def get_dataset_preprocessing_params(self):
+ """
+ Return any hardcoded preprocessing + adaptation for PIL.Image image reading (RGB).
+ image_processor as returned as list of dicts to be resolved by processing factory.
+ :return:
+ """
+ pipeline = [Processings.ReverseImageChannels] + self.transforms.get_equivalent_preprocessing()
+ params = dict(
+ class_names=self.class_names,
+ image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
+ iou=0.25,
+ conf=0.1,
+ )
+ return params
+
+ @classmethod
+ def find_images_and_labels(cls, images_dir, ann_dir, images_ext):
+ images_dir = Path(images_dir)
+ ann_dir = Path(ann_dir)
+
+ images = list(images_dir.glob(f"*{images_ext}"))
+ labels = list(sorted(ann_dir.glob("*.txt")))
+
+ if len(images) != len(labels):
+ raise ValueError(f"Number of images and labels do not match. There are {len(images)} images and {len(labels)} labels.")
+
+ images = []
+ for label_path in labels:
+ image_path = images_dir / (label_path.stem + images_ext)
+ if not image_path.exists():
+ raise ValueError(f"Image {image_path} does not exist")
+ images.append(image_path)
+ return images, labels
+
+ @classmethod
+ def parse_annotation_file(cls, annotation_file: Path):
+ with open(annotation_file, "r") as f:
+ lines = f.readlines()
+
+ coords = []
+ classes = []
+ difficult = []
+
+ for line in lines:
+ parts = line.strip().split(" ")
+ if len(parts) != 10:
+ raise ValueError(f"Invalid number of parts in line: {line}")
+
+ x1, y1, x2, y2, x3, y3, x4, y4 = map(float, parts[:8])
+ coords.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
+ classes.append(parts[8])
+ difficult.append(int(parts[9]))
+
+ return np.array(coords, dtype=np.float32).reshape(-1, 4, 2), np.array(classes, dtype=np.object_), np.array(difficult, dtype=int)
+
+ @classmethod
+ def chip_image(cls, img, coords, classes, difficult, tile_size: Tuple[int, int], tile_step: Tuple[int, int], min_visibility: float, min_area: int):
+ """
+ Chip an image and get relative coordinates and classes. Bounding boxes that pass into
+ multiple chips are clipped: each portion that is in a chip is labeled. For example,
+ half a building will be labeled if it is cut off in a chip.
+
+ :param img: the image to be chipped in array format
+ :param coords: an (N,4,2) array of oriented box coordinates for that image
+ :param classes: an (N,1) array of classes for each bounding box
+ :param tile_size: an (W,H) tuple indicating width and height of chips
+
+ Output:
+ An image array of shape (M,W,H,C), where M is the number of chips,
+ W and H are the dimensions of the image, and C is the number of color
+ channels. Also returns boxes and classes dictionaries for each corresponding chip.
+ """
+ height, width, _ = img.shape
+
+ tile_size_width, tile_size_height = tile_size
+ tile_step_width, tile_step_height = tile_step
+
+ total_images = []
+ total_boxes = []
+ total_classes = []
+ total_difficult = []
+
+ start_x = 0
+ end_x = start_x + tile_size_width
+
+ all_areas = np.array(list(cv2.contourArea(cv2.convexHull(poly)) for poly in coords), dtype=np.float32)
+
+ centers = np.mean(coords, axis=1) # [N,2]
+
+ while start_x < width:
+ start_y = 0
+ end_y = start_y + tile_size_height
+ while start_y < height:
+ chip = img[start_y:end_y, start_x:end_x, :3]
+
+ # Skipping thin strips that are not useful
+ # For instance, if image is 1030px wide and our tile size is 1024, that would end up with
+ # two tiles of [1024, 1024] and [1024, 6] which is not useful at all
+ if chip.shape[0] > 8 or chip.shape[1] > 8:
+
+ # Filter out boxes that whose bounding box is definitely not in the chip
+ offset = np.array([start_x, start_y], dtype=np.float32)
+ boxes_with_offset = coords - offset.reshape(1, 1, 2)
+ centers_with_offset = centers - offset.reshape(1, 2)
+
+ cond1 = (centers_with_offset >= 0).all(axis=1)
+ cond2 = (centers_with_offset[:, 0] < chip.shape[1]) & (centers_with_offset[:, 1] < chip.shape[0])
+ rboxes_inside_chip = cond1 & cond2
+
+ visible_coords = boxes_with_offset[rboxes_inside_chip]
+ visible_classes = classes[rboxes_inside_chip]
+ visible_difficult = difficult[rboxes_inside_chip]
+ visible_areas = all_areas[rboxes_inside_chip]
+
+ out_clipped = np.stack(
+ (
+ np.clip(visible_coords[:, :, 0], 0, chip.shape[1]),
+ np.clip(visible_coords[:, :, 1], 0, chip.shape[0]),
+ ),
+ axis=2,
+ )
+ areas_clipped = np.array(list(cv2.contourArea(cv2.convexHull(c)) for c in out_clipped), dtype=np.float32)
+
+ visibility_fraction = areas_clipped / (visible_areas + 1e-6)
+ visibility_mask = visibility_fraction >= min_visibility
+ min_area_mask = areas_clipped >= min_area
+
+ visible_coords = visible_coords[visibility_mask & min_area_mask]
+ visible_classes = visible_classes[visibility_mask & min_area_mask]
+ visible_difficult = visible_difficult[visibility_mask & min_area_mask]
+
+ total_boxes.append(visible_coords)
+ total_classes.append(visible_classes)
+ total_difficult.append(visible_difficult)
+
+ if chip.shape[0] < tile_size_height or chip.shape[1] < tile_size_width:
+ chip = cv2.copyMakeBorder(
+ chip,
+ top=0,
+ left=0,
+ bottom=tile_size_height - chip.shape[0],
+ right=tile_size_width - chip.shape[1],
+ value=0,
+ borderType=cv2.BORDER_CONSTANT,
+ )
+ total_images.append(chip)
+
+ start_y += tile_step_height
+ end_y += tile_step_height
+
+ start_x += tile_step_width
+ end_x += tile_step_width
+
+ return total_images, total_boxes, total_classes, total_difficult
+
+ @classmethod
+ def slice_dataset_into_tiles(
+ cls,
+ data_dir,
+ output_dir,
+ input_ann_subdir_name,
+ output_ann_subdir_name,
+ tile_size: int,
+ tile_step: int,
+ scale_factors: Tuple,
+ min_visibility,
+ min_area,
+ num_workers: int,
+ output_image_ext=".jpg",
+ ) -> None:
+ """
+ Helper function to slice a dataset into tiles of a given size and step.
+ Slicing dataset is a necessary preparation step in order to train a model on DOTA dataset.
+ This method is not meant to be used directly by user.
+ Please check src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py for an example usage.
+
+ :param data_dir: Input DOTA dataset directory
+ :param output_dir: Output location of the sliced dataset
+ :param input_ann_subdir_name: Name of the annotations subdirectory. Usually 'ann'
+ :param output_ann_subdir_name: Name of the output annotations subdirectory. Usually 'ann-obb'
+ This is to distinguish the OBB vs HBB annotations
+ :param tile_size: Size of the image tile (pixels)
+ :param tile_step: Step size in pixels between consecutive tiles
+ :param scale_factors: Tuple of scale factors to apply to the image before slicing
+ For training a range of scale factors is usually used, e.g. (0.75, 1, 1.25)
+ For validation a single scale factor is used, e.g. (1,)
+ :param min_visibility: A faction of the bounding box that must be visible in the tile for it to be included
+ :param min_area: Minimum area of the bounding box that must be visible in the tile for it to be included
+ :param num_workers: Number of workers to use for multiprocessing
+ :param output_image_ext: Extension of the output image files. Default value is '.jpg' which is more
+ size efficient than '.png'.
+ """
+ data_dir = Path(data_dir)
+ input_images_dir = data_dir / "images"
+ input_ann_dir = data_dir / input_ann_subdir_name
+ images, labels = cls.find_images_and_labels(input_images_dir, input_ann_dir, ".png")
+
+ output_dir = Path(output_dir)
+ output_images_dir = output_dir / "images"
+ output_ann_dir = output_dir / output_ann_subdir_name
+
+ output_images_dir.mkdir(parents=True, exist_ok=True)
+ output_ann_dir.mkdir(parents=True, exist_ok=True)
+
+ with multiprocessing.Pool(num_workers) as wp:
+ payload = [(image_path, ann_path, scale) for image_path, ann_path in zip(images, labels) for scale in scale_factors]
+
+ worker_fn = partial(
+ cls._worker_fn,
+ tile_size=tile_size,
+ tile_step=tile_step,
+ min_visibility=min_visibility,
+ min_area=min_area,
+ output_images_dir=output_images_dir,
+ output_ann_dir=output_ann_dir,
+ output_image_ext=output_image_ext,
+ )
+ for _ in tqdm(wp.imap_unordered(worker_fn, payload), total=len(payload)):
+ pass
+
+ @classmethod
+ def _worker_fn(cls, args, tile_size, tile_step, min_visibility, min_area, output_images_dir, output_ann_dir, output_image_ext):
+ image_path, ann_path, scale = args
+ image = cv2.imread(str(image_path))
+ coords, classes, difficult = cls.parse_annotation_file(ann_path)
+ scaled_image = cv2.resize(image, (0, 0), fx=scale, fy=scale)
+
+ image_tiles, total_boxes, total_classes, total_difficult = cls.chip_image(
+ scaled_image,
+ coords * scale,
+ classes,
+ difficult,
+ tile_size=(tile_size, tile_size),
+ tile_step=(tile_step, tile_step),
+ min_visibility=min_visibility,
+ min_area=min_area,
+ )
+ num_tiles = len(image_tiles)
+
+ for i in range(num_tiles):
+ tile_image = image_tiles[i]
+ tile_boxes = total_boxes[i]
+ tile_classes = total_classes[i]
+ tile_difficult = total_difficult[i]
+
+ tile_image_path = output_images_dir / f"{ann_path.stem}_{scale:.3f}_{i:06d}{output_image_ext}"
+ tile_label_path = output_ann_dir / f"{ann_path.stem}_{scale:.3f}_{i:06d}.txt"
+
+ with tile_label_path.open("w") as f:
+ for poly, category, diff in zip(tile_boxes, tile_classes, tile_difficult):
+ f.write(
+ f"{poly[0, 0]:.2f} {poly[0, 1]:.2f} {poly[1, 0]:.2f} {poly[1, 1]:.2f} {poly[2, 0]:.2f} {poly[2, 1]:.2f} {poly[3, 0]:.2f} {poly[3, 1]:.2f} {category} {diff}\n" # noqa
+ )
+
+ cv2.imwrite(str(tile_image_path), tile_image)
diff --git a/src/super_gradients/training/metrics/__init__.py b/src/super_gradients/training/metrics/__init__.py
index 7916b13cd4..ba7d96e188 100755
--- a/src/super_gradients/training/metrics/__init__.py
+++ b/src/super_gradients/training/metrics/__init__.py
@@ -17,6 +17,7 @@
DepthRMSE,
DepthMSLE,
)
+from .obb_detection_metrics import OBBDetectionMetrics_050_095, OBBDetectionMetrics_050, OBBDetectionMetrics
__all__ = [
"METRICS",
@@ -45,4 +46,7 @@
"DepthMSE",
"DepthRMSE",
"DepthMSLE",
+ "OBBDetectionMetrics_050_095",
+ "OBBDetectionMetrics_050",
+ "OBBDetectionMetrics",
]
diff --git a/src/super_gradients/training/metrics/detection_metrics.py b/src/super_gradients/training/metrics/detection_metrics.py
index 10beb2a6c8..7125e7a569 100755
--- a/src/super_gradients/training/metrics/detection_metrics.py
+++ b/src/super_gradients/training/metrics/detection_metrics.py
@@ -1,4 +1,6 @@
import collections
+import numbers
+import typing
from typing import Dict, Optional, Union, Tuple, List
import numpy as np
@@ -109,11 +111,15 @@ def __init__(
if isinstance(iou_thres, IouThreshold):
self.iou_thresholds = iou_thres.to_tensor()
- if isinstance(iou_thres, tuple):
+ elif isinstance(iou_thres, tuple):
low, high = iou_thres
self.iou_thresholds = IouThreshold.from_bounds(low, high)
- else:
- self.iou_thresholds = torch.tensor([iou_thres])
+ elif isinstance(iou_thres, typing.Iterable):
+ self.iou_thresholds = torch.tensor(list(iou_thres)).float()
+ elif isinstance(iou_thres, np.ndarray):
+ self.iou_thresholds = torch.from_numpy(iou_thres).float()
+ elif isinstance(iou_thres, numbers.Number):
+ self.iou_thresholds = torch.tensor([iou_thres], dtype=torch.float32)
self.map_str = "mAP" + self._get_range_str()
self.include_classwise_ap = include_classwise_ap
diff --git a/src/super_gradients/training/metrics/obb_detection_metrics.py b/src/super_gradients/training/metrics/obb_detection_metrics.py
new file mode 100644
index 0000000000..4bf366751f
--- /dev/null
+++ b/src/super_gradients/training/metrics/obb_detection_metrics.py
@@ -0,0 +1,474 @@
+import typing
+from typing import Optional, Union, Tuple, List
+
+import cv2
+import torch
+import torchvision.ops
+from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.registry.registry import register_metric
+from super_gradients.module_interfaces.obb_predictions import OBBPredictions
+from super_gradients.training.datasets.data_formats.obb.cxcywhr import cxcywhr_to_poly, poly_to_xyxy
+from super_gradients.training.transforms.obb import OBBSample
+from super_gradients.training.utils.detection_utils import (
+ DetectionMatching,
+ get_top_k_idx_per_cls,
+)
+from super_gradients.training.utils.detection_utils import IouThreshold
+from torch import Tensor
+
+from .detection_metrics import DetectionMetrics
+
+logger = get_logger(__name__)
+
+if typing.TYPE_CHECKING:
+ from super_gradients.module_interfaces import AbstractOBBPostPredictionCallback
+
+
+class OBBIoUMatching(DetectionMatching):
+ """
+ IoUMatching is a subclass of DetectionMatching that uses Intersection over Union (IoU)
+ for matching detections in object detection models.
+ """
+
+ def __init__(self, iou_thresholds: torch.Tensor):
+ """
+ Initializes the IoUMatching instance with IoU thresholds.
+
+ :param iou_thresholds: (torch.Tensor) The IoU thresholds for matching.
+ """
+ self.iou_thresholds = iou_thresholds
+
+ def get_thresholds(self) -> torch.Tensor:
+ """
+ Returns the IoU thresholds used for detection matching.
+
+ :return: (torch.Tensor) The IoU thresholds.
+ """
+ return self.iou_thresholds
+
+ @classmethod
+ def pairwise_cxcywhr_iou_accurate(cls, obb1: Tensor, obb2: Tensor) -> Tensor:
+ """
+ Calculate the pairwise IoU between oriented bounding boxes.
+
+ :param obb1: First set of boxes. Tensor of shape (N, 5) representing ground truth boxes, with cxcywhr format.
+ :param obb2: Second set of boxes. Tensor of shape (M, 5) representing predicted boxes, with cxcywhr format.
+ :return: A tensor of shape (N, M) representing IoU scores between corresponding boxes.
+ """
+ import numpy as np
+
+ if len(obb1.shape) != 2 or len(obb2.shape) != 2:
+ raise ValueError("Expected obb1 and obb2 to be 2D tensors")
+
+ poly1 = cxcywhr_to_poly(obb1.detach().cpu().numpy())
+ poly2 = cxcywhr_to_poly(obb2.detach().cpu().numpy())
+
+ # Compute bounding boxes from polygons
+ xyxy1 = poly_to_xyxy(poly1)
+ xyxy2 = poly_to_xyxy(poly2)
+ bbox_iou = torchvision.ops.box_iou(torch.from_numpy(xyxy1), torch.from_numpy(xyxy2)).numpy()
+ iou = np.zeros((poly1.shape[0], poly2.shape[0]))
+
+ # We use bounding box IoU to filter out pairs of polygons that has no intersection
+ # Only polygons that have non-zero bounding box IoU are considered for polygon-polygon IoU calculation
+ nz_indexes = np.nonzero(bbox_iou)
+ for i, j in zip(*nz_indexes):
+ iou[i, j] = cls.polygon_polygon_iou(poly1[i], poly2[j])
+ return torch.from_numpy(iou).to(obb1.device)
+
+ @classmethod
+ def polygon_polygon_iou(cls, gt_rect, pred_rect):
+ """
+ Performs intersection over union calculation for two polygons using integer coordinates of
+ vertices. This is a workaround for a bug in cv2.intersectConvexConvex function that returns
+ incorrect results for polygons with float coordinates that are almost identical
+
+ Args:
+ gt_rect: [4,2]
+ pred_rect: [4,2]
+
+ Returns:
+
+ """
+ # Multiply by 1000 to account for rounding errors when going from float to int. 1000 should be enough to get rid of any rounding errors
+ # It has no effect on IOU since it is scale-less
+ pred_rect_int = (pred_rect * 1000).astype(int)
+ gt_rect_int = (gt_rect * 1000).astype(int)
+
+ try:
+ intersection, _ = cv2.intersectConvexConvex(pred_rect_int, gt_rect_int, handleNested=True)
+ except Exception as e:
+ raise RuntimeError(
+ "Detected error in cv2.intersectConvexConvex while calculating polygon_polygon_iou\n"
+ f"pred_rect_int: {pred_rect_int}\n"
+ f"gt_rect_int: {gt_rect_int}"
+ ) from e
+
+ gt_area = cv2.contourArea(gt_rect_int)
+ pred_area = cv2.contourArea(pred_rect_int)
+
+ # Second condition is to avoid division by zero when predicted polygon is degenerate (point or line)
+ if intersection > 0 and pred_area > 0:
+ union = gt_area + pred_area - intersection
+ if union == 0:
+ raise ZeroDivisionError(
+ f"ZeroDivisionError at polygon_polygon_iou_int\n"
+ f"Intersection is {intersection}\n"
+ f"Union is {union}\n"
+ f"gt_rect_int {gt_rect_int}\n"
+ f"pred_rect_int {pred_rect_int}"
+ )
+ return intersection / max(union, 1e-7)
+
+ return 0
+
+ def compute_targets(
+ self,
+ preds_cxcywhr: torch.Tensor,
+ preds_cls: torch.Tensor,
+ targets_cxcywhr: torch.Tensor,
+ targets_cls: torch.Tensor,
+ preds_matched: torch.Tensor,
+ targets_matched: torch.Tensor,
+ preds_idx_to_use: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Computes the matching targets based on IoU for regular scenarios.
+
+ :param preds_cxcywhr: (torch.Tensor) Predicted bounding boxes in CXCYWHR format.
+ :param preds_cls: (torch.Tensor) Predicted classes.
+ :param targets_cxcywhr: (torch.Tensor) Target bounding boxes in CXCYWHR format.
+ :param targets_cls: (torch.Tensor) Target classes.
+ :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
+ :param targets_matched: (torch.Tensor) Tensor indicating which targets are matched.
+ :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
+ :return: (torch.Tensor) Computed matching targets.
+ """
+ # shape = (n_preds x n_targets)
+ iou = self.pairwise_cxcywhr_iou_accurate(preds_cxcywhr[preds_idx_to_use], targets_cxcywhr)
+
+ # Fill IoU values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
+ # Filling with 0 is equivalent to ignore these values since with want IoU > iou_threshold > 0
+ cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != targets_cls.view(1, -1)
+ iou[cls_mismatch] = 0
+
+ # The matching priority is first detection confidence and then IoU value.
+ # The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou.
+ sorted_iou, target_sorted = iou.sort(descending=True, stable=True)
+
+ # Only iterate over IoU values higher than min threshold to speed up the process
+ for pred_selected_i, target_sorted_i in (sorted_iou > self.iou_thresholds[0]).nonzero(as_tuple=False):
+ # pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes
+ pred_i = preds_idx_to_use[pred_selected_i]
+ target_i = target_sorted[pred_selected_i, target_sorted_i]
+
+ # Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold
+ is_iou_above_threshold = sorted_iou[pred_selected_i, target_sorted_i] > self.iou_thresholds
+
+ # Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold
+ are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :])
+
+ # Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold
+ are_candidates_good = torch.logical_and(is_iou_above_threshold, are_candidates_free)
+
+ # For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True )
+ # fill the matching placeholders with True
+ targets_matched[target_i, are_candidates_good] = True
+ preds_matched[pred_i, are_candidates_good] = True
+
+ # When all the targets are matched with a prediction for every IoU Threshold, stop.
+ if targets_matched.all():
+ break
+
+ return preds_matched
+
+ def compute_crowd_targets(
+ self,
+ preds_cxcywhr: torch.Tensor,
+ preds_cls: torch.Tensor,
+ crowd_targets_cls: torch.Tensor,
+ crowd_targets_cxcywhr: torch.Tensor,
+ preds_matched: torch.Tensor,
+ preds_to_ignore: torch.Tensor,
+ preds_idx_to_use: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Computes the matching targets based on IoU for crowd scenarios.
+
+ :param preds_cxcywhr: (torch.Tensor) Predicted bounding boxes in CXCYWHR format.
+ :param preds_cls: (torch.Tensor) Predicted classes.
+ :param crowd_targets_cls: (torch.Tensor) Crowd target classes.
+ :param crowd_targets_cxcywhr: (torch.Tensor) Crowd target bounding boxes in CXCYWHR format.
+ :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
+ :param preds_to_ignore: (torch.Tensor) Tensor indicating which predictions to ignore.
+ :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
+ :return: (Tuple[torch.Tensor, torch.Tensor]) Computed matching targets for crowd scenarios.
+ """
+ # Crowd targets can be matched with many predictions.
+ # Therefore, for every prediction we just need to check if it has IoU large enough with any crowd target.
+
+ # shape = (n_preds_to_use x n_crowd_targets)
+ iou = self.pairwise_cxcywhr_iou_accurate(preds_cxcywhr[preds_idx_to_use], crowd_targets_cxcywhr)
+
+ # Fill IoA values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
+ # Filling with 0 is equivalent to ignore these values since with want IoA > threshold > 0
+ cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != crowd_targets_cls.view(1, -1)
+ iou[cls_mismatch] = 0
+
+ # For each prediction, we keep it's highest score with any crowd target (of same class)
+ # shape = (n_preds_to_use)
+ best_ioa, _ = iou.max(1)
+
+ # If a prediction has IoA higher than threshold (with any target of same class), then there is a match
+ # shape = (n_preds_to_use x iou_thresholds)
+ is_matching_with_crowd = best_ioa.view(-1, 1) > self.iou_thresholds.view(1, -1)
+
+ preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd)
+
+ return preds_matched, preds_to_ignore
+
+
+def compute_obb_detection_matching(
+ preds: OBBPredictions,
+ targets: OBBSample,
+ matching_strategy: OBBIoUMatching,
+ top_k: Optional[int],
+ output_device: Optional[torch.device] = None,
+) -> Tuple:
+ """
+ Match predictions (NMS output) and the targets (ground truth) with respect to metric and confidence score
+ for a given image.
+ :param preds: Tensor of shape (num_img_predictions, 6)
+ format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size
+ :param targets: targets for this image of shape (num_img_targets, 6)
+ format: (label, cx, cy, w, h) where cx,cy,w,h
+ :param top_k: Number of predictions to keep per class, ordered by confidence score
+ :param matching_strategy: Method to match predictions to ground truth targets: IoU, distance based
+
+ :return:
+ :preds_matched: Tensor of shape (num_img_predictions, n_thresholds)
+ True when prediction (i) is matched with a target with respect to the (j)th threshold
+ :preds_to_ignore: Tensor of shape (num_img_predictions, n_thresholds)
+ True when prediction (i) is matched with a crowd target with respect to the (j)th threshold
+ :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction
+ :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction
+ :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target
+ """
+ num_thresholds = len(matching_strategy.get_thresholds())
+ device = preds.scores.device
+ num_preds = len(preds.rboxes_cxcywhr)
+
+ targets_box = torch.from_numpy(targets.rboxes_cxcywhr[~targets.is_crowd]).to(device)
+ targets_cls = torch.from_numpy(targets.labels[~targets.is_crowd]).to(device)
+
+ crowd_target_box = torch.from_numpy(targets.rboxes_cxcywhr[targets.is_crowd]).to(device)
+ crowd_targets_cls = torch.from_numpy(targets.labels[targets.is_crowd]).to(device)
+
+ num_targets = len(targets_box)
+ num_crowd_targets = len(crowd_target_box)
+
+ if num_preds == 0:
+ preds_matched = torch.zeros((0, num_thresholds), dtype=torch.bool, device=device)
+ preds_to_ignore = torch.zeros((0, num_thresholds), dtype=torch.bool, device=device)
+ preds_scores = torch.tensor([], dtype=torch.float32, device=device)
+ preds_cls = torch.tensor([], dtype=torch.float32, device=device)
+ targets_cls = targets_cls.to(device=device)
+ return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls
+
+ preds_scores = preds.scores
+ preds_cls = preds.labels
+
+ preds_matched = torch.zeros(num_preds, num_thresholds, dtype=torch.bool, device=device)
+ targets_matched = torch.zeros(num_targets, num_thresholds, dtype=torch.bool, device=device)
+ preds_to_ignore = torch.zeros(num_preds, num_thresholds, dtype=torch.bool, device=device)
+
+ # Ignore all but the predictions that were top_k for their class
+ if top_k is not None:
+ preds_idx_to_use = get_top_k_idx_per_cls(preds_scores, preds_cls, top_k)
+ else:
+ preds_idx_to_use = torch.arange(num_preds, device=device)
+
+ preds_to_ignore[:, :] = True
+ preds_to_ignore[preds_idx_to_use] = False
+
+ if num_targets > 0 or num_crowd_targets > 0:
+ if num_targets > 0:
+ preds_matched = matching_strategy.compute_targets(
+ preds.rboxes_cxcywhr, preds_cls, targets_box, targets_cls, preds_matched, targets_matched, preds_idx_to_use
+ )
+
+ if num_crowd_targets > 0:
+ preds_matched, preds_to_ignore = matching_strategy.compute_crowd_targets(
+ preds.rboxes_cxcywhr, preds_cls, crowd_targets_cls, crowd_target_box, preds_matched, preds_to_ignore, preds_idx_to_use
+ )
+
+ if output_device is not None:
+ preds_matched = preds_matched.to(output_device)
+ preds_to_ignore = preds_to_ignore.to(output_device)
+ preds_scores = preds_scores.to(output_device)
+ preds_cls = preds_cls.to(output_device)
+ targets_cls = targets_cls.to(output_device)
+
+ return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls
+
+
+@register_metric()
+class OBBDetectionMetrics(DetectionMetrics):
+ """
+ Metric class for computing F1, Precision, Recall and Mean Average Precision for Oriented Bounding Box (OBB) detection tasks.
+
+ :param num_cls: Number of classes.
+ :param post_prediction_callback: A post-prediction callback to be applied on net's output prior to the metric computation (NMS).
+ :param iou_thres: IoU threshold to compute the mAP.
+ Could be either instance of IouThreshold, a tuple (lower bound, upper_bound) or single scalar.
+ :param recall_thres: Recall threshold to compute the mAP.
+ :param score_thres: Score threshold to compute Recall, Precision and F1.
+ :param top_k_predictions: Number of predictions per class used to compute metrics, ordered by confidence score
+ :param dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step.
+ :param accumulate_on_cpu: Run on CPU regardless of device used in other parts.
+ This is to avoid "CUDA out of memory" that might happen on GPU.
+ :param calc_best_score_thresholds Whether to calculate the best score threshold overall and per class
+ If True, the compute() function will return a metrics dictionary that not
+ only includes the average metrics calculated across all classes,
+ but also the optimal score threshold overall and for each individual class.
+ :param include_classwise_ap: Whether to include the class-wise average precision in the returned metrics dictionary.
+ If enabled, output metrics dictionary will look similar to this:
+ {
+ 'Precision0.5:0.95': 0.5,
+ 'Recall0.5:0.95': 0.5,
+ 'F10.5:0.95': 0.5,
+ 'mAP0.5:0.95': 0.5,
+ 'AP0.5:0.95_person': 0.5,
+ 'AP0.5:0.95_car': 0.5,
+ 'AP0.5:0.95_bicycle': 0.5,
+ 'AP0.5:0.95_motorcycle': 0.5,
+ ...
+ }
+ Class names are either provided via the class_names parameter or are generated automatically.
+ :param class_names: Array of class names. When include_classwise_ap=True, will use these names to make
+ per-class APs keys in the output metrics dictionary.
+ If None, will use dummy names `class_{idx}` instead.
+ :param state_dict_prefix: A prefix to append to the state dict of the metric. A state dict used to synchronize metric in DDP mode.
+ It was empirically found that if you have two metric classes A and B(A) that has same state key, for
+ some reason torchmetrics attempts to sync their states all toghether which causes an error.
+ In this case adding a prefix to the name of the synchronized state seems to help,
+ but it is still unclear why it happens.
+
+
+ """
+
+ def __init__(
+ self,
+ num_cls: int,
+ post_prediction_callback: "AbstractOBBPostPredictionCallback",
+ iou_thres: Union[IouThreshold, Tuple[float, float], float],
+ top_k_predictions: Optional[int] = None,
+ recall_thres: Tuple[float, ...] = None,
+ score_thres: Optional[float] = 0.01,
+ dist_sync_on_step: bool = False,
+ accumulate_on_cpu: bool = True,
+ calc_best_score_thresholds: bool = True,
+ include_classwise_ap: bool = False,
+ class_names: List[str] = None,
+ state_dict_prefix: str = "",
+ ):
+ super().__init__(
+ num_cls=num_cls,
+ post_prediction_callback=post_prediction_callback,
+ normalize_targets=False,
+ iou_thres=iou_thres,
+ top_k_predictions=top_k_predictions,
+ recall_thres=recall_thres,
+ score_thres=score_thres,
+ dist_sync_on_step=dist_sync_on_step,
+ accumulate_on_cpu=accumulate_on_cpu,
+ calc_best_score_thresholds=calc_best_score_thresholds,
+ include_classwise_ap=include_classwise_ap,
+ class_names=class_names,
+ state_dict_prefix=state_dict_prefix,
+ )
+
+ def update(self, preds, gt_samples: List[OBBSample]) -> None:
+ """
+ Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.
+
+ :param preds: Raw output of the model, the format might change from one model to another,
+ but has to fit the input format of the post_prediction_callback (cx,cy,wh)
+ :param target: Targets for all images of shape (total_num_targets, 6) LABEL_CXCYWH. format: (index, label, cx, cy, w, h)
+ :param device: Device to run on
+ :param inputs: Input image tensor of shape (batch_size, n_img, height, width)
+ :param crowd_targets: Crowd targets for all images of shape (total_num_targets, 6), LABEL_CXCYWH
+ """
+ preds: List[OBBPredictions] = self.post_prediction_callback(preds)
+ output_device = "cpu" if self.accumulate_on_cpu else None
+ matching_strategy = OBBIoUMatching(self.iou_thresholds.to(preds[0].scores.device))
+
+ for pred, trues in zip(preds, gt_samples):
+ image_mathing = compute_obb_detection_matching(
+ pred, trues, matching_strategy=matching_strategy, top_k=self.top_k_predictions, output_device=output_device
+ )
+
+ accumulated_matching_info = getattr(self, self.state_key)
+ setattr(self, self.state_key, accumulated_matching_info + [image_mathing])
+
+
+@register_metric()
+class OBBDetectionMetrics_050(OBBDetectionMetrics):
+ def __init__(
+ self,
+ num_cls: int,
+ post_prediction_callback: "AbstractOBBPostPredictionCallback",
+ recall_thres: torch.Tensor = None,
+ score_thres: float = 0.01,
+ top_k_predictions: Optional[int] = None,
+ dist_sync_on_step: bool = False,
+ accumulate_on_cpu: bool = True,
+ calc_best_score_thresholds: bool = True,
+ include_classwise_ap: bool = False,
+ class_names: List[str] = None,
+ ):
+ super().__init__(
+ num_cls=num_cls,
+ post_prediction_callback=post_prediction_callback,
+ iou_thres=IouThreshold.MAP_05,
+ recall_thres=recall_thres,
+ score_thres=score_thres,
+ top_k_predictions=top_k_predictions,
+ dist_sync_on_step=dist_sync_on_step,
+ accumulate_on_cpu=accumulate_on_cpu,
+ calc_best_score_thresholds=calc_best_score_thresholds,
+ include_classwise_ap=include_classwise_ap,
+ class_names=class_names,
+ state_dict_prefix="",
+ )
+
+
+@register_metric()
+class OBBDetectionMetrics_050_095(OBBDetectionMetrics):
+ def __init__(
+ self,
+ num_cls: int,
+ post_prediction_callback: "AbstractOBBPostPredictionCallback",
+ recall_thres: torch.Tensor = None,
+ score_thres: float = 0.01,
+ top_k_predictions: Optional[int] = None,
+ dist_sync_on_step: bool = False,
+ accumulate_on_cpu: bool = True,
+ calc_best_score_thresholds: bool = True,
+ include_classwise_ap: bool = False,
+ class_names: List[str] = None,
+ ):
+ super().__init__(
+ num_cls=num_cls,
+ post_prediction_callback=post_prediction_callback,
+ iou_thres=IouThreshold.MAP_05_TO_095,
+ recall_thres=recall_thres,
+ score_thres=score_thres,
+ top_k_predictions=top_k_predictions,
+ dist_sync_on_step=dist_sync_on_step,
+ accumulate_on_cpu=accumulate_on_cpu,
+ calc_best_score_thresholds=calc_best_score_thresholds,
+ include_classwise_ap=include_classwise_ap,
+ class_names=class_names,
+ state_dict_prefix="",
+ )
diff --git a/src/super_gradients/training/pretrained_models.py b/src/super_gradients/training/pretrained_models.py
index 9c938b4a4d..fabed62e72 100644
--- a/src/super_gradients/training/pretrained_models.py
+++ b/src/super_gradients/training/pretrained_models.py
@@ -69,6 +69,7 @@
"coco": 80,
"coco_pose": 17,
"cifar10": 10,
+ "dota2": 18,
}
DATASET_LICENSES = {
@@ -79,4 +80,5 @@
"coco_pose": "https://cocodataset.org/#termsofuse",
"cityscapes": "https://www.cs.toronto.edu/~kriz/cifar.html",
"objects365": "https://www.objects365.org/download.html",
+ "dota2": "https://captain-whu.github.io/DOTA/dataset.html",
}
diff --git a/src/super_gradients/training/processing/__init__.py b/src/super_gradients/training/processing/__init__.py
index b189baa2fa..d643eea626 100644
--- a/src/super_gradients/training/processing/__init__.py
+++ b/src/super_gradients/training/processing/__init__.py
@@ -15,6 +15,7 @@
SegmentationPadShortToCropSize,
SegmentationPadToDivisible,
)
+from .obb import OBBDetectionAutoPadding
from .defaults import get_pretrained_processing_params
__all__ = [
@@ -33,5 +34,6 @@
"SegmentationResize",
"SegmentationPadShortToCropSize",
"SegmentationPadToDivisible",
+ "OBBDetectionAutoPadding",
"get_pretrained_processing_params",
]
diff --git a/src/super_gradients/training/processing/defaults.py b/src/super_gradients/training/processing/defaults.py
index e8ff6a0e8e..43f4155329 100644
--- a/src/super_gradients/training/processing/defaults.py
+++ b/src/super_gradients/training/processing/defaults.py
@@ -1,9 +1,11 @@
from super_gradients.training.datasets.datasets_conf import (
COCO_DETECTION_CLASSES_LIST,
+ DOTA2_DEFAULT_CLASSES_LIST,
IMAGENET_CLASSES,
CITYSCAPES_DEFAULT_SEGMENTATION_CLASSES_LIST,
)
+from .obb import OBBDetectionCenterPadding, OBBDetectionLongestMaxSizeRescale
from .processing import (
ComposeProcessing,
ReverseImageChannels,
@@ -93,6 +95,28 @@ def default_yolo_nas_coco_processing_params() -> dict:
return params
+def default_yolo_nas_r_dota_processing_params() -> dict:
+ """Processing parameters commonly used for training YoloNAS on COCO dataset."""
+
+ image_processor = ComposeProcessing(
+ [
+ ReverseImageChannels(), # Model trained on BGR images
+ OBBDetectionLongestMaxSizeRescale(output_shape=(1024, 1024)),
+ OBBDetectionCenterPadding(output_shape=(1024, 1024), pad_value=114),
+ StandardizeImage(max_value=255.0),
+ ImagePermute(permutation=(2, 0, 1)),
+ ]
+ )
+
+ params = dict(
+ class_names=DOTA2_DEFAULT_CLASSES_LIST,
+ image_processor=image_processor,
+ iou=0.25,
+ conf=0.1,
+ )
+ return params
+
+
def default_dekr_coco_processing_params() -> dict:
"""Processing parameters commonly used for training DEKR on COCO dataset."""
diff --git a/src/super_gradients/training/processing/obb.py b/src/super_gradients/training/processing/obb.py
new file mode 100644
index 0000000000..b524585df2
--- /dev/null
+++ b/src/super_gradients/training/processing/obb.py
@@ -0,0 +1,70 @@
+from typing import Tuple
+
+import numpy as np
+from super_gradients.common.registry import register_processing
+from super_gradients.training.transforms.utils import (
+ _pad_image,
+ PaddingCoordinates,
+ _get_center_padding_coordinates,
+ _rescale_bboxes,
+ _get_bottom_right_padding_coordinates,
+)
+from super_gradients.training.utils.predict import OBBDetectionPrediction
+from .processing import AutoPadding, DetectionPadToSizeMetadata, _LongestMaxSizeRescale, RescaleMetadata, _DetectionPadding
+
+
+@register_processing()
+class OBBDetectionCenterPadding(_DetectionPadding):
+ def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates:
+ return _get_center_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape)
+
+ def postprocess_predictions(self, predictions: OBBDetectionPrediction, metadata: DetectionPadToSizeMetadata) -> OBBDetectionPrediction:
+ offset = np.array([metadata.padding_coordinates.left, metadata.padding_coordinates.top, 0, 0, 0], dtype=np.float32).reshape(-1, 5)
+ predictions.rboxes_cxcywhr = predictions.rboxes_cxcywhr - offset
+ return predictions
+
+
+@register_processing()
+class OBBDetectionBottomRightPadding(_DetectionPadding):
+ def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates:
+ return _get_bottom_right_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape)
+
+ def postprocess_predictions(self, predictions: OBBDetectionPrediction, metadata: DetectionPadToSizeMetadata) -> OBBDetectionPrediction:
+ return predictions
+
+
+@register_processing()
+class OBBDetectionLongestMaxSizeRescale(_LongestMaxSizeRescale):
+ def postprocess_predictions(self, predictions: OBBDetectionPrediction, metadata: RescaleMetadata) -> OBBDetectionPrediction:
+ predictions.rboxes_cxcywhr = _rescale_bboxes(
+ targets=predictions.rboxes_cxcywhr, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w)
+ )
+ return predictions
+
+
+@register_processing()
+class OBBDetectionAutoPadding(AutoPadding):
+ def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPadToSizeMetadata]:
+ padding_coordinates = self._get_padding_params(input_shape=image.shape[:2]) # HWC -> (H, W)
+ processed_image = _pad_image(image=image, padding_coordinates=padding_coordinates, pad_value=self.pad_value)
+ return processed_image, DetectionPadToSizeMetadata(padding_coordinates=padding_coordinates)
+
+ def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates:
+ input_height, input_width = input_shape
+ height_modulo, width_modulo = self.shape_multiple
+
+ # Calculate necessary padding to reach the modulo
+ padded_height = ((input_height + height_modulo - 1) // height_modulo) * height_modulo
+ padded_width = ((input_width + width_modulo - 1) // width_modulo) * width_modulo
+
+ padding_top = 0 # No padding at the top
+ padding_left = 0 # No padding on the left
+ padding_bottom = padded_height - input_height
+ padding_right = padded_width - input_width
+
+ return PaddingCoordinates(top=padding_top, left=padding_left, bottom=padding_bottom, right=padding_right)
+
+ def postprocess_predictions(self, predictions: OBBDetectionPrediction, metadata: DetectionPadToSizeMetadata) -> OBBDetectionPrediction:
+ offset = np.array([metadata.padding_coordinates.left, metadata.padding_coordinates.top, 0, 0, 0], dtype=np.float32).reshape(-1, 5)
+ predictions.rboxes_cxcywhr = predictions.rboxes_cxcywhr + offset
+ return predictions
diff --git a/src/super_gradients/training/utils/callbacks/__init__.py b/src/super_gradients/training/utils/callbacks/__init__.py
index f7be3e0a8c..53b292f89d 100644
--- a/src/super_gradients/training/utils/callbacks/__init__.py
+++ b/src/super_gradients/training/utils/callbacks/__init__.py
@@ -25,6 +25,7 @@
from super_gradients.common.object_names import Callbacks, LRSchedulers, LRWarmups
from super_gradients.common.registry.registry import CALLBACKS, LR_SCHEDULERS_CLS_DICT, LR_WARMUP_CLS_DICT
from super_gradients.training.utils.callbacks.extreme_batch_pose_visualization_callback import ExtremeBatchPoseEstimationVisualizationCallback
+from .extreme_batch_obb_visualization_callback import ExtremeBatchOBBVisualizationCallback
__all__ = [
"Callback",
@@ -60,4 +61,5 @@
"PPYoloETrainingStageSwitchCallback",
"TimerCallback",
"ExtremeBatchPoseEstimationVisualizationCallback",
+ "ExtremeBatchOBBVisualizationCallback",
]
diff --git a/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py b/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py
new file mode 100644
index 0000000000..bbaee97d88
--- /dev/null
+++ b/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py
@@ -0,0 +1,208 @@
+import typing
+from typing import Optional, Tuple, List, Union
+
+import numpy as np
+import torch
+from torch import Tensor
+from torchmetrics import Metric
+
+from super_gradients.common.registry.registry import register_callback
+from super_gradients.training.utils.callbacks.callbacks import ExtremeBatchCaseVisualizationCallback
+from super_gradients.training.utils.visualization.obb import OBBVisualization
+from super_gradients.training.utils.visualization.utils import generate_color_mapping
+
+# These imports are required for type hints and not used anywhere else
+# Wrapping them under typing.TYPE_CHECKING is a legit way to avoid circular imports
+# while still having type hints
+if typing.TYPE_CHECKING:
+ from super_gradients.training.datasets.obb.dota import OBBSample
+ from super_gradients.module_interfaces.obb_predictions import AbstractOBBPostPredictionCallback, OBBPredictions
+
+
+@register_callback()
+class ExtremeBatchOBBVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
+ """
+ ExtremeBatchOBBVisualizationCallback
+
+ Visualizes worst/best batch in an epoch for OBB detection task.
+ This class visualize horizontally-stacked GT and predicted boxes.
+ It requires a key 'gt_samples' (List[OBBSample]) to be present in additional_batch_items dictionary.
+
+ Supported models: YoloNAS-R
+ Supported datasets: DOTAOBBDataset
+
+ Example usage in Yaml config:
+
+ training_hyperparams:
+ phase_callbacks:
+ - ExtremeBatchOBBVisualizationCallback:
+ loss_to_monitor: YoloNASRLoss/loss
+ max: True
+ freq: 1
+ max_images: 16
+ enable_on_train_loader: True
+ enable_on_valid_loader: True
+ post_prediction_callback:
+ _target_: super_gradients.training.models.detection_models.yolo_nas_r.yolo_nas_r_post_prediction_callback.YoloNASRPostPredictionCallback
+ score_threshold: 0.25
+ pre_nms_max_predictions: 4096
+ post_nms_max_predictions: 512
+ nms_iou_threshold: 0.6
+
+ :param metric: Metric, will be the metric which is monitored.
+
+ :param metric_component_name: In case metric returns multiple values (as Mapping),
+ the value at metric.compute()[metric_component_name] will be the one monitored.
+
+ :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
+ Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:
+
+ if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
+ "/".
+
+ If a single item is returned rather then a tuple:
+ .
+
+ When there is no such attributes and criterion.forward(..) returns a tuple:
+ "/"Loss_"
+
+ :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
+ the minimum (default=False).
+
+ :param freq: int, epoch frequency to perform all of the above (default=1).
+ """
+
+ def __init__(
+ self,
+ post_prediction_callback: "AbstractOBBPostPredictionCallback",
+ class_names: List[str],
+ class_colors=None,
+ metric: Optional[Metric] = None,
+ metric_component_name: Optional[str] = None,
+ loss_to_monitor: Optional[str] = None,
+ max: bool = False,
+ freq: int = 1,
+ max_images: int = -1,
+ enable_on_train_loader: bool = False,
+ enable_on_valid_loader: bool = True,
+ ):
+ if class_colors is None:
+ class_colors = generate_color_mapping(num_classes=len(class_names))
+
+ super().__init__(
+ metric=metric,
+ metric_component_name=metric_component_name,
+ loss_to_monitor=loss_to_monitor,
+ max=max,
+ freq=freq,
+ enable_on_train_loader=enable_on_train_loader,
+ enable_on_valid_loader=enable_on_valid_loader,
+ )
+ self.class_names = list(class_names)
+ self.class_colors = class_colors
+ self.post_prediction_callback = post_prediction_callback
+ self.max_images = max_images
+
+ @classmethod
+ def universal_undo_preprocessing_fn(cls, inputs: torch.Tensor) -> np.ndarray:
+ """
+ A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
+ :param inputs:
+ :return:
+ """
+ inputs = inputs - inputs.min()
+ inputs /= inputs.max()
+ inputs *= 255
+ inputs = inputs.to(torch.uint8)
+ inputs = inputs.cpu().numpy()
+ inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
+ inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
+ return inputs
+
+ @classmethod
+ def _visualize_batch(
+ cls,
+ image_tensor: np.ndarray,
+ rboxes: List[Union[np.ndarray, Tensor]],
+ labels: List[Union[np.ndarray, Tensor]],
+ scores: Optional[List[Union[np.ndarray, Tensor]]],
+ class_colors: List[Tuple[int, int, int]],
+ class_names: List[str],
+ ) -> List[np.ndarray]:
+ """
+ Generate list of samples visualization of a batch of images with keypoints and bounding boxes.
+
+ :param image_tensor: Images batch of [Batch Size, 3, H, W] shape with values in [0, 255] range.
+ The images should be scaled to [0, 255] range and converted to uint8 type beforehead.
+ :param scores: Keypoint scores. Shape [Num Instances, Num Joints]. Can be None.
+ :return: List of visualization images.
+ """
+
+ out_images = []
+ for i in range(image_tensor.shape[0]):
+ rboxes_i = rboxes[i]
+ labels_i = labels[i]
+ scores_i = scores[i] if scores is not None else None
+
+ if torch.is_tensor(rboxes_i):
+ rboxes_i = rboxes_i.detach().cpu().numpy()
+ if torch.is_tensor(labels_i):
+ labels_i = labels_i.detach().cpu().numpy()
+ if torch.is_tensor(scores_i):
+ scores_i = scores_i.detach().cpu().numpy()
+
+ res_image = image_tensor[i]
+ res_image = OBBVisualization.draw_obb(
+ image=res_image,
+ rboxes_cxcywhr=rboxes_i,
+ labels=labels_i,
+ scores=scores_i,
+ class_colors=class_colors,
+ class_names=class_names,
+ show_confidence=True,
+ show_labels=True,
+ )
+
+ out_images.append(res_image)
+
+ return out_images
+
+ @torch.no_grad()
+ def process_extreme_batch(self) -> np.ndarray:
+ """
+ Processes the extreme batch, and returns batche of images for visualization - predictions and GT poses stacked horizontally.
+
+ :return: np.ndarray - the visualization of predictions and GT
+ """
+ if "gt_samples" not in self.extreme_additional_batch_items:
+ raise RuntimeError(
+ "ExtremeBatchPoseEstimationVisualizationCallback requires 'gt_samples' to be present in additional_batch_items."
+ "Currently only YoloNASPose model is supported. Old DEKR recipe is not supported at the moment."
+ )
+
+ inputs = self.universal_undo_preprocessing_fn(self.extreme_batch)
+ gt_samples: List["OBBSample"] = self.extreme_additional_batch_items["gt_samples"]
+ predictions: List["OBBPredictions"] = self.post_prediction_callback(self.extreme_preds)
+
+ images_to_save_preds = self._visualize_batch(
+ image_tensor=inputs,
+ rboxes=[p.rboxes_cxcywhr for p in predictions],
+ labels=[p.labels for p in predictions],
+ scores=[p.scores for p in predictions],
+ class_colors=self.class_colors,
+ class_names=self.class_names,
+ )
+ images_to_save_preds = np.stack(images_to_save_preds)
+
+ images_to_save_gt = self._visualize_batch(
+ image_tensor=inputs,
+ rboxes=[gt.rboxes_cxcywhr for gt in gt_samples],
+ labels=[gt.labels for gt in gt_samples],
+ scores=None,
+ class_colors=self.class_colors,
+ class_names=self.class_names,
+ )
+ images_to_save_gt = np.stack(images_to_save_gt)
+
+ # Stack the predictions and GT images together
+ return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)
diff --git a/src/super_gradients/training/utils/predict/__init__.py b/src/super_gradients/training/utils/predict/__init__.py
index be93ee930c..90065b309b 100644
--- a/src/super_gradients/training/utils/predict/__init__.py
+++ b/src/super_gradients/training/utils/predict/__init__.py
@@ -17,7 +17,7 @@
VideoPoseEstimationPrediction,
ImagesPoseEstimationPrediction,
)
-
+from .prediction_obb_detection_results import OBBDetectionPrediction, ImageOBBDetectionPrediction, ImagesOBBDetectionPrediction, VideoOBBDetectionPrediction
__all__ = [
"Prediction",
@@ -39,4 +39,8 @@
"ImageSegmentationPrediction",
"ImagesSegmentationPrediction",
"VideoSegmentationPrediction",
+ "OBBDetectionPrediction",
+ "ImageOBBDetectionPrediction",
+ "ImagesOBBDetectionPrediction",
+ "VideoOBBDetectionPrediction",
]
diff --git a/src/super_gradients/training/utils/predict/prediction_obb_detection_results.py b/src/super_gradients/training/utils/predict/prediction_obb_detection_results.py
new file mode 100644
index 0000000000..ca9b22b717
--- /dev/null
+++ b/src/super_gradients/training/utils/predict/prediction_obb_detection_results.py
@@ -0,0 +1,428 @@
+import os
+
+import numpy as np
+import cv2
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Iterator, Iterable, Union
+
+from super_gradients.training.utils.media.image import save_image, show_image
+from super_gradients.training.utils.media.video import show_video_from_frames, save_video
+from super_gradients.training.utils.visualization.obb import OBBVisualization
+from super_gradients.training.utils.visualization.utils import generate_color_mapping
+from tqdm import tqdm
+
+from .predictions import Prediction
+from .prediction_results import ImagePrediction, VideoPredictions, ImagesPredictions
+
+__all__ = ["OBBDetectionPrediction", "ImageOBBDetectionPrediction", "ImagesOBBDetectionPrediction", "VideoOBBDetectionPrediction"]
+
+
+@dataclass
+class OBBDetectionPrediction(Prediction):
+ """Represents an OBB detection prediction, with bboxes represented in cxycxwhr format."""
+
+ rboxes_cxcywhr: np.ndarray
+ confidence: np.ndarray
+ labels: np.ndarray
+
+ def __init__(self, rboxes_cxcywhr: np.ndarray, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]):
+ """
+ :param rboxes_cxcywhr: Rboxes of [N,5] shape in the CXCYWHR format
+ :param confidence: Confidence scores for each bounding box
+ :param labels: Labels for each bounding box.
+ :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
+ """
+ self._validate_input(rboxes_cxcywhr, confidence, labels)
+ self.rboxes_cxcywhr = rboxes_cxcywhr
+ self.confidence = confidence
+ self.labels = labels
+ self.image_shape = image_shape
+
+ def _validate_input(self, rboxes_cxcywhr: np.ndarray, confidence: np.ndarray, labels: np.ndarray) -> None:
+ n_bboxes, n_confidences, n_labels = rboxes_cxcywhr.shape[0], confidence.shape[0], labels.shape[0]
+ if n_bboxes != n_confidences != n_labels:
+ raise ValueError(
+ f"The number of bounding boxes ({n_bboxes}) does not match the number of confidence scores ({n_confidences}) and labels ({n_labels})."
+ )
+ if rboxes_cxcywhr.shape[1] != 5:
+ raise ValueError(f"Expected 5 columns in rboxes_cxcywhr, got {rboxes_cxcywhr.shape[1]}.")
+
+ def __len__(self):
+ return len(self.rboxes_cxcywhr)
+
+
+@dataclass
+class ImageOBBDetectionPrediction(ImagePrediction):
+ """Object wrapping an image and a detection model's prediction.
+
+ :param image: Input image
+ :param prediction: Predictions of the model
+ :param class_names: List of the class names to predict
+ """
+
+ image: np.ndarray
+ prediction: OBBDetectionPrediction
+ class_names: List[str]
+
+ def draw(
+ self,
+ box_thickness: Optional[int] = None,
+ show_confidence: bool = True,
+ color_mapping: Optional[List[Tuple[int, int, int]]] = None,
+ target_rboxes: Optional[np.ndarray] = None,
+ target_class_ids: Optional[np.ndarray] = None,
+ class_names: Optional[List[str]] = None,
+ ) -> np.ndarray:
+ """Draw the predicted bboxes on the image.
+
+ :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
+ :param show_confidence: Whether to show confidence scores on the image.
+ :param color_mapping: List of tuples representing the colors for each class.
+ Default is None, which generates a default color mapping based on the number of class names.
+ :param target_rboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
+ Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
+ or a list of length len(target_bboxes), containing such arrays.
+ When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
+ :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
+ (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
+ :param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of
+ ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
+ Will raise an error if not None and target_bboxes is None.
+ :param class_names: List of class names to show. By default, is None which shows all classes using during training.
+
+ :return: Image with predicted bboxes. Note that this does not modify the original image.
+ """
+ target_rboxes = target_rboxes if target_rboxes is not None else np.zeros((0, 5))
+ target_class_ids = target_class_ids if target_class_ids is not None else np.zeros((0, 1))
+
+ class_names_to_show = class_names if class_names else self.class_names
+ class_ids_to_show = [i for i, class_name in enumerate(self.class_names) if class_name in class_names_to_show]
+ invalid_class_names_to_show = set(class_names_to_show) - set(self.class_names)
+ if len(invalid_class_names_to_show) > 0:
+ raise ValueError(
+ "`class_names` includes class names that the model was not trained on.\n"
+ f" - Invalid class names: {list(invalid_class_names_to_show)}\n"
+ f" - Available class names: {list(self.class_names)}"
+ )
+
+ plot_targets = target_rboxes is not None and len(target_rboxes)
+ color_mapping = color_mapping or generate_color_mapping(len(class_names_to_show))
+
+ keep_mask = np.isin(self.prediction.labels, class_ids_to_show)
+ image = OBBVisualization.draw_obb(
+ image=self.image.copy(),
+ rboxes_cxcywhr=self.prediction.rboxes_cxcywhr[keep_mask],
+ scores=self.prediction.confidence[keep_mask],
+ labels=self.prediction.labels[keep_mask],
+ class_names=class_names_to_show,
+ class_colors=color_mapping,
+ show_labels=True,
+ show_confidence=show_confidence,
+ thickness=box_thickness,
+ )
+
+ if plot_targets:
+ keep_mask = np.isin(target_class_ids, class_ids_to_show)
+ target_image = OBBVisualization.draw_obb(
+ image=self.image.copy(),
+ rboxes_cxcywhr=target_rboxes[keep_mask],
+ scores=None,
+ labels=target_class_ids[keep_mask],
+ class_names=class_names_to_show,
+ class_colors=color_mapping,
+ show_labels=True,
+ show_confidence=False,
+ thickness=box_thickness,
+ )
+
+ height, width, ch = target_image.shape
+ new_width, new_height = int(width + width / 20), int(height + height / 8)
+
+ # Crate a new canvas with new width and height.
+ canvas_image = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255
+ canvas_target = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255
+
+ # New replace the center of canvas with original image
+ padding_top, padding_left = 60, 10
+
+ canvas_image[padding_top : padding_top + height, padding_left : padding_left + width] = image
+ canvas_target[padding_top : padding_top + height, padding_left : padding_left + width] = target_image
+
+ img1 = cv2.putText(canvas_image, "Predictions", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0))
+ img2 = cv2.putText(canvas_target, "Ground Truth", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0))
+
+ image = cv2.hconcat((img1, img2))
+ return image
+
+ def show(
+ self,
+ box_thickness: Optional[int] = None,
+ show_confidence: bool = True,
+ color_mapping: Optional[List[Tuple[int, int, int]]] = None,
+ target_bboxes: Optional[np.ndarray] = None,
+ target_bboxes_format: Optional[str] = None,
+ target_class_ids: Optional[np.ndarray] = None,
+ class_names: Optional[List[str]] = None,
+ ) -> None:
+ """Display the image with predicted bboxes.
+
+ :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
+ :param show_confidence: Whether to show confidence scores on the image.
+ :param color_mapping: List of tuples representing the colors for each class.
+ Default is None, which generates a default color mapping based on the number of class names.
+ :param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
+ Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
+ or a list of length len(target_bboxes), containing such arrays.
+ When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
+ :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
+ (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
+ :param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of
+ ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
+ Will raise an error if not None and target_bboxes is None.
+ :param class_names: List of class names to show. By default, is None which shows all classes using during training.
+ """
+ image = self.draw(
+ box_thickness=box_thickness,
+ show_confidence=show_confidence,
+ color_mapping=color_mapping,
+ target_rboxes=target_bboxes,
+ target_class_ids=target_class_ids,
+ class_names=class_names,
+ )
+ show_image(image)
+
+ def save(
+ self,
+ output_path: str,
+ box_thickness: Optional[int] = None,
+ show_confidence: bool = True,
+ color_mapping: Optional[List[Tuple[int, int, int]]] = None,
+ target_bboxes: Optional[np.ndarray] = None,
+ target_class_ids: Optional[np.ndarray] = None,
+ class_names: Optional[List[str]] = None,
+ ) -> None:
+ """Save the predicted bboxes on the images.
+
+ :param output_path: Path to the output video file.
+ :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
+ :param show_confidence: Whether to show confidence scores on the image.
+ :param color_mapping: List of tuples representing the colors for each class.
+ Default is None, which generates a default color mapping based on the number of class names.
+ :param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
+ Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
+ or a list of length len(target_bboxes), containing such arrays.
+ When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
+ :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
+ (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
+ :param class_names: List of class names to show. By default, is None which shows all classes using during training.
+ """
+ image = self.draw(
+ box_thickness=box_thickness,
+ show_confidence=show_confidence,
+ color_mapping=color_mapping,
+ target_rboxes=target_bboxes,
+ target_class_ids=target_class_ids,
+ class_names=class_names,
+ )
+ save_image(image=image, path=output_path)
+
+
+@dataclass
+class ImagesOBBDetectionPrediction(ImagesPredictions):
+ """Object wrapping the list of image detection predictions.
+
+ :attr _images_prediction_lst: List of the predictions results
+ """
+
+ _images_prediction_lst: List[ImageOBBDetectionPrediction]
+
+ def show(
+ self,
+ box_thickness: Optional[int] = None,
+ show_confidence: bool = True,
+ color_mapping: Optional[List[Tuple[int, int, int]]] = None,
+ target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
+ target_bboxes_format: Optional[str] = None,
+ target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
+ class_names: Optional[List[str]] = None,
+ ) -> None:
+ """Display the predicted bboxes on the images.
+
+ :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
+ :param show_confidence: Whether to show confidence scores on the image.
+ :param color_mapping: List of tuples representing the colors for each class.
+ Default is None, which generates a default color mapping based on the number of class names.
+ :param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
+ Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
+ or a list of length len(target_bboxes), containing such arrays.
+ When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
+ :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
+ (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
+ :param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of
+ ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
+ Will raise an error if not None and target_bboxes is None.
+ :param class_names: List of class names to show. By default, is None which shows all classes using during training.
+ """
+ target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids)
+
+ for prediction, target_bbox, target_class_id in zip(self._images_prediction_lst, target_bboxes, target_class_ids):
+ prediction.show(
+ box_thickness=box_thickness,
+ show_confidence=show_confidence,
+ color_mapping=color_mapping,
+ target_bboxes=target_bbox,
+ target_bboxes_format=target_bboxes_format,
+ target_class_ids=target_class_id,
+ class_names=class_names,
+ )
+
+ def _check_target_args(
+ self,
+ target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
+ target_bboxes_format: Optional[str] = None,
+ target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
+ ):
+ if not (
+ (target_bboxes is None and target_bboxes_format is None and target_class_ids is None)
+ or (target_bboxes is not None and target_bboxes_format is not None and target_class_ids is not None)
+ ):
+ raise ValueError("target_bboxes, target_bboxes_format, and target_class_ids should either all be None or all not None.")
+
+ if isinstance(target_bboxes, np.ndarray):
+ target_bboxes = [target_bboxes]
+ if isinstance(target_class_ids, np.ndarray):
+ target_class_ids = [target_class_ids]
+
+ if target_bboxes is not None and target_class_ids is not None and len(target_bboxes) != len(target_class_ids):
+ raise ValueError(f"target_bboxes and target_class_ids lengths should be equal, got: {len(target_bboxes)} and {len(target_class_ids)}.")
+ if target_bboxes is not None and target_class_ids is not None and len(target_bboxes) != len(self._images_prediction_lst):
+ raise ValueError(
+ f"target_bboxes and target_class_ids lengths should be equal, to the "
+ f"amount of images passed to predict(), got: {len(target_bboxes)} and {len(self._images_prediction_lst)}."
+ )
+ if target_bboxes is None:
+ target_bboxes = [None for _ in range(len(self._images_prediction_lst))]
+ target_class_ids = [None for _ in range(len(self._images_prediction_lst))]
+
+ return target_bboxes, target_class_ids
+
+ def save(
+ self,
+ output_folder: str,
+ box_thickness: Optional[int] = None,
+ show_confidence: bool = True,
+ color_mapping: Optional[List[Tuple[int, int, int]]] = None,
+ target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
+ target_bboxes_format: Optional[str] = None,
+ target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
+ class_names: Optional[List[str]] = None,
+ ) -> None:
+ """Save the predicted bboxes on the images.
+
+ :param output_folder: Folder path, where the images will be saved.
+ :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
+ :param show_confidence: Whether to show confidence scores on the image.
+ :param color_mapping: List of tuples representing the colors for each class.
+ Default is None, which generates a default color mapping based on the number of class names.
+ :param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
+ Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
+ or a list of length len(target_bboxes), containing such arrays.
+ When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
+ :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
+ (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
+ :param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of
+ ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
+ Will raise an error if not None and target_bboxes is None.
+ :param class_names: List of class names to show. By default, is None which shows all classes using during training.
+ """
+ if output_folder:
+ os.makedirs(output_folder, exist_ok=True)
+
+ target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids)
+
+ for i, (prediction, target_bbox, target_class_id) in enumerate(zip(self._images_prediction_lst, target_bboxes, target_class_ids)):
+ image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
+ prediction.save(
+ output_path=image_output_path,
+ box_thickness=box_thickness,
+ show_confidence=show_confidence,
+ color_mapping=color_mapping,
+ class_names=class_names,
+ )
+
+
+@dataclass
+class VideoOBBDetectionPrediction(VideoPredictions):
+ """Object wrapping the list of image detection predictions as a Video.
+
+ :attr _images_prediction_gen: Iterable object of the predictions results
+ :att fps: Frames per second of the video
+ """
+
+ _images_prediction_gen: Iterable[ImageOBBDetectionPrediction]
+ fps: int
+ n_frames: int
+
+ def draw(
+ self,
+ box_thickness: Optional[int] = None,
+ show_confidence: bool = True,
+ color_mapping: Optional[List[Tuple[int, int, int]]] = None,
+ class_names: Optional[List[str]] = None,
+ ) -> Iterator[np.ndarray]:
+ """Draw the predicted bboxes on the images.
+
+ :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
+ :param show_confidence: Whether to show confidence scores on the image.
+ :param color_mapping: List of tuples representing the colors for each class.
+ Default is None, which generates a default color mapping based on the number of class names.
+ :param class_names: List of class names to show. By default, is None which shows all classes using during training.
+ :return: Iterable object of images with predicted bboxes. Note that this does not modify the original image.
+ """
+
+ for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"):
+ yield result.draw(
+ box_thickness=box_thickness,
+ show_confidence=show_confidence,
+ color_mapping=color_mapping,
+ class_names=class_names,
+ )
+
+ def show(
+ self,
+ box_thickness: Optional[int] = None,
+ show_confidence: bool = True,
+ color_mapping: Optional[List[Tuple[int, int, int]]] = None,
+ class_names: Optional[List[str]] = None,
+ ) -> None:
+ """Display the predicted bboxes on the images.
+
+ :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
+ :param show_confidence: Whether to show confidence scores on the image.
+ :param color_mapping: List of tuples representing the colors for each class.
+ Default is None, which generates a default color mapping based on the number of class names.
+ :param class_names: List of class names to show. By default, is None which shows all classes using during training.
+ """
+ frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping, class_names=class_names)
+ show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps)
+
+ def save(
+ self,
+ output_path: str,
+ box_thickness: Optional[int] = None,
+ show_confidence: bool = True,
+ color_mapping: Optional[List[Tuple[int, int, int]]] = None,
+ class_names: Optional[List[str]] = None,
+ ) -> None:
+ """Save the predicted bboxes on the images.
+
+ :param output_path: Path to the output video file.
+ :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
+ :param show_confidence: Whether to show confidence scores on the image.
+ :param color_mapping: List of tuples representing the colors for each class.
+ Default is None, which generates a default color mapping based on the number of class names.
+ :param class_names: List of class names to show. By default, is None which shows all classes using during training.
+ """
+ frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping, class_names=class_names)
+ save_video(output_path=output_path, frames=frames, fps=self.fps)