Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1047 predict od with labels #1365

Merged
merged 21 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from super_gradients.common.object_names import Models
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
from super_gradients.training import models
from pathlib import Path
from super_gradients.training.datasets import COCODetectionDataset

# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
mini_coco_data_dir = str(Path(__file__).parent.parent.parent.parent.parent / "tests" / "data" / "tinycoco")

dataset = COCODetectionDataset(
data_dir=mini_coco_data_dir, subdir="images/val2017", json_file="instances_val2017.json", input_dim=None, transforms=[], cache_annotations=False
)

# the loaded images are np.ndarrays images of shape (H,W,3)
# the loaded targets are np.ndarrays of shape (num_boxes,x1,y1,x2,y2,class_id)
image1, target1, _ = dataset[0]
image2, target2, _ = dataset[1]

# images from COCODetectionDataset are RGB and images as np.ndarrays are expected to be BGR
image2 = image2[:, :, ::-1]
image1 = image1[:, :, ::-1]

predictions = model.predict([image1, image2])
predictions.show(target_bboxes=[target1[:, :4], target2[:, :4]], target_class_ids=[target1[:, 4], target2[:, 4]], target_bboxes_format="xyxy")
predictions.save(
output_folder="", target_bboxes=[target1[:, :4], target2[:, :4]], target_class_ids=[target1[:, 4], target2[:, 4]], target_bboxes_format="xyxy"
) # Save in working directory
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def __init__(

@torch.jit.ignore
def cache_anchors(self, input_size: Tuple[int, int]):
b, c, h, w = input_size
self.eval_size = (h, w)
self.eval_size = list(input_size)[-2:]
device = infer_model_device(self.pred_cls)
dtype = infer_model_dtype(self.pred_cls)
anchor_points, stride_tensor = self._generate_anchors(dtype=dtype, device=device)
Expand Down
225 changes: 213 additions & 12 deletions src/super_gradients/training/utils/predict/prediction_results.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Tuple, Iterator
from typing import List, Optional, Tuple, Iterator, Union

import cv2
import numpy as np

from super_gradients.common.factories.bbox_format_factory import BBoxFormatFactory
from super_gradients.training.utils.media.image import show_image, save_image
from super_gradients.training.utils.media.video import show_video_from_frames, save_video
from super_gradients.training.utils.visualization.detection import draw_bbox
from super_gradients.training.utils.visualization.classification import draw_label

from super_gradients.training.utils.visualization.utils import generate_color_mapping
from .predictions import Prediction, DetectionPrediction, ClassificationPrediction
from ...datasets.data_formats.bbox_formats import convert_bboxes


@dataclass
Expand Down Expand Up @@ -102,22 +105,58 @@ class ImageDetectionPrediction(ImagePrediction):
prediction: DetectionPrediction
class_names: List[str]

def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> np.ndarray:
def draw(
self,
box_thickness: int = 2,
show_confidence: bool = True,
color_mapping: Optional[List[Tuple[int, int, int]]] = None,
target_bboxes: Optional[np.ndarray] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Draw the predicted bboxes on the image.

:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.

:param target_bboxes: Optional[np.ndarray], ground truth bounding boxes represented as an np.ndarray of shape
(image_i_object_count, 4). When not None, will plot the predictions and the ground truth bounding boxes side
by side (i.e 2 images stitched as one). (default=None).

:param target_class_ids: Optional[np.ndarray], ground truth target class indices
represented as an np.ndarray of shape (object_count). (default=None).

:param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh',
'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Ignored if not
None and target_bboxes is None.

:return: Image with predicted bboxes. Note that this does not modify the original image.

"""
image = self.image.copy()

target_bboxes = target_bboxes if target_bboxes is not None else np.zeros((0, 4))
target_class_ids = target_class_ids if target_class_ids is not None else np.zeros((0, 1))
bbox_format_factory = BBoxFormatFactory()
if len(target_bboxes):
target_bboxes_xyxy = convert_bboxes(
bboxes=target_bboxes,
image_shape=self.prediction.image_shape,
source_format=bbox_format_factory.get(target_bboxes_format),
target_format=bbox_format_factory.get("xyxy"),
inplace=False,
)
else:
target_bboxes_xyxy = target_bboxes

plot_targets = any([len(tbbx) > 0 for tbbx in target_bboxes_xyxy])
color_mapping = color_mapping or generate_color_mapping(len(self.class_names))

for pred_i in np.argsort(self.prediction.confidence):
class_id = int(self.prediction.labels[pred_i])
score = "" if not show_confidence else str(round(self.prediction.confidence[pred_i], 2))

image = draw_bbox(
image=image,
title=f"{self.class_names[class_id]} {score}",
Expand All @@ -129,29 +168,115 @@ def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mappi
y2=int(self.prediction.bboxes_xyxy[pred_i, 3]),
)

if plot_targets:
target_image = self.image.copy()
for target_idx in range(len(target_bboxes_xyxy)):
class_id = int(target_class_ids[target_idx])
target_image = draw_bbox(
image=target_image,
title=f"{self.class_names[class_id]}",
color=color_mapping[class_id],
box_thickness=box_thickness,
x1=int(target_bboxes_xyxy[target_idx, 0]),
y1=int(target_bboxes_xyxy[target_idx, 1]),
x2=int(target_bboxes_xyxy[target_idx, 2]),
y2=int(target_bboxes_xyxy[target_idx, 3]),
)

height, width, ch = target_image.shape
new_width, new_height = int(width + width / 20), int(height + height / 8)

# Crate a new canvas with new width and height.
canvas_image = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255
canvas_target = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255

# New replace the center of canvas with original image
padding_top, padding_left = 60, 10

canvas_image[padding_top : padding_top + height, padding_left : padding_left + width] = image
canvas_target[padding_top : padding_top + height, padding_left : padding_left + width] = target_image

img1 = cv2.putText(canvas_image, "Predictions", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0))
img2 = cv2.putText(canvas_target, "Ground Truth", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0))

image = cv2.hconcat((img1, img2))
return image

def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
def show(
self,
box_thickness: int = 2,
show_confidence: bool = True,
color_mapping: Optional[List[Tuple[int, int, int]]] = None,
target_bboxes: Optional[np.ndarray] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[np.ndarray] = None,
) -> None:

"""Display the image with predicted bboxes.

:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.

:param target_bboxes: Optional[np.ndarray], ground truth bounding boxes represented as an np.ndarray of shape
(image_i_object_count, 4). When not None, will plot the predictions and the ground truth bounding boxes side
by side (i.e 2 images stitched as one). (default=None).

:param target_class_ids: Optional[np.ndarray], ground truth target class indices
represented as an np.ndarray of shape (object_count). (default=None).

:param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh',
'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Ignored if not
None and target_bboxes is None.
"""
image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
image = self.draw(
box_thickness=box_thickness,
show_confidence=show_confidence,
color_mapping=color_mapping,
target_bboxes=target_bboxes,
target_bboxes_format=target_bboxes_format,
target_class_ids=target_class_ids,
)
show_image(image)

def save(self, output_path: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
def save(
self,
output_path: str,
box_thickness: int = 2,
show_confidence: bool = True,
color_mapping: Optional[List[Tuple[int, int, int]]] = None,
target_bboxes: Optional[np.ndarray] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[np.ndarray] = None,
) -> None:
"""Save the predicted bboxes on the images.

:param output_path: Path to the output video file.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.

:param target_bboxes: Optional[np.ndarray], ground truth bounding boxes represented as an np.ndarray of shape
(image_i_object_count, 4). When not None, will plot the predictions and the ground truth bounding boxes side
by side (i.e 2 images stitched as one). (default=None).

:param target_class_ids: Optional[np.ndarray], ground truth target class indices
represented as an np.ndarray of shape (object_count). (default=None).

:param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh',
'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Ignored if not
None and target_bboxes is None.
"""
image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
image = self.draw(
box_thickness=box_thickness,
show_confidence=show_confidence,
color_mapping=color_mapping,
target_bboxes=target_bboxes,
target_bboxes_format=target_bboxes_format,
target_class_ids=target_class_ids,
)
save_image(image=image, path=output_path)


Expand Down Expand Up @@ -245,19 +370,83 @@ class ImagesDetectionPrediction(ImagesPredictions):

_images_prediction_lst: List[ImageDetectionPrediction]

def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
def show(
self,
box_thickness: int = 2,
show_confidence: bool = True,
color_mapping: Optional[List[Tuple[int, int, int]]] = None,
target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
) -> None:
"""Display the predicted bboxes on the images.

:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
:param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape
(image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one).

:param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
(image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays (default=None).

:param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh',
'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an
error if not None and target_bboxes is None.
"""
for prediction in self._images_prediction_lst:
prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids)

for prediction, target_bbox, target_class_id in zip(self._images_prediction_lst, target_bboxes, target_class_ids):
prediction.show(
box_thickness=box_thickness,
show_confidence=show_confidence,
color_mapping=color_mapping,
target_bboxes=target_bbox,
target_bboxes_format=target_bboxes_format,
target_class_ids=target_class_id,
)

def _check_target_args(
self,
target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
):
if not (
(target_bboxes is None and target_bboxes_format is None and target_class_ids is None)
or (target_bboxes is not None and target_bboxes_format is not None and target_class_ids is not None)
):
raise ValueError("target_bboxes, target_bboxes_format, and target_class_ids should either all be None or all not None.")

if isinstance(target_bboxes, np.ndarray):
target_bboxes = [target_bboxes]
if isinstance(target_class_ids, np.ndarray):
target_class_ids = [target_class_ids]

if target_bboxes is not None and target_class_ids is not None and len(target_bboxes) != len(target_class_ids):
raise ValueError(f"target_bboxes and target_class_ids lengths should be equal, got: {len(target_bboxes)} and {len(target_class_ids)}.")
if target_bboxes is not None and target_class_ids is not None and len(target_bboxes) != len(self._images_prediction_lst):
raise ValueError(
f"target_bboxes and target_class_ids lengths should be equal, to the "
f"amount of images passed to predict(), got: {len(target_bboxes)} and {len(self._images_prediction_lst)}."
)
if target_bboxes is None:
target_bboxes = [None for _ in range(len(self._images_prediction_lst))]
target_class_ids = [None for _ in range(len(self._images_prediction_lst))]

return target_bboxes, target_class_ids

def save(
self, output_folder: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None
self,
output_folder: str,
box_thickness: int = 2,
show_confidence: bool = True,
color_mapping: Optional[List[Tuple[int, int, int]]] = None,
target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
) -> None:
"""Save the predicted bboxes on the images.

Expand All @@ -266,11 +455,23 @@ def save(
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
:param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape
(image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one).

:param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
(image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays (default=None).

:param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh',
'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an
error if not None and target_bboxes is None.
"""
if output_folder:
os.makedirs(output_folder, exist_ok=True)

for i, prediction in enumerate(self._images_prediction_lst):
target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids)

for i, (prediction, target_bbox, target_class_id) in enumerate(zip(self._images_prediction_lst, target_bboxes, target_class_ids)):
image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
prediction.save(output_path=image_output_path, box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)

Expand Down
Loading