diff --git a/src/super_gradients/training/datasets/Dataset_Setup_Instructions.md b/documentation/source/Dataset_Setup_Instructions.md similarity index 91% rename from src/super_gradients/training/datasets/Dataset_Setup_Instructions.md rename to documentation/source/Dataset_Setup_Instructions.md index fc8c3b8be9..e0620991b3 100644 --- a/src/super_gradients/training/datasets/Dataset_Setup_Instructions.md +++ b/documentation/source/Dataset_Setup_Instructions.md @@ -511,3 +511,63 @@ train_set = COCOKeypointsDataset(data_dir='.../coco', images_dir='images/train20 valid_set = COCOKeypointsDataset(data_dir='.../coco', images_dir='images/val2017', json_file='annotations/instances_val2017.json', ...) ``` + + + +### Oriented Box Detection Datasets + + + +
+DOTA 2.0 + +1. Download DOTA dataset: https://captain-whu.github.io/DOTA/dataset.html + +2. Unzip and organize it as below: +``` + dota + └── train + ├── images + │ ├─ 000000000001.jpg + │ └─ ... + └── ann + └─ 000000000001.txt + └── val + ├── images + │ ├─ 000000000002.jpg + │ └─ ... + └── ann + └─ 000000000002.txt +``` + + +3. Run script to slice the dataset into tiles: + +```bash +python src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py --data_dir /dota --output_dir /dota_tiles +``` + +4. Specify path to the sliced dataset in the dataset (CLI): +```bash +python -m super_gradients.train_from_recipe --config-name yolo_nas_r_s_dota dataset_params.data_dir=/dota_tiles +``` + +4. Specify path to the sliced dataset in the dataset (YAML): +```yaml +dataset_params: + train_dataset_params: + data_dir: /dota_tiles/train + val_dataset_params: + data_dir: /dota_tiles/train +``` + +4. Specify path to the sliced dataset in the dataset (CODE): + +```python + +from super_gradients.training.datasets import DOTAOBBDataset + +train_loader = DOTAOBBDataset(data_dir="/dota_tiles/train", ...) +``` + +
diff --git a/mkdocs.yml b/mkdocs.yml index b6badb491b..a28cbb9cb0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -19,7 +19,7 @@ nav: - Models: ./documentation/source/models.md - Dataset: - Data: ./documentation/source/Data.md - - Computer Vision Datasets: ./src/super_gradients/training/datasets/Dataset_Setup_Instructions.md + - Computer Vision Datasets: ./documentation/source/Dataset_Setup_Instructions.md - Dataset Adapter: ./documentation/source/dataloader_adapter.md - Loss functions: ./documentation/source/Losses.md - LR Assignment: ./documentation/source/LRAssignment.md diff --git a/src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py b/src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py new file mode 100644 index 0000000000..8392facf25 --- /dev/null +++ b/src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py @@ -0,0 +1,79 @@ +""" +This script slices the DOTA dataset into tiles of a usable size for training a model. +The tiles are saved in the output directory with the same structure as the input directory. + +To use this script you should download the DOTA dataset from the official website: +https://captain-whu.github.io/DOTA/dataset.html + +The dataset should be organized as follows: + dota + └── train + ├── images + │ ├─ 000000000001.jpg + │ └─ ... + └── ann + └─ 000000000001.txt + └── val + ├── images + │ ├─ 000000000002.jpg + │ └─ ... + └── ann + └─ 000000000002.txt + +Example usage: + python dota_prepare_dataset.py --input_dir /path/to/dota --output_dir /path/to/dota-sliced + +After running this script you can use /path/to/dota-sliced as the data_dir argument for training a model on DOTA dataset. +""" + +import argparse +from pathlib import Path + +import cv2 +from super_gradients.training.datasets import DOTAOBBDataset + + +def main(): + parser = argparse.ArgumentParser(description="Slice DOTA dataset into tiles of usable size for training a model") + parser.add_argument("--input_dir", help="Where the full coco dataset is stored", required=True) + parser.add_argument("--output_dir", help="Where the resulting data should be stored", required=True) + parser.add_argument("--ann_subdir_name", default="ann", help="Name of the annotations subdirectory") + parser.add_argument("--output_ann_subdir_name", default="ann-obb", help="Name of the output annotations subdirectory") + parser.add_argument("--num_workers", default=cv2.getNumberOfCPUs() // 2) + args = parser.parse_args() + + cv2.setNumThreads(cv2.getNumberOfCPUs() // 4) + + input_dir = Path(args.input_dir) + output_dir = Path(args.output_dir) + ann_subdir_name = str(args.ann_subdir_name) + output_ann_subdir_name = str(args.output_ann_subdir_name) + DOTAOBBDataset.slice_dataset_into_tiles( + data_dir=input_dir / "train", + output_dir=output_dir / "train", + input_ann_subdir_name=ann_subdir_name, + output_ann_subdir_name=output_ann_subdir_name, + tile_size=1024, + tile_step=512, + scale_factors=(0.75, 1, 1.25), + min_visibility=0.4, + min_area=8, + num_workers=args.num_workers, + ) + + DOTAOBBDataset.slice_dataset_into_tiles( + data_dir=input_dir / "val", + output_dir=output_dir / "val", + input_ann_subdir_name=ann_subdir_name, + output_ann_subdir_name=output_ann_subdir_name, + tile_size=1024, + tile_step=1024, + scale_factors=(1,), + min_visibility=0.4, + min_area=8, + num_workers=args.num_workers, + ) + + +if __name__ == "__main__": + main() diff --git a/src/super_gradients/module_interfaces/__init__.py b/src/super_gradients/module_interfaces/__init__.py index f9871c3825..cd8f30436d 100644 --- a/src/super_gradients/module_interfaces/__init__.py +++ b/src/super_gradients/module_interfaces/__init__.py @@ -12,6 +12,7 @@ SemanticSegmentationDecodingModule, BinarySegmentationDecodingModule, ) +from .obb_predictions import OBBPredictions, AbstractOBBPostPredictionCallback __all__ = [ "HasPredict", @@ -35,4 +36,6 @@ "AbstractSegmentationDecodingModule", "SemanticSegmentationDecodingModule", "BinarySegmentationDecodingModule", + "OBBPredictions", + "AbstractOBBPostPredictionCallback", ] diff --git a/src/super_gradients/module_interfaces/obb_predictions.py b/src/super_gradients/module_interfaces/obb_predictions.py new file mode 100644 index 0000000000..b074587b62 --- /dev/null +++ b/src/super_gradients/module_interfaces/obb_predictions.py @@ -0,0 +1,47 @@ +import abc +import dataclasses +from typing import Any, List +from typing import Union + +import numpy as np +from torch import Tensor + +__all__ = ["OBBPredictions", "AbstractOBBPostPredictionCallback"] + + +@dataclasses.dataclass +class OBBPredictions: + """ + A data class that encapsulates oriented box predictions for a single image. + + :param labels: Array of shape [N] with class indices + :param scores: Array of shape [N] with corresponding confidence scores. + :param rboxes_cxcywhr: Array of shape [N, 5] with rotated boxes for each pose in CXCYWHR format. + """ + + scores: Union[Tensor, np.ndarray] + labels: Union[Tensor, np.ndarray] + rboxes_cxcywhr: Union[Tensor, np.ndarray] + + def __init__(self, rboxes_cxcywhr, scores, labels): + if len(rboxes_cxcywhr) != len(scores) or len(rboxes_cxcywhr) != len(labels): + raise ValueError(f"rboxes_cxcywhr, scores and labels must have the same length. Got: {len(rboxes_cxcywhr)}, {len(scores)}, {len(labels)}") + if rboxes_cxcywhr.ndim != 2 or rboxes_cxcywhr.shape[1] != 5: + raise ValueError(f"rboxes_cxcywhr must have shape [N, 5]. Got: {rboxes_cxcywhr.shape}") + + self.scores = scores + self.labels = labels + self.rboxes_cxcywhr = rboxes_cxcywhr + + def __len__(self): + return len(self.scores) + + +class AbstractOBBPostPredictionCallback(abc.ABC): + """ + A protocol interface of a post-prediction callback for pose estimation models. + """ + + @abc.abstractmethod + def __call__(self, predictions: Any) -> List[OBBPredictions]: + ... diff --git a/src/super_gradients/recipes/dataset_params/dota2_yolo_nas_r_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/dota2_yolo_nas_r_dataset_params.yaml new file mode 100644 index 0000000000..46d9ce6981 --- /dev/null +++ b/src/super_gradients/recipes/dataset_params/dota2_yolo_nas_r_dataset_params.yaml @@ -0,0 +1,115 @@ +# Configuration file for the dataset parameters of the DOTA2 dataset for the YOLO-NAS-R model. +# A data_dir parameter should be explicitly defined in the config file that includes this file. +# Please check documentation/source/Dataset_Setup_Instructions.md for more information on how to set up the dataset. + +num_classes: 18 +class_names: + - plane + - ship + - storage-tank + - baseball-diamond + - tennis-court + - basketball-court + - ground-track-field + - harbor + - bridge + - large-vehicle + - small-vehicle + - helicopter + - roundabout + - soccer-ball-field + - swimming-pool + - container-crane + - airport + - helipad + +data_dir: ??? + +mixup_prob: 0.5 + +train_dataset_params: + data_dir: ${dataset_params.data_dir}/train + class_names: ${dataset_params.class_names} + ignore_empty_annotations: True + transforms: + - Albumentations: + Compose: + keypoint_params: + transforms: + - ShiftScaleRotate: + shift_limit: 0.1 + scale_limit: 0.75 + rotate_limit: 45 + interpolation: 1 + border_mode: 0 + - RandomBrightnessContrast: + brightness_limit: 0.2 + contrast_limit: 0.2 + p: 0.5 + - RandomCrop: + p: 1.0 + height: 640 + width: 640 + - HueSaturationValue: + hue_shift_limit: 20 + sat_shift_limit: 30 + val_shift_limit: 20 + p: 0.5 + - RandomRotate90: + p: 1.0 + - HorizontalFlip: + p: 0.5 + + - OBBRemoveSmallObjects: + min_size: 8 + min_area: 64 + + - OBBDetectionMixup: + prob: ${dataset_params.mixup_prob} + + - OBBDetectionStandardize: + max_value: 255. + +train_dataloader_params: + dataset: DOTAOBBDataset + batch_size: 16 + num_workers: 8 + shuffle: True + drop_last: True + pin_memory: True + persistent_workers: True + collate_fn: OrientedBoxesCollate + sampler: + ClassBalancedSampler: + num_samples: 65536 + oversample_threshold: 0.99 + oversample_aggressiveness: 0.9945267123516118 + +val_dataset_params: + data_dir: ${dataset_params.data_dir}/val + class_names: ${dataset_params.class_names} + ignore_empty_annotations: True + transforms: + - OBBDetectionLongestMaxSize: + max_height: 1024 + max_width: 1024 + - OBBDetectionPadIfNeeded: + min_height: 1024 + min_width: 1024 + pad_value: 114 + padding_mode: bottom_right + - OBBDetectionStandardize: + max_value: 255. + + +val_dataloader_params: + dataset: DOTAOBBDataset + batch_size: 16 + num_workers: 8 + drop_last: False + shuffle: False + pin_memory: True + persistent_workers: True + collate_fn: OrientedBoxesCollate + +_convert_: all diff --git a/src/super_gradients/training/datasets/__init__.py b/src/super_gradients/training/datasets/__init__.py index 754fd09a79..16e4d3cf7d 100755 --- a/src/super_gradients/training/datasets/__init__.py +++ b/src/super_gradients/training/datasets/__init__.py @@ -25,7 +25,7 @@ BaseKeypointsDataset, COCOPoseEstimationDataset, ) - +from .obb import DOTAOBBDataset __all__ = [ "BaseKeypointsDataset", @@ -50,6 +50,7 @@ "SuperviselyPersonsDataset", "COCOKeypointsDataset", "COCOPoseEstimationDataset", + "DOTAOBBDataset", ] cv2.setNumThreads(0) diff --git a/src/super_gradients/training/datasets/datasets_conf.py b/src/super_gradients/training/datasets/datasets_conf.py index 85219a8450..18635eb7b3 100755 --- a/src/super_gradients/training/datasets/datasets_conf.py +++ b/src/super_gradients/training/datasets/datasets_conf.py @@ -1226,3 +1226,24 @@ "motorcycle", "bicycle", ] + +DOTA2_DEFAULT_CLASSES_LIST = [ + "plane", + "ship", + "storage-tank", + "baseball-diamond", + "tennis-court", + "basketball-court", + "ground-track-field", + "harbor", + "bridge", + "large-vehicle", + "small-vehicle", + "helicopter", + "roundabout", + "soccer-ball-field", + "swimming-pool", + "container-crane", + "airport", + "helipad", +] diff --git a/src/super_gradients/training/datasets/obb/__init__.py b/src/super_gradients/training/datasets/obb/__init__.py new file mode 100644 index 0000000000..8a9a448273 --- /dev/null +++ b/src/super_gradients/training/datasets/obb/__init__.py @@ -0,0 +1,7 @@ +from .collate import OrientedBoxesCollate +from .dota import DOTAOBBDataset + +__all__ = [ + "DOTAOBBDataset", + "OrientedBoxesCollate", +] diff --git a/src/super_gradients/training/datasets/obb/abstract_obb_dataset.py b/src/super_gradients/training/datasets/obb/abstract_obb_dataset.py new file mode 100644 index 0000000000..b28c5a18a0 --- /dev/null +++ b/src/super_gradients/training/datasets/obb/abstract_obb_dataset.py @@ -0,0 +1,20 @@ +import abc +from abc import ABC + +from super_gradients.training.transforms.obb import OBBSample +from torch.utils.data import Dataset + + +class AbstractOBBDataset(Dataset, ABC): + """ + Abstract class for OBB detection datasets. + This class declares minimal interface for OBB detection datasets. + """ + + @abc.abstractmethod + def __len__(self): + pass + + @abc.abstractmethod + def __getitem__(self, index) -> OBBSample: + pass diff --git a/src/super_gradients/training/datasets/obb/collate.py b/src/super_gradients/training/datasets/obb/collate.py new file mode 100644 index 0000000000..0ea3cc0b36 --- /dev/null +++ b/src/super_gradients/training/datasets/obb/collate.py @@ -0,0 +1,33 @@ +from typing import List + +import numpy as np +import torch +from super_gradients.common.registry import register_collate_function +from super_gradients.training.transforms.obb import OBBSample + + +@register_collate_function() +class OrientedBoxesCollate: + def __call__(self, batch: List[OBBSample]): + from super_gradients.training.datasets.pose_estimation_datasets.yolo_nas_pose_collate_fn import flat_collate_tensors_with_batch_index + + images = [] + all_boxes = [] + all_labels = [] + all_crowd_masks = [] + + for sample in batch: + images.append(torch.from_numpy(np.transpose(sample.image, [2, 0, 1]))) + all_boxes.append(torch.from_numpy(sample.rboxes_cxcywhr)) + all_labels.append(torch.from_numpy(sample.labels.reshape((-1, 1)))) + all_crowd_masks.append(torch.from_numpy(sample.is_crowd.reshape((-1, 1)))) + sample.image = None + + images = torch.stack(images) + + boxes = flat_collate_tensors_with_batch_index(all_boxes).float() + labels = flat_collate_tensors_with_batch_index(all_labels).long() + is_crowd = flat_collate_tensors_with_batch_index(all_crowd_masks) + + extras = {"gt_samples": batch} + return images, (boxes, labels, is_crowd), extras diff --git a/src/super_gradients/training/datasets/obb/dota.py b/src/super_gradients/training/datasets/obb/dota.py new file mode 100644 index 0000000000..0798d6c682 --- /dev/null +++ b/src/super_gradients/training/datasets/obb/dota.py @@ -0,0 +1,363 @@ +import multiprocessing +import random +import cv2 +import numpy as np + +from functools import partial +from pathlib import Path +from typing import Tuple, Iterable + +from super_gradients.module_interfaces import HasPreprocessingParams +from super_gradients.training.datasets.data_formats.obb.cxcywhr import poly_to_cxcywhr +from tqdm import tqdm + +from super_gradients.common.decorators.factory_decorator import resolve_param +from super_gradients.common.object_names import Processings +from super_gradients.common.registry import register_dataset +from super_gradients.dataset_interfaces import HasClassesInformation +from super_gradients.training.transforms import OBBDetectionCompose +from super_gradients.training.transforms.obb import OBBSample +from super_gradients.common.factories.transforms_factory import TransformsFactory +from .abstract_obb_dataset import AbstractOBBDataset + +__all__ = ["DOTAOBBDataset"] + + +@register_dataset() +class DOTAOBBDataset(AbstractOBBDataset, HasPreprocessingParams, HasClassesInformation): + @resolve_param("transforms", TransformsFactory()) + def __init__( + self, + data_dir, + transforms, + class_names: Iterable[str], + ignore_empty_annotations: bool = False, + difficult_labels_are_crowd: bool = False, + images_ext: str = ".jpg", + images_subdir="images", + ann_subdir="ann-obb", + ): + super().__init__() + + images_dir = Path(data_dir) / images_subdir + ann_dir = Path(data_dir) / ann_subdir + images, labels = self.find_images_and_labels(images_dir, ann_dir, images_ext) + self.images = [] + self.coords = [] + self.classes = [] + self.difficult = [] + self.transforms = OBBDetectionCompose(transforms, load_sample_fn=self.load_random_sample) + self.class_names = list(class_names) + self.difficult_labels_are_crowd = difficult_labels_are_crowd + + class_names_to_index = {name: i for i, name in enumerate(self.class_names)} + for image_path, label_path in tqdm(zip(images, labels), desc=f"Parsing annotations in {ann_dir}", total=len(images)): + coords, classes, difficult = self.parse_annotation_file(label_path) + if ignore_empty_annotations and len(coords) == 0: + continue + self.images.append(image_path) + self.coords.append(coords) + self.classes.append(np.array([class_names_to_index[c] for c in classes], dtype=int)) + self.difficult.append(difficult) + + def __len__(self): + return len(self.images) + + def load_random_sample(self) -> OBBSample: + num_samples = len(self) + random_index = random.randrange(0, num_samples) + return self.load_sample(random_index) + + def load_sample(self, index) -> OBBSample: + image = cv2.imread(str(self.images[index])) + coords = self.coords[index] + classes = self.classes[index] + difficult = self.difficult[index] + + cxcywhr = poly_to_cxcywhr(coords) + + is_crowd = difficult.reshape(-1) if self.difficult_labels_are_crowd else np.zeros_like(difficult, dtype=bool) + sample = OBBSample( + image=image, + rboxes_cxcywhr=cxcywhr.reshape(-1, 5), + labels=classes.reshape(-1), + is_crowd=is_crowd, + ) + return sample.sanitize_sample() + + def __getitem__(self, index) -> OBBSample: + sample = self.load_sample(index) + sample = self.transforms.apply_to_sample(sample) + return sample + + def get_sample_classes_information(self, index) -> np.ndarray: + """ + Returns a histogram of length `num_classes` with class occurrences at that index. + """ + return np.bincount(self.classes[index], minlength=len(self.class_names)) + + def get_dataset_classes_information(self) -> np.ndarray: + """ + Returns a matrix of shape (dataset_length, num_classes). Each row `i` is histogram of length `num_classes` with class occurrences for sample `i`. + Example implementation, assuming __len__: `np.vstack([self.get_sample_classes_information(i) for i in range(len(self))])` + """ + m = np.zeros((len(self), len(self.class_names)), dtype=int) + for i in range(len(self)): + m[i] = self.get_sample_classes_information(i) + return m + + def get_dataset_preprocessing_params(self): + """ + Return any hardcoded preprocessing + adaptation for PIL.Image image reading (RGB). + image_processor as returned as list of dicts to be resolved by processing factory. + :return: + """ + pipeline = [Processings.ReverseImageChannels] + self.transforms.get_equivalent_preprocessing() + params = dict( + class_names=self.class_names, + image_processor={Processings.ComposeProcessing: {"processings": pipeline}}, + iou=0.25, + conf=0.1, + ) + return params + + @classmethod + def find_images_and_labels(cls, images_dir, ann_dir, images_ext): + images_dir = Path(images_dir) + ann_dir = Path(ann_dir) + + images = list(images_dir.glob(f"*{images_ext}")) + labels = list(sorted(ann_dir.glob("*.txt"))) + + if len(images) != len(labels): + raise ValueError(f"Number of images and labels do not match. There are {len(images)} images and {len(labels)} labels.") + + images = [] + for label_path in labels: + image_path = images_dir / (label_path.stem + images_ext) + if not image_path.exists(): + raise ValueError(f"Image {image_path} does not exist") + images.append(image_path) + return images, labels + + @classmethod + def parse_annotation_file(cls, annotation_file: Path): + with open(annotation_file, "r") as f: + lines = f.readlines() + + coords = [] + classes = [] + difficult = [] + + for line in lines: + parts = line.strip().split(" ") + if len(parts) != 10: + raise ValueError(f"Invalid number of parts in line: {line}") + + x1, y1, x2, y2, x3, y3, x4, y4 = map(float, parts[:8]) + coords.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) + classes.append(parts[8]) + difficult.append(int(parts[9])) + + return np.array(coords, dtype=np.float32).reshape(-1, 4, 2), np.array(classes, dtype=np.object_), np.array(difficult, dtype=int) + + @classmethod + def chip_image(cls, img, coords, classes, difficult, tile_size: Tuple[int, int], tile_step: Tuple[int, int], min_visibility: float, min_area: int): + """ + Chip an image and get relative coordinates and classes. Bounding boxes that pass into + multiple chips are clipped: each portion that is in a chip is labeled. For example, + half a building will be labeled if it is cut off in a chip. + + :param img: the image to be chipped in array format + :param coords: an (N,4,2) array of oriented box coordinates for that image + :param classes: an (N,1) array of classes for each bounding box + :param tile_size: an (W,H) tuple indicating width and height of chips + + Output: + An image array of shape (M,W,H,C), where M is the number of chips, + W and H are the dimensions of the image, and C is the number of color + channels. Also returns boxes and classes dictionaries for each corresponding chip. + """ + height, width, _ = img.shape + + tile_size_width, tile_size_height = tile_size + tile_step_width, tile_step_height = tile_step + + total_images = [] + total_boxes = [] + total_classes = [] + total_difficult = [] + + start_x = 0 + end_x = start_x + tile_size_width + + all_areas = np.array(list(cv2.contourArea(cv2.convexHull(poly)) for poly in coords), dtype=np.float32) + + centers = np.mean(coords, axis=1) # [N,2] + + while start_x < width: + start_y = 0 + end_y = start_y + tile_size_height + while start_y < height: + chip = img[start_y:end_y, start_x:end_x, :3] + + # Skipping thin strips that are not useful + # For instance, if image is 1030px wide and our tile size is 1024, that would end up with + # two tiles of [1024, 1024] and [1024, 6] which is not useful at all + if chip.shape[0] > 8 or chip.shape[1] > 8: + + # Filter out boxes that whose bounding box is definitely not in the chip + offset = np.array([start_x, start_y], dtype=np.float32) + boxes_with_offset = coords - offset.reshape(1, 1, 2) + centers_with_offset = centers - offset.reshape(1, 2) + + cond1 = (centers_with_offset >= 0).all(axis=1) + cond2 = (centers_with_offset[:, 0] < chip.shape[1]) & (centers_with_offset[:, 1] < chip.shape[0]) + rboxes_inside_chip = cond1 & cond2 + + visible_coords = boxes_with_offset[rboxes_inside_chip] + visible_classes = classes[rboxes_inside_chip] + visible_difficult = difficult[rboxes_inside_chip] + visible_areas = all_areas[rboxes_inside_chip] + + out_clipped = np.stack( + ( + np.clip(visible_coords[:, :, 0], 0, chip.shape[1]), + np.clip(visible_coords[:, :, 1], 0, chip.shape[0]), + ), + axis=2, + ) + areas_clipped = np.array(list(cv2.contourArea(cv2.convexHull(c)) for c in out_clipped), dtype=np.float32) + + visibility_fraction = areas_clipped / (visible_areas + 1e-6) + visibility_mask = visibility_fraction >= min_visibility + min_area_mask = areas_clipped >= min_area + + visible_coords = visible_coords[visibility_mask & min_area_mask] + visible_classes = visible_classes[visibility_mask & min_area_mask] + visible_difficult = visible_difficult[visibility_mask & min_area_mask] + + total_boxes.append(visible_coords) + total_classes.append(visible_classes) + total_difficult.append(visible_difficult) + + if chip.shape[0] < tile_size_height or chip.shape[1] < tile_size_width: + chip = cv2.copyMakeBorder( + chip, + top=0, + left=0, + bottom=tile_size_height - chip.shape[0], + right=tile_size_width - chip.shape[1], + value=0, + borderType=cv2.BORDER_CONSTANT, + ) + total_images.append(chip) + + start_y += tile_step_height + end_y += tile_step_height + + start_x += tile_step_width + end_x += tile_step_width + + return total_images, total_boxes, total_classes, total_difficult + + @classmethod + def slice_dataset_into_tiles( + cls, + data_dir, + output_dir, + input_ann_subdir_name, + output_ann_subdir_name, + tile_size: int, + tile_step: int, + scale_factors: Tuple, + min_visibility, + min_area, + num_workers: int, + output_image_ext=".jpg", + ) -> None: + """ + Helper function to slice a dataset into tiles of a given size and step. + Slicing dataset is a necessary preparation step in order to train a model on DOTA dataset. + This method is not meant to be used directly by user. + Please check src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py for an example usage. + + :param data_dir: Input DOTA dataset directory + :param output_dir: Output location of the sliced dataset + :param input_ann_subdir_name: Name of the annotations subdirectory. Usually 'ann' + :param output_ann_subdir_name: Name of the output annotations subdirectory. Usually 'ann-obb' + This is to distinguish the OBB vs HBB annotations + :param tile_size: Size of the image tile (pixels) + :param tile_step: Step size in pixels between consecutive tiles + :param scale_factors: Tuple of scale factors to apply to the image before slicing + For training a range of scale factors is usually used, e.g. (0.75, 1, 1.25) + For validation a single scale factor is used, e.g. (1,) + :param min_visibility: A faction of the bounding box that must be visible in the tile for it to be included + :param min_area: Minimum area of the bounding box that must be visible in the tile for it to be included + :param num_workers: Number of workers to use for multiprocessing + :param output_image_ext: Extension of the output image files. Default value is '.jpg' which is more + size efficient than '.png'. + """ + data_dir = Path(data_dir) + input_images_dir = data_dir / "images" + input_ann_dir = data_dir / input_ann_subdir_name + images, labels = cls.find_images_and_labels(input_images_dir, input_ann_dir, ".png") + + output_dir = Path(output_dir) + output_images_dir = output_dir / "images" + output_ann_dir = output_dir / output_ann_subdir_name + + output_images_dir.mkdir(parents=True, exist_ok=True) + output_ann_dir.mkdir(parents=True, exist_ok=True) + + with multiprocessing.Pool(num_workers) as wp: + payload = [(image_path, ann_path, scale) for image_path, ann_path in zip(images, labels) for scale in scale_factors] + + worker_fn = partial( + cls._worker_fn, + tile_size=tile_size, + tile_step=tile_step, + min_visibility=min_visibility, + min_area=min_area, + output_images_dir=output_images_dir, + output_ann_dir=output_ann_dir, + output_image_ext=output_image_ext, + ) + for _ in tqdm(wp.imap_unordered(worker_fn, payload), total=len(payload)): + pass + + @classmethod + def _worker_fn(cls, args, tile_size, tile_step, min_visibility, min_area, output_images_dir, output_ann_dir, output_image_ext): + image_path, ann_path, scale = args + image = cv2.imread(str(image_path)) + coords, classes, difficult = cls.parse_annotation_file(ann_path) + scaled_image = cv2.resize(image, (0, 0), fx=scale, fy=scale) + + image_tiles, total_boxes, total_classes, total_difficult = cls.chip_image( + scaled_image, + coords * scale, + classes, + difficult, + tile_size=(tile_size, tile_size), + tile_step=(tile_step, tile_step), + min_visibility=min_visibility, + min_area=min_area, + ) + num_tiles = len(image_tiles) + + for i in range(num_tiles): + tile_image = image_tiles[i] + tile_boxes = total_boxes[i] + tile_classes = total_classes[i] + tile_difficult = total_difficult[i] + + tile_image_path = output_images_dir / f"{ann_path.stem}_{scale:.3f}_{i:06d}{output_image_ext}" + tile_label_path = output_ann_dir / f"{ann_path.stem}_{scale:.3f}_{i:06d}.txt" + + with tile_label_path.open("w") as f: + for poly, category, diff in zip(tile_boxes, tile_classes, tile_difficult): + f.write( + f"{poly[0, 0]:.2f} {poly[0, 1]:.2f} {poly[1, 0]:.2f} {poly[1, 1]:.2f} {poly[2, 0]:.2f} {poly[2, 1]:.2f} {poly[3, 0]:.2f} {poly[3, 1]:.2f} {category} {diff}\n" # noqa + ) + + cv2.imwrite(str(tile_image_path), tile_image) diff --git a/src/super_gradients/training/metrics/__init__.py b/src/super_gradients/training/metrics/__init__.py index 7916b13cd4..ba7d96e188 100755 --- a/src/super_gradients/training/metrics/__init__.py +++ b/src/super_gradients/training/metrics/__init__.py @@ -17,6 +17,7 @@ DepthRMSE, DepthMSLE, ) +from .obb_detection_metrics import OBBDetectionMetrics_050_095, OBBDetectionMetrics_050, OBBDetectionMetrics __all__ = [ "METRICS", @@ -45,4 +46,7 @@ "DepthMSE", "DepthRMSE", "DepthMSLE", + "OBBDetectionMetrics_050_095", + "OBBDetectionMetrics_050", + "OBBDetectionMetrics", ] diff --git a/src/super_gradients/training/metrics/detection_metrics.py b/src/super_gradients/training/metrics/detection_metrics.py index 10beb2a6c8..7125e7a569 100755 --- a/src/super_gradients/training/metrics/detection_metrics.py +++ b/src/super_gradients/training/metrics/detection_metrics.py @@ -1,4 +1,6 @@ import collections +import numbers +import typing from typing import Dict, Optional, Union, Tuple, List import numpy as np @@ -109,11 +111,15 @@ def __init__( if isinstance(iou_thres, IouThreshold): self.iou_thresholds = iou_thres.to_tensor() - if isinstance(iou_thres, tuple): + elif isinstance(iou_thres, tuple): low, high = iou_thres self.iou_thresholds = IouThreshold.from_bounds(low, high) - else: - self.iou_thresholds = torch.tensor([iou_thres]) + elif isinstance(iou_thres, typing.Iterable): + self.iou_thresholds = torch.tensor(list(iou_thres)).float() + elif isinstance(iou_thres, np.ndarray): + self.iou_thresholds = torch.from_numpy(iou_thres).float() + elif isinstance(iou_thres, numbers.Number): + self.iou_thresholds = torch.tensor([iou_thres], dtype=torch.float32) self.map_str = "mAP" + self._get_range_str() self.include_classwise_ap = include_classwise_ap diff --git a/src/super_gradients/training/metrics/obb_detection_metrics.py b/src/super_gradients/training/metrics/obb_detection_metrics.py new file mode 100644 index 0000000000..4bf366751f --- /dev/null +++ b/src/super_gradients/training/metrics/obb_detection_metrics.py @@ -0,0 +1,474 @@ +import typing +from typing import Optional, Union, Tuple, List + +import cv2 +import torch +import torchvision.ops +from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.registry.registry import register_metric +from super_gradients.module_interfaces.obb_predictions import OBBPredictions +from super_gradients.training.datasets.data_formats.obb.cxcywhr import cxcywhr_to_poly, poly_to_xyxy +from super_gradients.training.transforms.obb import OBBSample +from super_gradients.training.utils.detection_utils import ( + DetectionMatching, + get_top_k_idx_per_cls, +) +from super_gradients.training.utils.detection_utils import IouThreshold +from torch import Tensor + +from .detection_metrics import DetectionMetrics + +logger = get_logger(__name__) + +if typing.TYPE_CHECKING: + from super_gradients.module_interfaces import AbstractOBBPostPredictionCallback + + +class OBBIoUMatching(DetectionMatching): + """ + IoUMatching is a subclass of DetectionMatching that uses Intersection over Union (IoU) + for matching detections in object detection models. + """ + + def __init__(self, iou_thresholds: torch.Tensor): + """ + Initializes the IoUMatching instance with IoU thresholds. + + :param iou_thresholds: (torch.Tensor) The IoU thresholds for matching. + """ + self.iou_thresholds = iou_thresholds + + def get_thresholds(self) -> torch.Tensor: + """ + Returns the IoU thresholds used for detection matching. + + :return: (torch.Tensor) The IoU thresholds. + """ + return self.iou_thresholds + + @classmethod + def pairwise_cxcywhr_iou_accurate(cls, obb1: Tensor, obb2: Tensor) -> Tensor: + """ + Calculate the pairwise IoU between oriented bounding boxes. + + :param obb1: First set of boxes. Tensor of shape (N, 5) representing ground truth boxes, with cxcywhr format. + :param obb2: Second set of boxes. Tensor of shape (M, 5) representing predicted boxes, with cxcywhr format. + :return: A tensor of shape (N, M) representing IoU scores between corresponding boxes. + """ + import numpy as np + + if len(obb1.shape) != 2 or len(obb2.shape) != 2: + raise ValueError("Expected obb1 and obb2 to be 2D tensors") + + poly1 = cxcywhr_to_poly(obb1.detach().cpu().numpy()) + poly2 = cxcywhr_to_poly(obb2.detach().cpu().numpy()) + + # Compute bounding boxes from polygons + xyxy1 = poly_to_xyxy(poly1) + xyxy2 = poly_to_xyxy(poly2) + bbox_iou = torchvision.ops.box_iou(torch.from_numpy(xyxy1), torch.from_numpy(xyxy2)).numpy() + iou = np.zeros((poly1.shape[0], poly2.shape[0])) + + # We use bounding box IoU to filter out pairs of polygons that has no intersection + # Only polygons that have non-zero bounding box IoU are considered for polygon-polygon IoU calculation + nz_indexes = np.nonzero(bbox_iou) + for i, j in zip(*nz_indexes): + iou[i, j] = cls.polygon_polygon_iou(poly1[i], poly2[j]) + return torch.from_numpy(iou).to(obb1.device) + + @classmethod + def polygon_polygon_iou(cls, gt_rect, pred_rect): + """ + Performs intersection over union calculation for two polygons using integer coordinates of + vertices. This is a workaround for a bug in cv2.intersectConvexConvex function that returns + incorrect results for polygons with float coordinates that are almost identical + + Args: + gt_rect: [4,2] + pred_rect: [4,2] + + Returns: + + """ + # Multiply by 1000 to account for rounding errors when going from float to int. 1000 should be enough to get rid of any rounding errors + # It has no effect on IOU since it is scale-less + pred_rect_int = (pred_rect * 1000).astype(int) + gt_rect_int = (gt_rect * 1000).astype(int) + + try: + intersection, _ = cv2.intersectConvexConvex(pred_rect_int, gt_rect_int, handleNested=True) + except Exception as e: + raise RuntimeError( + "Detected error in cv2.intersectConvexConvex while calculating polygon_polygon_iou\n" + f"pred_rect_int: {pred_rect_int}\n" + f"gt_rect_int: {gt_rect_int}" + ) from e + + gt_area = cv2.contourArea(gt_rect_int) + pred_area = cv2.contourArea(pred_rect_int) + + # Second condition is to avoid division by zero when predicted polygon is degenerate (point or line) + if intersection > 0 and pred_area > 0: + union = gt_area + pred_area - intersection + if union == 0: + raise ZeroDivisionError( + f"ZeroDivisionError at polygon_polygon_iou_int\n" + f"Intersection is {intersection}\n" + f"Union is {union}\n" + f"gt_rect_int {gt_rect_int}\n" + f"pred_rect_int {pred_rect_int}" + ) + return intersection / max(union, 1e-7) + + return 0 + + def compute_targets( + self, + preds_cxcywhr: torch.Tensor, + preds_cls: torch.Tensor, + targets_cxcywhr: torch.Tensor, + targets_cls: torch.Tensor, + preds_matched: torch.Tensor, + targets_matched: torch.Tensor, + preds_idx_to_use: torch.Tensor, + ) -> torch.Tensor: + """ + Computes the matching targets based on IoU for regular scenarios. + + :param preds_cxcywhr: (torch.Tensor) Predicted bounding boxes in CXCYWHR format. + :param preds_cls: (torch.Tensor) Predicted classes. + :param targets_cxcywhr: (torch.Tensor) Target bounding boxes in CXCYWHR format. + :param targets_cls: (torch.Tensor) Target classes. + :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched. + :param targets_matched: (torch.Tensor) Tensor indicating which targets are matched. + :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use. + :return: (torch.Tensor) Computed matching targets. + """ + # shape = (n_preds x n_targets) + iou = self.pairwise_cxcywhr_iou_accurate(preds_cxcywhr[preds_idx_to_use], targets_cxcywhr) + + # Fill IoU values at index (i, j) with 0 when the prediction (i) and target(j) are of different class + # Filling with 0 is equivalent to ignore these values since with want IoU > iou_threshold > 0 + cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != targets_cls.view(1, -1) + iou[cls_mismatch] = 0 + + # The matching priority is first detection confidence and then IoU value. + # The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou. + sorted_iou, target_sorted = iou.sort(descending=True, stable=True) + + # Only iterate over IoU values higher than min threshold to speed up the process + for pred_selected_i, target_sorted_i in (sorted_iou > self.iou_thresholds[0]).nonzero(as_tuple=False): + # pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes + pred_i = preds_idx_to_use[pred_selected_i] + target_i = target_sorted[pred_selected_i, target_sorted_i] + + # Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold + is_iou_above_threshold = sorted_iou[pred_selected_i, target_sorted_i] > self.iou_thresholds + + # Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold + are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :]) + + # Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold + are_candidates_good = torch.logical_and(is_iou_above_threshold, are_candidates_free) + + # For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True ) + # fill the matching placeholders with True + targets_matched[target_i, are_candidates_good] = True + preds_matched[pred_i, are_candidates_good] = True + + # When all the targets are matched with a prediction for every IoU Threshold, stop. + if targets_matched.all(): + break + + return preds_matched + + def compute_crowd_targets( + self, + preds_cxcywhr: torch.Tensor, + preds_cls: torch.Tensor, + crowd_targets_cls: torch.Tensor, + crowd_targets_cxcywhr: torch.Tensor, + preds_matched: torch.Tensor, + preds_to_ignore: torch.Tensor, + preds_idx_to_use: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the matching targets based on IoU for crowd scenarios. + + :param preds_cxcywhr: (torch.Tensor) Predicted bounding boxes in CXCYWHR format. + :param preds_cls: (torch.Tensor) Predicted classes. + :param crowd_targets_cls: (torch.Tensor) Crowd target classes. + :param crowd_targets_cxcywhr: (torch.Tensor) Crowd target bounding boxes in CXCYWHR format. + :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched. + :param preds_to_ignore: (torch.Tensor) Tensor indicating which predictions to ignore. + :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use. + :return: (Tuple[torch.Tensor, torch.Tensor]) Computed matching targets for crowd scenarios. + """ + # Crowd targets can be matched with many predictions. + # Therefore, for every prediction we just need to check if it has IoU large enough with any crowd target. + + # shape = (n_preds_to_use x n_crowd_targets) + iou = self.pairwise_cxcywhr_iou_accurate(preds_cxcywhr[preds_idx_to_use], crowd_targets_cxcywhr) + + # Fill IoA values at index (i, j) with 0 when the prediction (i) and target(j) are of different class + # Filling with 0 is equivalent to ignore these values since with want IoA > threshold > 0 + cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != crowd_targets_cls.view(1, -1) + iou[cls_mismatch] = 0 + + # For each prediction, we keep it's highest score with any crowd target (of same class) + # shape = (n_preds_to_use) + best_ioa, _ = iou.max(1) + + # If a prediction has IoA higher than threshold (with any target of same class), then there is a match + # shape = (n_preds_to_use x iou_thresholds) + is_matching_with_crowd = best_ioa.view(-1, 1) > self.iou_thresholds.view(1, -1) + + preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd) + + return preds_matched, preds_to_ignore + + +def compute_obb_detection_matching( + preds: OBBPredictions, + targets: OBBSample, + matching_strategy: OBBIoUMatching, + top_k: Optional[int], + output_device: Optional[torch.device] = None, +) -> Tuple: + """ + Match predictions (NMS output) and the targets (ground truth) with respect to metric and confidence score + for a given image. + :param preds: Tensor of shape (num_img_predictions, 6) + format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size + :param targets: targets for this image of shape (num_img_targets, 6) + format: (label, cx, cy, w, h) where cx,cy,w,h + :param top_k: Number of predictions to keep per class, ordered by confidence score + :param matching_strategy: Method to match predictions to ground truth targets: IoU, distance based + + :return: + :preds_matched: Tensor of shape (num_img_predictions, n_thresholds) + True when prediction (i) is matched with a target with respect to the (j)th threshold + :preds_to_ignore: Tensor of shape (num_img_predictions, n_thresholds) + True when prediction (i) is matched with a crowd target with respect to the (j)th threshold + :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction + :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction + :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target + """ + num_thresholds = len(matching_strategy.get_thresholds()) + device = preds.scores.device + num_preds = len(preds.rboxes_cxcywhr) + + targets_box = torch.from_numpy(targets.rboxes_cxcywhr[~targets.is_crowd]).to(device) + targets_cls = torch.from_numpy(targets.labels[~targets.is_crowd]).to(device) + + crowd_target_box = torch.from_numpy(targets.rboxes_cxcywhr[targets.is_crowd]).to(device) + crowd_targets_cls = torch.from_numpy(targets.labels[targets.is_crowd]).to(device) + + num_targets = len(targets_box) + num_crowd_targets = len(crowd_target_box) + + if num_preds == 0: + preds_matched = torch.zeros((0, num_thresholds), dtype=torch.bool, device=device) + preds_to_ignore = torch.zeros((0, num_thresholds), dtype=torch.bool, device=device) + preds_scores = torch.tensor([], dtype=torch.float32, device=device) + preds_cls = torch.tensor([], dtype=torch.float32, device=device) + targets_cls = targets_cls.to(device=device) + return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls + + preds_scores = preds.scores + preds_cls = preds.labels + + preds_matched = torch.zeros(num_preds, num_thresholds, dtype=torch.bool, device=device) + targets_matched = torch.zeros(num_targets, num_thresholds, dtype=torch.bool, device=device) + preds_to_ignore = torch.zeros(num_preds, num_thresholds, dtype=torch.bool, device=device) + + # Ignore all but the predictions that were top_k for their class + if top_k is not None: + preds_idx_to_use = get_top_k_idx_per_cls(preds_scores, preds_cls, top_k) + else: + preds_idx_to_use = torch.arange(num_preds, device=device) + + preds_to_ignore[:, :] = True + preds_to_ignore[preds_idx_to_use] = False + + if num_targets > 0 or num_crowd_targets > 0: + if num_targets > 0: + preds_matched = matching_strategy.compute_targets( + preds.rboxes_cxcywhr, preds_cls, targets_box, targets_cls, preds_matched, targets_matched, preds_idx_to_use + ) + + if num_crowd_targets > 0: + preds_matched, preds_to_ignore = matching_strategy.compute_crowd_targets( + preds.rboxes_cxcywhr, preds_cls, crowd_targets_cls, crowd_target_box, preds_matched, preds_to_ignore, preds_idx_to_use + ) + + if output_device is not None: + preds_matched = preds_matched.to(output_device) + preds_to_ignore = preds_to_ignore.to(output_device) + preds_scores = preds_scores.to(output_device) + preds_cls = preds_cls.to(output_device) + targets_cls = targets_cls.to(output_device) + + return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls + + +@register_metric() +class OBBDetectionMetrics(DetectionMetrics): + """ + Metric class for computing F1, Precision, Recall and Mean Average Precision for Oriented Bounding Box (OBB) detection tasks. + + :param num_cls: Number of classes. + :param post_prediction_callback: A post-prediction callback to be applied on net's output prior to the metric computation (NMS). + :param iou_thres: IoU threshold to compute the mAP. + Could be either instance of IouThreshold, a tuple (lower bound, upper_bound) or single scalar. + :param recall_thres: Recall threshold to compute the mAP. + :param score_thres: Score threshold to compute Recall, Precision and F1. + :param top_k_predictions: Number of predictions per class used to compute metrics, ordered by confidence score + :param dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. + :param accumulate_on_cpu: Run on CPU regardless of device used in other parts. + This is to avoid "CUDA out of memory" that might happen on GPU. + :param calc_best_score_thresholds Whether to calculate the best score threshold overall and per class + If True, the compute() function will return a metrics dictionary that not + only includes the average metrics calculated across all classes, + but also the optimal score threshold overall and for each individual class. + :param include_classwise_ap: Whether to include the class-wise average precision in the returned metrics dictionary. + If enabled, output metrics dictionary will look similar to this: + { + 'Precision0.5:0.95': 0.5, + 'Recall0.5:0.95': 0.5, + 'F10.5:0.95': 0.5, + 'mAP0.5:0.95': 0.5, + 'AP0.5:0.95_person': 0.5, + 'AP0.5:0.95_car': 0.5, + 'AP0.5:0.95_bicycle': 0.5, + 'AP0.5:0.95_motorcycle': 0.5, + ... + } + Class names are either provided via the class_names parameter or are generated automatically. + :param class_names: Array of class names. When include_classwise_ap=True, will use these names to make + per-class APs keys in the output metrics dictionary. + If None, will use dummy names `class_{idx}` instead. + :param state_dict_prefix: A prefix to append to the state dict of the metric. A state dict used to synchronize metric in DDP mode. + It was empirically found that if you have two metric classes A and B(A) that has same state key, for + some reason torchmetrics attempts to sync their states all toghether which causes an error. + In this case adding a prefix to the name of the synchronized state seems to help, + but it is still unclear why it happens. + + + """ + + def __init__( + self, + num_cls: int, + post_prediction_callback: "AbstractOBBPostPredictionCallback", + iou_thres: Union[IouThreshold, Tuple[float, float], float], + top_k_predictions: Optional[int] = None, + recall_thres: Tuple[float, ...] = None, + score_thres: Optional[float] = 0.01, + dist_sync_on_step: bool = False, + accumulate_on_cpu: bool = True, + calc_best_score_thresholds: bool = True, + include_classwise_ap: bool = False, + class_names: List[str] = None, + state_dict_prefix: str = "", + ): + super().__init__( + num_cls=num_cls, + post_prediction_callback=post_prediction_callback, + normalize_targets=False, + iou_thres=iou_thres, + top_k_predictions=top_k_predictions, + recall_thres=recall_thres, + score_thres=score_thres, + dist_sync_on_step=dist_sync_on_step, + accumulate_on_cpu=accumulate_on_cpu, + calc_best_score_thresholds=calc_best_score_thresholds, + include_classwise_ap=include_classwise_ap, + class_names=class_names, + state_dict_prefix=state_dict_prefix, + ) + + def update(self, preds, gt_samples: List[OBBSample]) -> None: + """ + Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly. + + :param preds: Raw output of the model, the format might change from one model to another, + but has to fit the input format of the post_prediction_callback (cx,cy,wh) + :param target: Targets for all images of shape (total_num_targets, 6) LABEL_CXCYWH. format: (index, label, cx, cy, w, h) + :param device: Device to run on + :param inputs: Input image tensor of shape (batch_size, n_img, height, width) + :param crowd_targets: Crowd targets for all images of shape (total_num_targets, 6), LABEL_CXCYWH + """ + preds: List[OBBPredictions] = self.post_prediction_callback(preds) + output_device = "cpu" if self.accumulate_on_cpu else None + matching_strategy = OBBIoUMatching(self.iou_thresholds.to(preds[0].scores.device)) + + for pred, trues in zip(preds, gt_samples): + image_mathing = compute_obb_detection_matching( + pred, trues, matching_strategy=matching_strategy, top_k=self.top_k_predictions, output_device=output_device + ) + + accumulated_matching_info = getattr(self, self.state_key) + setattr(self, self.state_key, accumulated_matching_info + [image_mathing]) + + +@register_metric() +class OBBDetectionMetrics_050(OBBDetectionMetrics): + def __init__( + self, + num_cls: int, + post_prediction_callback: "AbstractOBBPostPredictionCallback", + recall_thres: torch.Tensor = None, + score_thres: float = 0.01, + top_k_predictions: Optional[int] = None, + dist_sync_on_step: bool = False, + accumulate_on_cpu: bool = True, + calc_best_score_thresholds: bool = True, + include_classwise_ap: bool = False, + class_names: List[str] = None, + ): + super().__init__( + num_cls=num_cls, + post_prediction_callback=post_prediction_callback, + iou_thres=IouThreshold.MAP_05, + recall_thres=recall_thres, + score_thres=score_thres, + top_k_predictions=top_k_predictions, + dist_sync_on_step=dist_sync_on_step, + accumulate_on_cpu=accumulate_on_cpu, + calc_best_score_thresholds=calc_best_score_thresholds, + include_classwise_ap=include_classwise_ap, + class_names=class_names, + state_dict_prefix="", + ) + + +@register_metric() +class OBBDetectionMetrics_050_095(OBBDetectionMetrics): + def __init__( + self, + num_cls: int, + post_prediction_callback: "AbstractOBBPostPredictionCallback", + recall_thres: torch.Tensor = None, + score_thres: float = 0.01, + top_k_predictions: Optional[int] = None, + dist_sync_on_step: bool = False, + accumulate_on_cpu: bool = True, + calc_best_score_thresholds: bool = True, + include_classwise_ap: bool = False, + class_names: List[str] = None, + ): + super().__init__( + num_cls=num_cls, + post_prediction_callback=post_prediction_callback, + iou_thres=IouThreshold.MAP_05_TO_095, + recall_thres=recall_thres, + score_thres=score_thres, + top_k_predictions=top_k_predictions, + dist_sync_on_step=dist_sync_on_step, + accumulate_on_cpu=accumulate_on_cpu, + calc_best_score_thresholds=calc_best_score_thresholds, + include_classwise_ap=include_classwise_ap, + class_names=class_names, + state_dict_prefix="", + ) diff --git a/src/super_gradients/training/pretrained_models.py b/src/super_gradients/training/pretrained_models.py index 9c938b4a4d..fabed62e72 100644 --- a/src/super_gradients/training/pretrained_models.py +++ b/src/super_gradients/training/pretrained_models.py @@ -69,6 +69,7 @@ "coco": 80, "coco_pose": 17, "cifar10": 10, + "dota2": 18, } DATASET_LICENSES = { @@ -79,4 +80,5 @@ "coco_pose": "https://cocodataset.org/#termsofuse", "cityscapes": "https://www.cs.toronto.edu/~kriz/cifar.html", "objects365": "https://www.objects365.org/download.html", + "dota2": "https://captain-whu.github.io/DOTA/dataset.html", } diff --git a/src/super_gradients/training/processing/__init__.py b/src/super_gradients/training/processing/__init__.py index b189baa2fa..d643eea626 100644 --- a/src/super_gradients/training/processing/__init__.py +++ b/src/super_gradients/training/processing/__init__.py @@ -15,6 +15,7 @@ SegmentationPadShortToCropSize, SegmentationPadToDivisible, ) +from .obb import OBBDetectionAutoPadding from .defaults import get_pretrained_processing_params __all__ = [ @@ -33,5 +34,6 @@ "SegmentationResize", "SegmentationPadShortToCropSize", "SegmentationPadToDivisible", + "OBBDetectionAutoPadding", "get_pretrained_processing_params", ] diff --git a/src/super_gradients/training/processing/defaults.py b/src/super_gradients/training/processing/defaults.py index e8ff6a0e8e..43f4155329 100644 --- a/src/super_gradients/training/processing/defaults.py +++ b/src/super_gradients/training/processing/defaults.py @@ -1,9 +1,11 @@ from super_gradients.training.datasets.datasets_conf import ( COCO_DETECTION_CLASSES_LIST, + DOTA2_DEFAULT_CLASSES_LIST, IMAGENET_CLASSES, CITYSCAPES_DEFAULT_SEGMENTATION_CLASSES_LIST, ) +from .obb import OBBDetectionCenterPadding, OBBDetectionLongestMaxSizeRescale from .processing import ( ComposeProcessing, ReverseImageChannels, @@ -93,6 +95,28 @@ def default_yolo_nas_coco_processing_params() -> dict: return params +def default_yolo_nas_r_dota_processing_params() -> dict: + """Processing parameters commonly used for training YoloNAS on COCO dataset.""" + + image_processor = ComposeProcessing( + [ + ReverseImageChannels(), # Model trained on BGR images + OBBDetectionLongestMaxSizeRescale(output_shape=(1024, 1024)), + OBBDetectionCenterPadding(output_shape=(1024, 1024), pad_value=114), + StandardizeImage(max_value=255.0), + ImagePermute(permutation=(2, 0, 1)), + ] + ) + + params = dict( + class_names=DOTA2_DEFAULT_CLASSES_LIST, + image_processor=image_processor, + iou=0.25, + conf=0.1, + ) + return params + + def default_dekr_coco_processing_params() -> dict: """Processing parameters commonly used for training DEKR on COCO dataset.""" diff --git a/src/super_gradients/training/processing/obb.py b/src/super_gradients/training/processing/obb.py new file mode 100644 index 0000000000..b524585df2 --- /dev/null +++ b/src/super_gradients/training/processing/obb.py @@ -0,0 +1,70 @@ +from typing import Tuple + +import numpy as np +from super_gradients.common.registry import register_processing +from super_gradients.training.transforms.utils import ( + _pad_image, + PaddingCoordinates, + _get_center_padding_coordinates, + _rescale_bboxes, + _get_bottom_right_padding_coordinates, +) +from super_gradients.training.utils.predict import OBBDetectionPrediction +from .processing import AutoPadding, DetectionPadToSizeMetadata, _LongestMaxSizeRescale, RescaleMetadata, _DetectionPadding + + +@register_processing() +class OBBDetectionCenterPadding(_DetectionPadding): + def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates: + return _get_center_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape) + + def postprocess_predictions(self, predictions: OBBDetectionPrediction, metadata: DetectionPadToSizeMetadata) -> OBBDetectionPrediction: + offset = np.array([metadata.padding_coordinates.left, metadata.padding_coordinates.top, 0, 0, 0], dtype=np.float32).reshape(-1, 5) + predictions.rboxes_cxcywhr = predictions.rboxes_cxcywhr - offset + return predictions + + +@register_processing() +class OBBDetectionBottomRightPadding(_DetectionPadding): + def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates: + return _get_bottom_right_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape) + + def postprocess_predictions(self, predictions: OBBDetectionPrediction, metadata: DetectionPadToSizeMetadata) -> OBBDetectionPrediction: + return predictions + + +@register_processing() +class OBBDetectionLongestMaxSizeRescale(_LongestMaxSizeRescale): + def postprocess_predictions(self, predictions: OBBDetectionPrediction, metadata: RescaleMetadata) -> OBBDetectionPrediction: + predictions.rboxes_cxcywhr = _rescale_bboxes( + targets=predictions.rboxes_cxcywhr, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w) + ) + return predictions + + +@register_processing() +class OBBDetectionAutoPadding(AutoPadding): + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPadToSizeMetadata]: + padding_coordinates = self._get_padding_params(input_shape=image.shape[:2]) # HWC -> (H, W) + processed_image = _pad_image(image=image, padding_coordinates=padding_coordinates, pad_value=self.pad_value) + return processed_image, DetectionPadToSizeMetadata(padding_coordinates=padding_coordinates) + + def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates: + input_height, input_width = input_shape + height_modulo, width_modulo = self.shape_multiple + + # Calculate necessary padding to reach the modulo + padded_height = ((input_height + height_modulo - 1) // height_modulo) * height_modulo + padded_width = ((input_width + width_modulo - 1) // width_modulo) * width_modulo + + padding_top = 0 # No padding at the top + padding_left = 0 # No padding on the left + padding_bottom = padded_height - input_height + padding_right = padded_width - input_width + + return PaddingCoordinates(top=padding_top, left=padding_left, bottom=padding_bottom, right=padding_right) + + def postprocess_predictions(self, predictions: OBBDetectionPrediction, metadata: DetectionPadToSizeMetadata) -> OBBDetectionPrediction: + offset = np.array([metadata.padding_coordinates.left, metadata.padding_coordinates.top, 0, 0, 0], dtype=np.float32).reshape(-1, 5) + predictions.rboxes_cxcywhr = predictions.rboxes_cxcywhr + offset + return predictions diff --git a/src/super_gradients/training/utils/callbacks/__init__.py b/src/super_gradients/training/utils/callbacks/__init__.py index f7be3e0a8c..53b292f89d 100644 --- a/src/super_gradients/training/utils/callbacks/__init__.py +++ b/src/super_gradients/training/utils/callbacks/__init__.py @@ -25,6 +25,7 @@ from super_gradients.common.object_names import Callbacks, LRSchedulers, LRWarmups from super_gradients.common.registry.registry import CALLBACKS, LR_SCHEDULERS_CLS_DICT, LR_WARMUP_CLS_DICT from super_gradients.training.utils.callbacks.extreme_batch_pose_visualization_callback import ExtremeBatchPoseEstimationVisualizationCallback +from .extreme_batch_obb_visualization_callback import ExtremeBatchOBBVisualizationCallback __all__ = [ "Callback", @@ -60,4 +61,5 @@ "PPYoloETrainingStageSwitchCallback", "TimerCallback", "ExtremeBatchPoseEstimationVisualizationCallback", + "ExtremeBatchOBBVisualizationCallback", ] diff --git a/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py b/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py new file mode 100644 index 0000000000..bbaee97d88 --- /dev/null +++ b/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py @@ -0,0 +1,208 @@ +import typing +from typing import Optional, Tuple, List, Union + +import numpy as np +import torch +from torch import Tensor +from torchmetrics import Metric + +from super_gradients.common.registry.registry import register_callback +from super_gradients.training.utils.callbacks.callbacks import ExtremeBatchCaseVisualizationCallback +from super_gradients.training.utils.visualization.obb import OBBVisualization +from super_gradients.training.utils.visualization.utils import generate_color_mapping + +# These imports are required for type hints and not used anywhere else +# Wrapping them under typing.TYPE_CHECKING is a legit way to avoid circular imports +# while still having type hints +if typing.TYPE_CHECKING: + from super_gradients.training.datasets.obb.dota import OBBSample + from super_gradients.module_interfaces.obb_predictions import AbstractOBBPostPredictionCallback, OBBPredictions + + +@register_callback() +class ExtremeBatchOBBVisualizationCallback(ExtremeBatchCaseVisualizationCallback): + """ + ExtremeBatchOBBVisualizationCallback + + Visualizes worst/best batch in an epoch for OBB detection task. + This class visualize horizontally-stacked GT and predicted boxes. + It requires a key 'gt_samples' (List[OBBSample]) to be present in additional_batch_items dictionary. + + Supported models: YoloNAS-R + Supported datasets: DOTAOBBDataset + + Example usage in Yaml config: + + training_hyperparams: + phase_callbacks: + - ExtremeBatchOBBVisualizationCallback: + loss_to_monitor: YoloNASRLoss/loss + max: True + freq: 1 + max_images: 16 + enable_on_train_loader: True + enable_on_valid_loader: True + post_prediction_callback: + _target_: super_gradients.training.models.detection_models.yolo_nas_r.yolo_nas_r_post_prediction_callback.YoloNASRPostPredictionCallback + score_threshold: 0.25 + pre_nms_max_predictions: 4096 + post_nms_max_predictions: 512 + nms_iou_threshold: 0.6 + + :param metric: Metric, will be the metric which is monitored. + + :param metric_component_name: In case metric returns multiple values (as Mapping), + the value at metric.compute()[metric_component_name] will be the one monitored. + + :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). + Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: + + if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: + "/". + + If a single item is returned rather then a tuple: + . + + When there is no such attributes and criterion.forward(..) returns a tuple: + "/"Loss_" + + :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or + the minimum (default=False). + + :param freq: int, epoch frequency to perform all of the above (default=1). + """ + + def __init__( + self, + post_prediction_callback: "AbstractOBBPostPredictionCallback", + class_names: List[str], + class_colors=None, + metric: Optional[Metric] = None, + metric_component_name: Optional[str] = None, + loss_to_monitor: Optional[str] = None, + max: bool = False, + freq: int = 1, + max_images: int = -1, + enable_on_train_loader: bool = False, + enable_on_valid_loader: bool = True, + ): + if class_colors is None: + class_colors = generate_color_mapping(num_classes=len(class_names)) + + super().__init__( + metric=metric, + metric_component_name=metric_component_name, + loss_to_monitor=loss_to_monitor, + max=max, + freq=freq, + enable_on_train_loader=enable_on_train_loader, + enable_on_valid_loader=enable_on_valid_loader, + ) + self.class_names = list(class_names) + self.class_colors = class_colors + self.post_prediction_callback = post_prediction_callback + self.max_images = max_images + + @classmethod + def universal_undo_preprocessing_fn(cls, inputs: torch.Tensor) -> np.ndarray: + """ + A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg. + :param inputs: + :return: + """ + inputs = inputs - inputs.min() + inputs /= inputs.max() + inputs *= 255 + inputs = inputs.to(torch.uint8) + inputs = inputs.cpu().numpy() + inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1) + inputs = np.ascontiguousarray(inputs, dtype=np.uint8) + return inputs + + @classmethod + def _visualize_batch( + cls, + image_tensor: np.ndarray, + rboxes: List[Union[np.ndarray, Tensor]], + labels: List[Union[np.ndarray, Tensor]], + scores: Optional[List[Union[np.ndarray, Tensor]]], + class_colors: List[Tuple[int, int, int]], + class_names: List[str], + ) -> List[np.ndarray]: + """ + Generate list of samples visualization of a batch of images with keypoints and bounding boxes. + + :param image_tensor: Images batch of [Batch Size, 3, H, W] shape with values in [0, 255] range. + The images should be scaled to [0, 255] range and converted to uint8 type beforehead. + :param scores: Keypoint scores. Shape [Num Instances, Num Joints]. Can be None. + :return: List of visualization images. + """ + + out_images = [] + for i in range(image_tensor.shape[0]): + rboxes_i = rboxes[i] + labels_i = labels[i] + scores_i = scores[i] if scores is not None else None + + if torch.is_tensor(rboxes_i): + rboxes_i = rboxes_i.detach().cpu().numpy() + if torch.is_tensor(labels_i): + labels_i = labels_i.detach().cpu().numpy() + if torch.is_tensor(scores_i): + scores_i = scores_i.detach().cpu().numpy() + + res_image = image_tensor[i] + res_image = OBBVisualization.draw_obb( + image=res_image, + rboxes_cxcywhr=rboxes_i, + labels=labels_i, + scores=scores_i, + class_colors=class_colors, + class_names=class_names, + show_confidence=True, + show_labels=True, + ) + + out_images.append(res_image) + + return out_images + + @torch.no_grad() + def process_extreme_batch(self) -> np.ndarray: + """ + Processes the extreme batch, and returns batche of images for visualization - predictions and GT poses stacked horizontally. + + :return: np.ndarray - the visualization of predictions and GT + """ + if "gt_samples" not in self.extreme_additional_batch_items: + raise RuntimeError( + "ExtremeBatchPoseEstimationVisualizationCallback requires 'gt_samples' to be present in additional_batch_items." + "Currently only YoloNASPose model is supported. Old DEKR recipe is not supported at the moment." + ) + + inputs = self.universal_undo_preprocessing_fn(self.extreme_batch) + gt_samples: List["OBBSample"] = self.extreme_additional_batch_items["gt_samples"] + predictions: List["OBBPredictions"] = self.post_prediction_callback(self.extreme_preds) + + images_to_save_preds = self._visualize_batch( + image_tensor=inputs, + rboxes=[p.rboxes_cxcywhr for p in predictions], + labels=[p.labels for p in predictions], + scores=[p.scores for p in predictions], + class_colors=self.class_colors, + class_names=self.class_names, + ) + images_to_save_preds = np.stack(images_to_save_preds) + + images_to_save_gt = self._visualize_batch( + image_tensor=inputs, + rboxes=[gt.rboxes_cxcywhr for gt in gt_samples], + labels=[gt.labels for gt in gt_samples], + scores=None, + class_colors=self.class_colors, + class_names=self.class_names, + ) + images_to_save_gt = np.stack(images_to_save_gt) + + # Stack the predictions and GT images together + return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2) diff --git a/src/super_gradients/training/utils/predict/__init__.py b/src/super_gradients/training/utils/predict/__init__.py index be93ee930c..90065b309b 100644 --- a/src/super_gradients/training/utils/predict/__init__.py +++ b/src/super_gradients/training/utils/predict/__init__.py @@ -17,7 +17,7 @@ VideoPoseEstimationPrediction, ImagesPoseEstimationPrediction, ) - +from .prediction_obb_detection_results import OBBDetectionPrediction, ImageOBBDetectionPrediction, ImagesOBBDetectionPrediction, VideoOBBDetectionPrediction __all__ = [ "Prediction", @@ -39,4 +39,8 @@ "ImageSegmentationPrediction", "ImagesSegmentationPrediction", "VideoSegmentationPrediction", + "OBBDetectionPrediction", + "ImageOBBDetectionPrediction", + "ImagesOBBDetectionPrediction", + "VideoOBBDetectionPrediction", ] diff --git a/src/super_gradients/training/utils/predict/prediction_obb_detection_results.py b/src/super_gradients/training/utils/predict/prediction_obb_detection_results.py new file mode 100644 index 0000000000..ca9b22b717 --- /dev/null +++ b/src/super_gradients/training/utils/predict/prediction_obb_detection_results.py @@ -0,0 +1,428 @@ +import os + +import numpy as np +import cv2 + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Iterator, Iterable, Union + +from super_gradients.training.utils.media.image import save_image, show_image +from super_gradients.training.utils.media.video import show_video_from_frames, save_video +from super_gradients.training.utils.visualization.obb import OBBVisualization +from super_gradients.training.utils.visualization.utils import generate_color_mapping +from tqdm import tqdm + +from .predictions import Prediction +from .prediction_results import ImagePrediction, VideoPredictions, ImagesPredictions + +__all__ = ["OBBDetectionPrediction", "ImageOBBDetectionPrediction", "ImagesOBBDetectionPrediction", "VideoOBBDetectionPrediction"] + + +@dataclass +class OBBDetectionPrediction(Prediction): + """Represents an OBB detection prediction, with bboxes represented in cxycxwhr format.""" + + rboxes_cxcywhr: np.ndarray + confidence: np.ndarray + labels: np.ndarray + + def __init__(self, rboxes_cxcywhr: np.ndarray, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]): + """ + :param rboxes_cxcywhr: Rboxes of [N,5] shape in the CXCYWHR format + :param confidence: Confidence scores for each bounding box + :param labels: Labels for each bounding box. + :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format + """ + self._validate_input(rboxes_cxcywhr, confidence, labels) + self.rboxes_cxcywhr = rboxes_cxcywhr + self.confidence = confidence + self.labels = labels + self.image_shape = image_shape + + def _validate_input(self, rboxes_cxcywhr: np.ndarray, confidence: np.ndarray, labels: np.ndarray) -> None: + n_bboxes, n_confidences, n_labels = rboxes_cxcywhr.shape[0], confidence.shape[0], labels.shape[0] + if n_bboxes != n_confidences != n_labels: + raise ValueError( + f"The number of bounding boxes ({n_bboxes}) does not match the number of confidence scores ({n_confidences}) and labels ({n_labels})." + ) + if rboxes_cxcywhr.shape[1] != 5: + raise ValueError(f"Expected 5 columns in rboxes_cxcywhr, got {rboxes_cxcywhr.shape[1]}.") + + def __len__(self): + return len(self.rboxes_cxcywhr) + + +@dataclass +class ImageOBBDetectionPrediction(ImagePrediction): + """Object wrapping an image and a detection model's prediction. + + :param image: Input image + :param prediction: Predictions of the model + :param class_names: List of the class names to predict + """ + + image: np.ndarray + prediction: OBBDetectionPrediction + class_names: List[str] + + def draw( + self, + box_thickness: Optional[int] = None, + show_confidence: bool = True, + color_mapping: Optional[List[Tuple[int, int, int]]] = None, + target_rboxes: Optional[np.ndarray] = None, + target_class_ids: Optional[np.ndarray] = None, + class_names: Optional[List[str]] = None, + ) -> np.ndarray: + """Draw the predicted bboxes on the image. + + :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size. + :param show_confidence: Whether to show confidence scores on the image. + :param color_mapping: List of tuples representing the colors for each class. + Default is None, which generates a default color mapping based on the number of class names. + :param target_rboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. + Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, + or a list of length len(target_bboxes), containing such arrays. + When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one) + :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape + (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. + :param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of + ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. + Will raise an error if not None and target_bboxes is None. + :param class_names: List of class names to show. By default, is None which shows all classes using during training. + + :return: Image with predicted bboxes. Note that this does not modify the original image. + """ + target_rboxes = target_rboxes if target_rboxes is not None else np.zeros((0, 5)) + target_class_ids = target_class_ids if target_class_ids is not None else np.zeros((0, 1)) + + class_names_to_show = class_names if class_names else self.class_names + class_ids_to_show = [i for i, class_name in enumerate(self.class_names) if class_name in class_names_to_show] + invalid_class_names_to_show = set(class_names_to_show) - set(self.class_names) + if len(invalid_class_names_to_show) > 0: + raise ValueError( + "`class_names` includes class names that the model was not trained on.\n" + f" - Invalid class names: {list(invalid_class_names_to_show)}\n" + f" - Available class names: {list(self.class_names)}" + ) + + plot_targets = target_rboxes is not None and len(target_rboxes) + color_mapping = color_mapping or generate_color_mapping(len(class_names_to_show)) + + keep_mask = np.isin(self.prediction.labels, class_ids_to_show) + image = OBBVisualization.draw_obb( + image=self.image.copy(), + rboxes_cxcywhr=self.prediction.rboxes_cxcywhr[keep_mask], + scores=self.prediction.confidence[keep_mask], + labels=self.prediction.labels[keep_mask], + class_names=class_names_to_show, + class_colors=color_mapping, + show_labels=True, + show_confidence=show_confidence, + thickness=box_thickness, + ) + + if plot_targets: + keep_mask = np.isin(target_class_ids, class_ids_to_show) + target_image = OBBVisualization.draw_obb( + image=self.image.copy(), + rboxes_cxcywhr=target_rboxes[keep_mask], + scores=None, + labels=target_class_ids[keep_mask], + class_names=class_names_to_show, + class_colors=color_mapping, + show_labels=True, + show_confidence=False, + thickness=box_thickness, + ) + + height, width, ch = target_image.shape + new_width, new_height = int(width + width / 20), int(height + height / 8) + + # Crate a new canvas with new width and height. + canvas_image = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255 + canvas_target = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255 + + # New replace the center of canvas with original image + padding_top, padding_left = 60, 10 + + canvas_image[padding_top : padding_top + height, padding_left : padding_left + width] = image + canvas_target[padding_top : padding_top + height, padding_left : padding_left + width] = target_image + + img1 = cv2.putText(canvas_image, "Predictions", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0)) + img2 = cv2.putText(canvas_target, "Ground Truth", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0)) + + image = cv2.hconcat((img1, img2)) + return image + + def show( + self, + box_thickness: Optional[int] = None, + show_confidence: bool = True, + color_mapping: Optional[List[Tuple[int, int, int]]] = None, + target_bboxes: Optional[np.ndarray] = None, + target_bboxes_format: Optional[str] = None, + target_class_ids: Optional[np.ndarray] = None, + class_names: Optional[List[str]] = None, + ) -> None: + """Display the image with predicted bboxes. + + :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size. + :param show_confidence: Whether to show confidence scores on the image. + :param color_mapping: List of tuples representing the colors for each class. + Default is None, which generates a default color mapping based on the number of class names. + :param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. + Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, + or a list of length len(target_bboxes), containing such arrays. + When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one) + :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape + (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. + :param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of + ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. + Will raise an error if not None and target_bboxes is None. + :param class_names: List of class names to show. By default, is None which shows all classes using during training. + """ + image = self.draw( + box_thickness=box_thickness, + show_confidence=show_confidence, + color_mapping=color_mapping, + target_rboxes=target_bboxes, + target_class_ids=target_class_ids, + class_names=class_names, + ) + show_image(image) + + def save( + self, + output_path: str, + box_thickness: Optional[int] = None, + show_confidence: bool = True, + color_mapping: Optional[List[Tuple[int, int, int]]] = None, + target_bboxes: Optional[np.ndarray] = None, + target_class_ids: Optional[np.ndarray] = None, + class_names: Optional[List[str]] = None, + ) -> None: + """Save the predicted bboxes on the images. + + :param output_path: Path to the output video file. + :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size. + :param show_confidence: Whether to show confidence scores on the image. + :param color_mapping: List of tuples representing the colors for each class. + Default is None, which generates a default color mapping based on the number of class names. + :param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. + Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, + or a list of length len(target_bboxes), containing such arrays. + When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one) + :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape + (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. + :param class_names: List of class names to show. By default, is None which shows all classes using during training. + """ + image = self.draw( + box_thickness=box_thickness, + show_confidence=show_confidence, + color_mapping=color_mapping, + target_rboxes=target_bboxes, + target_class_ids=target_class_ids, + class_names=class_names, + ) + save_image(image=image, path=output_path) + + +@dataclass +class ImagesOBBDetectionPrediction(ImagesPredictions): + """Object wrapping the list of image detection predictions. + + :attr _images_prediction_lst: List of the predictions results + """ + + _images_prediction_lst: List[ImageOBBDetectionPrediction] + + def show( + self, + box_thickness: Optional[int] = None, + show_confidence: bool = True, + color_mapping: Optional[List[Tuple[int, int, int]]] = None, + target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None, + target_bboxes_format: Optional[str] = None, + target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None, + class_names: Optional[List[str]] = None, + ) -> None: + """Display the predicted bboxes on the images. + + :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size. + :param show_confidence: Whether to show confidence scores on the image. + :param color_mapping: List of tuples representing the colors for each class. + Default is None, which generates a default color mapping based on the number of class names. + :param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. + Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, + or a list of length len(target_bboxes), containing such arrays. + When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one) + :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape + (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. + :param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of + ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. + Will raise an error if not None and target_bboxes is None. + :param class_names: List of class names to show. By default, is None which shows all classes using during training. + """ + target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids) + + for prediction, target_bbox, target_class_id in zip(self._images_prediction_lst, target_bboxes, target_class_ids): + prediction.show( + box_thickness=box_thickness, + show_confidence=show_confidence, + color_mapping=color_mapping, + target_bboxes=target_bbox, + target_bboxes_format=target_bboxes_format, + target_class_ids=target_class_id, + class_names=class_names, + ) + + def _check_target_args( + self, + target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None, + target_bboxes_format: Optional[str] = None, + target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None, + ): + if not ( + (target_bboxes is None and target_bboxes_format is None and target_class_ids is None) + or (target_bboxes is not None and target_bboxes_format is not None and target_class_ids is not None) + ): + raise ValueError("target_bboxes, target_bboxes_format, and target_class_ids should either all be None or all not None.") + + if isinstance(target_bboxes, np.ndarray): + target_bboxes = [target_bboxes] + if isinstance(target_class_ids, np.ndarray): + target_class_ids = [target_class_ids] + + if target_bboxes is not None and target_class_ids is not None and len(target_bboxes) != len(target_class_ids): + raise ValueError(f"target_bboxes and target_class_ids lengths should be equal, got: {len(target_bboxes)} and {len(target_class_ids)}.") + if target_bboxes is not None and target_class_ids is not None and len(target_bboxes) != len(self._images_prediction_lst): + raise ValueError( + f"target_bboxes and target_class_ids lengths should be equal, to the " + f"amount of images passed to predict(), got: {len(target_bboxes)} and {len(self._images_prediction_lst)}." + ) + if target_bboxes is None: + target_bboxes = [None for _ in range(len(self._images_prediction_lst))] + target_class_ids = [None for _ in range(len(self._images_prediction_lst))] + + return target_bboxes, target_class_ids + + def save( + self, + output_folder: str, + box_thickness: Optional[int] = None, + show_confidence: bool = True, + color_mapping: Optional[List[Tuple[int, int, int]]] = None, + target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None, + target_bboxes_format: Optional[str] = None, + target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None, + class_names: Optional[List[str]] = None, + ) -> None: + """Save the predicted bboxes on the images. + + :param output_folder: Folder path, where the images will be saved. + :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size. + :param show_confidence: Whether to show confidence scores on the image. + :param color_mapping: List of tuples representing the colors for each class. + Default is None, which generates a default color mapping based on the number of class names. + :param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. + Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, + or a list of length len(target_bboxes), containing such arrays. + When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one) + :param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape + (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. + :param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of + ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. + Will raise an error if not None and target_bboxes is None. + :param class_names: List of class names to show. By default, is None which shows all classes using during training. + """ + if output_folder: + os.makedirs(output_folder, exist_ok=True) + + target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids) + + for i, (prediction, target_bbox, target_class_id) in enumerate(zip(self._images_prediction_lst, target_bboxes, target_class_ids)): + image_output_path = os.path.join(output_folder, f"pred_{i}.jpg") + prediction.save( + output_path=image_output_path, + box_thickness=box_thickness, + show_confidence=show_confidence, + color_mapping=color_mapping, + class_names=class_names, + ) + + +@dataclass +class VideoOBBDetectionPrediction(VideoPredictions): + """Object wrapping the list of image detection predictions as a Video. + + :attr _images_prediction_gen: Iterable object of the predictions results + :att fps: Frames per second of the video + """ + + _images_prediction_gen: Iterable[ImageOBBDetectionPrediction] + fps: int + n_frames: int + + def draw( + self, + box_thickness: Optional[int] = None, + show_confidence: bool = True, + color_mapping: Optional[List[Tuple[int, int, int]]] = None, + class_names: Optional[List[str]] = None, + ) -> Iterator[np.ndarray]: + """Draw the predicted bboxes on the images. + + :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size. + :param show_confidence: Whether to show confidence scores on the image. + :param color_mapping: List of tuples representing the colors for each class. + Default is None, which generates a default color mapping based on the number of class names. + :param class_names: List of class names to show. By default, is None which shows all classes using during training. + :return: Iterable object of images with predicted bboxes. Note that this does not modify the original image. + """ + + for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"): + yield result.draw( + box_thickness=box_thickness, + show_confidence=show_confidence, + color_mapping=color_mapping, + class_names=class_names, + ) + + def show( + self, + box_thickness: Optional[int] = None, + show_confidence: bool = True, + color_mapping: Optional[List[Tuple[int, int, int]]] = None, + class_names: Optional[List[str]] = None, + ) -> None: + """Display the predicted bboxes on the images. + + :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size. + :param show_confidence: Whether to show confidence scores on the image. + :param color_mapping: List of tuples representing the colors for each class. + Default is None, which generates a default color mapping based on the number of class names. + :param class_names: List of class names to show. By default, is None which shows all classes using during training. + """ + frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping, class_names=class_names) + show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps) + + def save( + self, + output_path: str, + box_thickness: Optional[int] = None, + show_confidence: bool = True, + color_mapping: Optional[List[Tuple[int, int, int]]] = None, + class_names: Optional[List[str]] = None, + ) -> None: + """Save the predicted bboxes on the images. + + :param output_path: Path to the output video file. + :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size. + :param show_confidence: Whether to show confidence scores on the image. + :param color_mapping: List of tuples representing the colors for each class. + Default is None, which generates a default color mapping based on the number of class names. + :param class_names: List of class names to show. By default, is None which shows all classes using during training. + """ + frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping, class_names=class_names) + save_video(output_path=output_path, frames=frames, fps=self.fps)