Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1047 predict od with labels #1365

Merged
merged 21 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from super_gradients.common.object_names import Models
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
from super_gradients.training import models
from pathlib import Path
from super_gradients.training.datasets import COCODetectionDataset

# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
mini_coco_data_dir = str(Path(__file__).parent.parent.parent.parent.parent / "tests" / "data" / "tinycoco")

dataset = COCODetectionDataset(
data_dir=mini_coco_data_dir, subdir="images/val2017", json_file="instances_val2017.json", input_dim=None, transforms=[], cache_annotations=False
)

# x's are np.ndarrays images of shape (H,W,3)
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
# y's are np.ndarrays of shape (num_boxes,x1,y1,x2,y2,class_id)
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
image1, target1, _ = dataset[0]
image2, target2, _ = dataset[1]

# images from COCODetectionDataset are RGB and images as np.ndarrays are expected to be BGR
image2 = image2[:, :, ::-1]
image1 = image1[:, :, ::-1]

predictions = model.predict(
[image1, image2], target_bboxes=[target1[:, :4], target2[:, :4]], target_class_ids=[target1[:, 4], target2[:, 4]], target_bboxes_format="xyxy"
)
predictions.show()
predictions.save(output_folder="") # Save in working directory
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Union, Optional, List
from functools import lru_cache

import numpy as np
import torch
from torch import nn
from omegaconf import DictConfig
Expand Down Expand Up @@ -182,18 +183,39 @@ def predict(
conf: Optional[float] = None,
batch_size: int = 32,
fuse_model: bool = True,
target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.

:param images: Images to predict.

:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.

:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param batch_size: Maximum number of images to process at the same time.

:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.

:param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape
(image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one).

:param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
(image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays (default=None).

:param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh',
'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an
error if not None and target_bboxes is None.


"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
return pipeline(images, batch_size=batch_size) # type: ignore
return pipeline(
images, batch_size=batch_size, target_bboxes=target_bboxes, target_bboxes_format=target_bboxes_format, target_class_ids=target_class_ids
) # type: ignore

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import lru_cache
from typing import Union, Optional, List, Tuple

import numpy as np
import torch
from torch import Tensor

Expand Down Expand Up @@ -165,6 +166,9 @@ def predict(
conf: Optional[float] = None,
batch_size: int = 32,
fuse_model: bool = True,
target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.

Expand All @@ -174,9 +178,22 @@ def predict(
If None, the default value associated to the training is used.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.

:param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape
(image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one).

:param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
(image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays (default=None).

:param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh',
'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an
error if not None and target_bboxes is None.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
return pipeline(images, batch_size=batch_size) # type: ignore
return pipeline(
images, batch_size=batch_size, target_bboxes=target_bboxes, target_bboxes_format=target_bboxes_format, target_class_ids=target_class_ids
)

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def __init__(

@torch.jit.ignore
def cache_anchors(self, input_size: Tuple[int, int]):
b, c, h, w = input_size
self.eval_size = (h, w)
self.eval_size = list(input_size)[-2:]
device = infer_model_device(self.pred_cls)
dtype = infer_model_dtype(self.pred_cls)
anchor_points, stride_tensor = self._generate_anchors(dtype=dtype, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Union, Type, List, Tuple, Optional
from functools import lru_cache

import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -550,6 +551,9 @@ def predict(
conf: Optional[float] = None,
batch_size: int = 32,
fuse_model: bool = True,
target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.

Expand All @@ -559,9 +563,22 @@ def predict(
If None, the default value associated to the training is used.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.

:param target_bboxes: Optional[List[np.ndarray]], ground truth bounding boxes. Can either be an np.ndarray of shape
(image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one).

:param target_class_ids: Optional[List[np.ndarray]], ground truth target class indices. Can either be an np.ndarray of shape
(image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays (default=None).

:param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh',
'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an
error if not None and target_bboxes is None.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
return pipeline(images, batch_size=batch_size) # type: ignore
return pipeline(
images, batch_size=batch_size, target_bboxes=target_bboxes, target_bboxes_format=target_bboxes_format, target_class_ids=target_class_ids
)

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
Expand Down
79 changes: 66 additions & 13 deletions src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _fuse_model(self, input_example: torch.Tensor):
self.model.prep_model_for_conversion(input_size=input_example.shape[-2:])
self.fuse_model = False

def __call__(self, inputs: Union[str, ImageSource, List[ImageSource]], batch_size: Optional[int] = 32) -> ImagesPredictions:
def __call__(self, inputs: Union[str, ImageSource, List[ImageSource]], batch_size: Optional[int] = 32, **kwargs) -> ImagesPredictions:
"""Predict an image or a list of images.

Supported types include:
Expand All @@ -102,13 +102,13 @@ def __call__(self, inputs: Union[str, ImageSource, List[ImageSource]], batch_siz
"""

if includes_video_extension(inputs):
return self.predict_video(inputs, batch_size)
return self.predict_video(inputs, batch_size, **kwargs)
elif check_image_typing(inputs):
return self.predict_images(inputs, batch_size)
return self.predict_images(inputs, batch_size, **kwargs)
else:
raise ValueError(f"Input {inputs} not supported for prediction.")

def predict_images(self, images: Union[ImageSource, List[ImageSource]], batch_size: Optional[int] = 32) -> ImagesPredictions:
def predict_images(self, images: Union[ImageSource, List[ImageSource]], batch_size: Optional[int] = 32, **kwargs) -> ImagesPredictions:
"""Predict an image or a list of images.

:param images: Images to predict.
Expand All @@ -118,7 +118,7 @@ def predict_images(self, images: Union[ImageSource, List[ImageSource]], batch_si
from super_gradients.training.utils.media.image import load_images

images = load_images(images)
result_generator = self._generate_prediction_result(images=images, batch_size=batch_size)
result_generator = self._generate_prediction_result(images=images, batch_size=batch_size, **kwargs)
return self._combine_image_prediction_to_images(result_generator, n_images=len(images))

def predict_video(self, video_path: str, batch_size: Optional[int] = 32) -> VideoPredictions:
Expand All @@ -143,7 +143,7 @@ def _draw_predictions(frame: np.ndarray) -> np.ndarray:
video_streaming = WebcamStreaming(frame_processing_fn=_draw_predictions, fps_update_frequency=1)
video_streaming.run()

def _generate_prediction_result(self, images: Iterable[np.ndarray], batch_size: Optional[int] = None) -> Iterable[ImagePrediction]:
def _generate_prediction_result(self, images: Iterable[np.ndarray], batch_size: Optional[int] = None, **kwargs) -> Iterable[ImagePrediction]:
"""Run the pipeline on the images as single batch or through multiple batches.

NOTE: A core motivation to have this function as a generator is that it can be used in a lazy way (if images is generator itself),
Expand All @@ -154,12 +154,12 @@ def _generate_prediction_result(self, images: Iterable[np.ndarray], batch_size:
:return: Iterable of Results object, each containing the results of the prediction and the image.
"""
if batch_size is None:
yield from self._generate_prediction_result_single_batch(images)
yield from self._generate_prediction_result_single_batch(images, **kwargs)
else:
for batch_images in generate_batch(images, batch_size):
yield from self._generate_prediction_result_single_batch(batch_images)
yield from self._generate_prediction_result_single_batch(batch_images, **kwargs)

def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray]) -> Iterable[ImagePrediction]:
def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray], **kwargs) -> Iterable[ImagePrediction]:
"""Run the pipeline on images. The pipeline is made of 4 steps:
1. Load images - Loading the images into a list of numpy arrays.
2. Preprocess - Encode the image in the shape/format expected by the model
Expand All @@ -186,7 +186,7 @@ def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray])
if self.fuse_model:
self._fuse_model(torch_inputs)
model_output = self.model(torch_inputs)
predictions = self._decode_model_output(model_output, model_input=torch_inputs)
predictions = self._decode_model_output(model_output, model_input=torch_inputs, **kwargs)

# Postprocess
postprocessed_predictions = []
Expand All @@ -199,7 +199,7 @@ def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray])
yield self._instantiate_image_prediction(image=image, prediction=prediction)

@abstractmethod
def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[Prediction]:
def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray, **kwargs) -> List[Prediction]:
"""Decode the model outputs, move each prediction to numpy and store it in a Prediction object.

:param model_output: Direct output of the model, without any post-processing.
Expand Down Expand Up @@ -266,31 +266,84 @@ def __init__(
super().__init__(model=model, device=device, image_processor=image_processor, class_names=class_names, fuse_model=fuse_model)
self.post_prediction_callback = post_prediction_callback

def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[DetectionPrediction]:
def _decode_model_output(
self,
model_output: Union[List, Tuple, torch.Tensor],
model_input: np.ndarray,
target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
) -> List[DetectionPrediction]:

"""Decode the model output, by applying post prediction callback. This includes NMS.

:param model_output: Direct output of the model, without any post-processing.
:param model_input: Model input (i.e. images after preprocessing).

:param target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape
(image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one).

:param target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
(image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays (default=None).

:param target_bboxes_format: Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh',
'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an
error if not None and target_bboxes is None.

:return: Predicted Bboxes.
"""
target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids)

post_nms_predictions = self.post_prediction_callback(model_output, device=self.device)
if target_bboxes is None:
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
target_bboxes = [None for _ in range(len(model_input))]
target_class_ids = [None for _ in range(len(model_input))]

predictions = []
for prediction, image in zip(post_nms_predictions, model_input):
for prediction, image, target_bbox, target_class_id in zip(post_nms_predictions, model_input, target_bboxes, target_class_ids):
prediction = prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32)
target_bbox = target_bbox if target_bbox is not None else np.zeros((0, 4))
target_class_id = target_class_id if target_class_id is not None else np.zeros((0, 1))
prediction = prediction.detach().cpu().numpy()

predictions.append(
DetectionPrediction(
bboxes=prediction[:, :4],
confidence=prediction[:, 4],
labels=prediction[:, 5],
bbox_format="xyxy",
target_bboxes=target_bbox,
target_labels=target_class_id,
target_bbox_format=target_bboxes_format,
image_shape=image.shape,
)
)

return predictions

@staticmethod
def _check_target_args(
target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
target_bboxes_format: Optional[str] = None,
target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
):
if not (
(target_bboxes is None and target_bboxes_format is None and target_class_ids is None)
or (target_bboxes is not None and target_bboxes_format is not None and target_class_ids is not None)
):
raise ValueError("target_bboxes, target_bboxes_format, and target_class_ids should either all be None or all not None.")

if isinstance(target_bboxes, np.ndarray):
target_bboxes = [target_bboxes]
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(target_class_ids, np.ndarray):
target_class_ids = [target_class_ids]

if target_bboxes is not None and target_class_ids is not None and len(target_bboxes) != len(target_class_ids):
raise ValueError(f"target_bboxes and target_class_ids lengths should be equal, got: {len(target_bboxes)} and {len(target_class_ids)}.")

return target_bboxes, target_class_ids
shaydeci marked this conversation as resolved.
Show resolved Hide resolved

def _instantiate_image_prediction(self, image: np.ndarray, prediction: DetectionPrediction) -> ImagePrediction:
return ImageDetectionPrediction(image=image, prediction=prediction, class_names=self.class_names)

Expand Down
1 change: 1 addition & 0 deletions src/super_gradients/training/processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def infer_image_input_shape(self) -> Optional[Tuple[int, int]]:
class DetectionRescale(_Rescale):
def postprocess_predictions(self, predictions: DetectionPrediction, metadata: RescaleMetadata) -> DetectionPrediction:
predictions.bboxes_xyxy = _rescale_bboxes(targets=predictions.bboxes_xyxy, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w))

return predictions


Expand Down
Loading