diff --git a/documentation/source/images/examples/countryside.jpg b/documentation/source/images/examples/countryside.jpg new file mode 100644 index 0000000000..acb18730d0 Binary files /dev/null and b/documentation/source/images/examples/countryside.jpg differ diff --git a/documentation/source/images/examples/street_busy.jpg b/documentation/source/images/examples/street_busy.jpg new file mode 100644 index 0000000000..041aff4e9e Binary files /dev/null and b/documentation/source/images/examples/street_busy.jpg differ diff --git a/documentation/source/images/examples/street_vehicles.jpg b/documentation/source/images/examples/street_vehicles.jpg new file mode 100644 index 0000000000..acc8ef112e Binary files /dev/null and b/documentation/source/images/examples/street_vehicles.jpg differ diff --git a/src/super_gradients/examples/predict/detection_predict.py b/src/super_gradients/examples/predict/detection_predict.py index 3ddf9c21b0..f1d8e4d297 100644 --- a/src/super_gradients/examples/predict/detection_predict.py +++ b/src/super_gradients/examples/predict/detection_predict.py @@ -5,9 +5,10 @@ model = models.get(Models.PP_YOLOE_S, pretrained_weights="coco") IMAGES = [ - "https://miro.medium.com/v2/resize:fit:500/0*w1s81z-Q72obhE_z", - "https://s.hs-data.com/bilder/spieler/gross/128069.jpg", - "https://datasets-server.huggingface.co/assets/Chris1/cityscapes/--/Chris1--cityscapes/train/28/image/image.jpg", + "../../../../documentation/source/images/examples/countryside.jpg", + "../../../../documentation/source/images/examples/street_busy.jpg", + "https://cdn-attachments.timesofmalta.com/cc1eceadde40d2940bc5dd20692901371622153217-1301777007-4d978a6f-620x348.jpg", ] -prediction = model.predict(IMAGES, iou=0.65, conf=0.5) + +prediction = model.predict(IMAGES) prediction.show() diff --git a/src/super_gradients/examples/predict/detection_predict_image_folder.py b/src/super_gradients/examples/predict/detection_predict_image_folder.py new file mode 100644 index 0000000000..27fb7dadf7 --- /dev/null +++ b/src/super_gradients/examples/predict/detection_predict_image_folder.py @@ -0,0 +1,9 @@ +from super_gradients.common.object_names import Models +from super_gradients.training import models + +# Note that currently only YoloX and PPYoloE are supported. +model = models.get(Models.YOLOX_N, pretrained_weights="coco") + +image_folder_path = "../../../../documentation/source/images/examples" +predictions = model.predict(image_folder_path) +predictions.show() diff --git a/src/super_gradients/examples/predict/detection_predict_streaming.py b/src/super_gradients/examples/predict/detection_predict_streaming.py new file mode 100644 index 0000000000..e5f912c02d --- /dev/null +++ b/src/super_gradients/examples/predict/detection_predict_streaming.py @@ -0,0 +1,6 @@ +from super_gradients.common.object_names import Models +from super_gradients.training import models + +# Note that currently only YoloX and PPYoloE are supported. +model = models.get(Models.YOLOX_N, pretrained_weights="coco") +model.predict_webcam() diff --git a/src/super_gradients/examples/predict/detection_predict_video.py b/src/super_gradients/examples/predict/detection_predict_video.py new file mode 100644 index 0000000000..0e902d9f1e --- /dev/null +++ b/src/super_gradients/examples/predict/detection_predict_video.py @@ -0,0 +1,9 @@ +from super_gradients.common.object_names import Models +from super_gradients.training import models + +# Note that currently only YoloX and PPYoloE are supported. +model = models.get(Models.YOLOX_N, pretrained_weights="coco") + +video_path = "" +predictions = model.predict(video_path) +predictions.show() diff --git a/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py b/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py index 4e40ddf946..d217c2342c 100644 --- a/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py +++ b/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py @@ -5,7 +5,7 @@ from typing import List, Optional from super_gradients.common.abstractions.abstract_logger import get_logger -from super_gradients.training.utils.load_image import is_image +from super_gradients.training.utils.media.image import is_image from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormatConverter from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL, LABEL_NORMALIZED_CXCYWH diff --git a/src/super_gradients/training/models/detection_models/customizable_detector.py b/src/super_gradients/training/models/detection_models/customizable_detector.py index 9d7e0055ae..8fc2e600b0 100644 --- a/src/super_gradients/training/models/detection_models/customizable_detector.py +++ b/src/super_gradients/training/models/detection_models/customizable_detector.py @@ -5,9 +5,7 @@ * each module accepts in_channels and other parameters * each module defines out_channels property on construction """ - - -from typing import Union, Optional +from typing import Union, Optional, List from torch import nn from omegaconf import DictConfig @@ -15,6 +13,11 @@ from super_gradients.training.utils.utils import HpmStruct from super_gradients.training.models.sg_module import SgModule import super_gradients.common.factories.detection_modules_factory as det_factory +from super_gradients.training.models.prediction_results import ImagesDetectionPrediction +from super_gradients.training.pipelines.pipelines import DetectionPipeline +from super_gradients.training.transforms.processing import Processing +from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback +from super_gradients.training.utils.media.image import ImageSource class CustomizableDetector(SgModule): @@ -67,6 +70,12 @@ def __init__( self._initialize_weights(bn_eps, bn_momentum, inplace_act) + # Processing params + self._class_names: Optional[List[str]] = None + self._image_processor: Optional[Processing] = None + self._default_nms_iou: Optional[float] = None + self._default_nms_conf: Optional[float] = None + def forward(self, x): x = self.backbone(x) x = self.neck(x) @@ -96,3 +105,70 @@ def replace_head(self, new_num_classes: Optional[int] = None, new_head: Optional self.heads_params = factory.insert_module_param(self.heads_params, "num_classes", new_num_classes) self.heads = factory.get(factory.insert_module_param(self.heads_params, "in_channels", self.neck.out_channels)) self._initialize_weights(self.bn_eps, self.bn_momentum, self.inplace_act) + + @staticmethod + def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback: + raise NotImplementedError + + def set_dataset_processing_params( + self, + class_names: Optional[List[str]] = None, + image_processor: Optional[Processing] = None, + iou: Optional[float] = None, + conf: Optional[float] = None, + ) -> None: + """Set the processing parameters for the dataset. + + :param class_names: (Optional) Names of the dataset the model was trained on. + :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training. + :param iou: (Optional) IoU threshold for the nms algorithm + :param conf: (Optional) Below the confidence threshold, prediction are discarded + """ + self._class_names = class_names or self._class_names + self._image_processor = image_processor or self._image_processor + self._default_nms_iou = iou or self._default_nms_iou + self._default_nms_conf = conf or self._default_nms_conf + + def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None) -> DetectionPipeline: + """Instantiate the prediction pipeline of this model. + + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + """ + if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf): + raise RuntimeError( + "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first." + ) + + iou = iou or self._default_nms_iou + conf = conf or self._default_nms_conf + + pipeline = DetectionPipeline( + model=self, + image_processor=self._image_processor, + post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf), + class_names=self._class_names, + ) + return pipeline + + def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None) -> ImagesDetectionPrediction: + """Predict an image or a list of images. + + :param images: Images to predict. + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + """ + pipeline = self._get_pipeline(iou=iou, conf=conf) + return pipeline(images) # type: ignore + + def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None): + """Predict using webcam. + + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + """ + pipeline = self._get_pipeline(iou=iou, conf=conf) + pipeline.predict_webcam() diff --git a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py index 94ebcb8591..3365de20e4 100644 --- a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py +++ b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py @@ -12,9 +12,10 @@ from super_gradients.training.utils import HpmStruct from super_gradients.training.models.arch_params_factory import get_arch_params from super_gradients.training.models.detection_models.pp_yolo_e.post_prediction_callback import PPYoloEPostPredictionCallback, DetectionPostPredictionCallback -from super_gradients.training.models.results import DetectionResults +from super_gradients.training.models.prediction_results import ImagesDetectionPrediction from super_gradients.training.pipelines.pipelines import DetectionPipeline from super_gradients.training.transforms.processing import Processing +from super_gradients.training.utils.media.image import ImageSource class PPYoloE(SgModule): @@ -29,34 +30,75 @@ def __init__(self, arch_params): self._class_names: Optional[List[str]] = None self._image_processor: Optional[Processing] = None + self._default_nms_iou: Optional[float] = None + self._default_nms_conf: Optional[float] = None @staticmethod def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback: return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300) - def set_dataset_processing_params(self, class_names: Optional[List[str]], image_processor: Optional[Processing]) -> None: + def set_dataset_processing_params( + self, + class_names: Optional[List[str]] = None, + image_processor: Optional[Processing] = None, + iou: Optional[float] = None, + conf: Optional[float] = None, + ) -> None: """Set the processing parameters for the dataset. :param class_names: (Optional) Names of the dataset the model was trained on. :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training. + :param iou: (Optional) IoU threshold for the nms algorithm + :param conf: (Optional) Below the confidence threshold, prediction are discarded """ self._class_names = class_names or self._class_names self._image_processor = image_processor or self._image_processor + self._default_nms_iou = iou or self._default_nms_iou + self._default_nms_conf = conf or self._default_nms_conf - def predict(self, images, iou: float = 0.65, conf: float = 0.01) -> DetectionResults: + def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None) -> DetectionPipeline: + """Instantiate the prediction pipeline of this model. - if self._class_names is None or self._image_processor is None: + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + """ + if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf): raise RuntimeError( "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first." ) + iou = iou or self._default_nms_iou + conf = conf or self._default_nms_conf + pipeline = DetectionPipeline( model=self, image_processor=self._image_processor, post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf), class_names=self._class_names, ) - return pipeline(images) + return pipeline + + def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None) -> ImagesDetectionPrediction: + """Predict an image or a list of images. + + :param images: Images to predict. + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + """ + pipeline = self._get_pipeline(iou=iou, conf=conf) + return pipeline(images) # type: ignore + + def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None): + """Predict using webcam. + + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + """ + pipeline = self._get_pipeline(iou=iou, conf=conf) + pipeline.predict_webcam() def forward(self, x: Tensor): features = self.backbone(x) diff --git a/src/super_gradients/training/models/detection_models/yolo_base.py b/src/super_gradients/training/models/detection_models/yolo_base.py index 6573ad4075..62ad295329 100755 --- a/src/super_gradients/training/models/detection_models/yolo_base.py +++ b/src/super_gradients/training/models/detection_models/yolo_base.py @@ -11,10 +11,10 @@ from super_gradients.training.utils import torch_version_is_greater_or_equal from super_gradients.training.utils.detection_utils import non_max_suppression, matrix_non_max_suppression, NMS_Type, DetectionPostPredictionCallback, Anchors from super_gradients.training.utils.utils import HpmStruct, check_img_size_divisibility, get_param -from super_gradients.training.models.results import DetectionResults +from super_gradients.training.models.prediction_results import ImagesDetectionPrediction from super_gradients.training.pipelines.pipelines import DetectionPipeline from super_gradients.training.transforms.processing import Processing - +from super_gradients.training.utils.media.image import ImageSource COCO_DETECTION_80_CLASSES_BBOX_ANCHORS = Anchors( [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], strides=[8, 16, 32] @@ -418,33 +418,75 @@ def __init__(self, backbone: Type[nn.Module], arch_params: HpmStruct, initialize self._class_names: Optional[List[str]] = None self._image_processor: Optional[Processing] = None + self._default_nms_iou: Optional[float] = None + self._default_nms_conf: Optional[float] = None @staticmethod def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback: return YoloPostPredictionCallback(conf=conf, iou=iou) - def set_dataset_processing_params(self, class_names: Optional[List[str]], image_processor: Optional[Processing]) -> None: + def set_dataset_processing_params( + self, + class_names: Optional[List[str]] = None, + image_processor: Optional[Processing] = None, + iou: Optional[float] = None, + conf: Optional[float] = None, + ) -> None: """Set the processing parameters for the dataset. :param class_names: (Optional) Names of the dataset the model was trained on. :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training. + :param iou: (Optional) IoU threshold for the nms algorithm + :param conf: (Optional) Below the confidence threshold, prediction are discarded """ self._class_names = class_names or self._class_names self._image_processor = image_processor or self._image_processor + self._default_nms_iou = iou or self._default_nms_iou + self._default_nms_conf = conf or self._default_nms_conf + + def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None) -> DetectionPipeline: + """Instantiate the prediction pipeline of this model. - def predict(self, images, iou: float = 0.65, conf: float = 0.01) -> DetectionResults: - if self._class_names is None or self._image_processor is None: + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + """ + if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf): raise RuntimeError( "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first." ) + iou = iou or self._default_nms_iou + conf = conf or self._default_nms_conf + pipeline = DetectionPipeline( model=self, image_processor=self._image_processor, post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf), class_names=self._class_names, ) - return pipeline(images) + return pipeline + + def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None) -> ImagesDetectionPrediction: + """Predict an image or a list of images. + + :param images: Images to predict. + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + """ + pipeline = self._get_pipeline(iou=iou, conf=conf) + return pipeline(images) # type: ignore + + def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None): + """Predict using webcam. + + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + """ + pipeline = self._get_pipeline(iou=iou, conf=conf) + pipeline.predict_webcam() def forward(self, x): out = self._backbone(x) diff --git a/src/super_gradients/training/models/model_factory.py b/src/super_gradients/training/models/model_factory.py index a1b3b98570..01de15febc 100644 --- a/src/super_gradients/training/models/model_factory.py +++ b/src/super_gradients/training/models/model_factory.py @@ -136,8 +136,9 @@ def instantiate_model( net.replace_head(new_num_classes=num_classes_new_head) arch_params.num_classes = num_classes_new_head - class_names, image_processor = get_pretrained_processing_params(model_name, pretrained_weights) - net.set_dataset_processing_params(class_names, image_processor) + # TODO: remove once we load it from the checkpoint + processing_params = get_pretrained_processing_params(model_name, pretrained_weights) + net.set_dataset_processing_params(**processing_params) _add_model_name_attribute(net, model_name) diff --git a/src/super_gradients/training/models/results.py b/src/super_gradients/training/models/prediction_results.py similarity index 57% rename from src/super_gradients/training/models/results.py rename to src/super_gradients/training/models/prediction_results.py index 34e09844ba..1396e34210 100644 --- a/src/super_gradients/training/models/results.py +++ b/src/super_gradients/training/models/prediction_results.py @@ -1,17 +1,18 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Iterator from dataclasses import dataclass -from matplotlib import pyplot as plt import numpy as np from super_gradients.training.utils.detection_utils import DetectionVisualization from super_gradients.training.models.predictions import Prediction, DetectionPrediction +from super_gradients.training.utils.media.video import show_video_from_frames +from super_gradients.training.utils.media.image import show_image @dataclass -class Result(ABC): - """Results of a given computer vision task (detection, classification, etc.). +class ImagePrediction(ABC): + """Object wrapping an image and a model's prediction. :attr image: Input image :attr predictions: Predictions of the model @@ -19,7 +20,7 @@ class Result(ABC): """ image: np.ndarray - predictions: Prediction + prediction: Prediction class_names: List[str] @abstractmethod @@ -34,28 +35,8 @@ def show(self) -> None: @dataclass -class Results(ABC): - """List of results of a given computer vision task (detection, classification, etc.). - - :attr results: List of results of the run - """ - - results: List[Result] - - @abstractmethod - def draw(self) -> List[np.ndarray]: - """Draw the predictions on the image.""" - pass - - @abstractmethod - def show(self) -> None: - """Display the predictions on the image.""" - pass - - -@dataclass -class DetectionResult(Result): - """Result of a detection task. +class ImageDetectionPrediction(ImagePrediction): + """Object wrapping an image and a detection model's prediction. :attr image: Input image :attr predictions: Predictions of the model @@ -63,7 +44,7 @@ class DetectionResult(Result): """ image: np.ndarray - predictions: DetectionPrediction + prediction: DetectionPrediction class_names: List[str] def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> np.ndarray: @@ -78,18 +59,18 @@ def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mappi image_np = self.image.copy() color_mapping = color_mapping or DetectionVisualization._generate_color_mapping(len(self.class_names)) - for pred_i in range(len(self.predictions)): + for pred_i in range(len(self.prediction)): image_np = DetectionVisualization._draw_box_title( color_mapping=color_mapping, class_names=self.class_names, box_thickness=box_thickness, image_np=image_np, - x1=int(self.predictions.bboxes_xyxy[pred_i, 0]), - y1=int(self.predictions.bboxes_xyxy[pred_i, 1]), - x2=int(self.predictions.bboxes_xyxy[pred_i, 2]), - y2=int(self.predictions.bboxes_xyxy[pred_i, 3]), - class_id=int(self.predictions.labels[pred_i]), - pred_conf=self.predictions.confidence[pred_i] if show_confidence else None, + x1=int(self.prediction.bboxes_xyxy[pred_i, 0]), + y1=int(self.prediction.bboxes_xyxy[pred_i, 1]), + x2=int(self.prediction.bboxes_xyxy[pred_i, 2]), + y2=int(self.prediction.bboxes_xyxy[pred_i, 3]), + class_id=int(self.prediction.labels[pred_i]), + pred_conf=self.prediction.confidence[pred_i] if show_confidence else None, ) return image_np @@ -102,34 +83,80 @@ def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mappi Default is None, which generates a default color mapping based on the number of class names. """ image_np = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) + show_image(image_np) + + +@dataclass +class ImagesPredictions(ABC): + """Object wrapping the list of image predictions. - plt.imshow(image_np, interpolation="nearest") - plt.axis("off") - plt.show() + :attr _images_prediction_lst: List of results of the run + """ + + _images_prediction_lst: List[ImagePrediction] + + def __len__(self) -> int: + return len(self._images_prediction_lst) + + def __getitem__(self, index: int) -> ImagePrediction: + return self._images_prediction_lst[index] + + def __iter__(self) -> Iterator[ImagePrediction]: + return iter(self._images_prediction_lst) + + @abstractmethod + def show(self) -> None: + pass @dataclass -class DetectionResults(Results): - """Results of a detection task. +class VideoPredictions(ImagesPredictions, ABC): + """Object wrapping the list of image predictions as a Video. - :attr results: List of the predictions results + :attr _images_prediction_lst: List of results of the run + :att fps: Frames per second of the video """ - def __init__(self, images: List[np.ndarray], predictions: List[DetectionPrediction], class_names: List[str]): - self.results: List[DetectionResult] = [] - for image, prediction in zip(images, predictions): - self.results.append(DetectionResult(image=image, predictions=prediction, class_names=class_names)) + _images_prediction_lst: List[ImagePrediction] + fps: float + + @abstractmethod + def show(self, *args, **kwargs) -> None: + """Display the predictions on the image.""" + pass + + +@dataclass +class ImagesDetectionPrediction(ImagesPredictions): + """Object wrapping the list of image detection predictions. + + :attr _images_prediction_lst: List of the predictions results + """ - def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> List[np.ndarray]: - """Draw the predicted bboxes on the images. + _images_prediction_lst: List[ImageDetectionPrediction] + + def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> None: + """Display the predicted bboxes on the images. :param box_thickness: Thickness of bounding boxes. :param show_confidence: Whether to show confidence scores on the image. :param color_mapping: List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names. - :return: List of Images with predicted bboxes for each image. Note that this does not modify the original images. """ - return [prediction.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) for prediction in self.results] + for prediction in self._images_prediction_lst: + prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) + + +@dataclass +class VideoDetectionPrediction(VideoPredictions): + """Object wrapping the list of image detection predictions as a Video. + + :attr _images_prediction_lst: List of the predictions results + :att fps: Frames per second of the video + """ + + _images_prediction_lst: List[ImageDetectionPrediction] + fps: float def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> None: """Display the predicted bboxes on the images. @@ -139,5 +166,7 @@ def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mappi :param color_mapping: List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names. """ - for prediction in self.results: - prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) + frames = [ + result.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) for result in self._images_prediction_lst + ] + show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps) diff --git a/src/super_gradients/training/models/sg_module.py b/src/super_gradients/training/models/sg_module.py index 560e834a24..dc4c38166f 100755 --- a/src/super_gradients/training/models/sg_module.py +++ b/src/super_gradients/training/models/sg_module.py @@ -3,7 +3,7 @@ from torch import nn from super_gradients.training.utils.utils import HpmStruct -from super_gradients.training.models.results import Result +from super_gradients.training.models.prediction_results import ImagesPredictions class SgModule(nn.Module): @@ -64,9 +64,12 @@ class to implement. raise NotImplementedError - def predict(self, images, *args, **kwargs) -> Result: + def predict(self, images, *args, **kwargs) -> ImagesPredictions: raise NotImplementedError(f"`predict` is not implemented for {self.__class__.__name__}.") + def predict_webcam(self, *args, **kwargs) -> None: + raise NotImplementedError(f"`predict_webcam` is not implemented for {self.__class__.__name__}.") + def set_dataset_processing_params(self, *args, **kwargs) -> None: """Set the processing parameters for the dataset.""" pass diff --git a/src/super_gradients/training/pipelines/pipelines.py b/src/super_gradients/training/pipelines/pipelines.py index 55eb4ae4e3..eb050486cc 100644 --- a/src/super_gradients/training/pipelines/pipelines.py +++ b/src/super_gradients/training/pipelines/pipelines.py @@ -1,16 +1,30 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Iterable from contextlib import contextmanager +from tqdm import tqdm import numpy as np import torch - -from super_gradients.training.utils.load_image import load_images, ImageType +from super_gradients.training.utils.utils import generate_batch +from super_gradients.training.utils.media.video import load_video, is_video +from super_gradients.training.utils.media.image import ImageSource, check_image_typing +from super_gradients.training.utils.media.stream import WebcamStreaming from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback from super_gradients.training.models.sg_module import SgModule -from super_gradients.training.models.results import Results, DetectionResults +from super_gradients.training.models.prediction_results import ( + ImagesDetectionPrediction, + VideoDetectionPrediction, + ImagePrediction, + ImageDetectionPrediction, + ImagesPredictions, + VideoPredictions, +) from super_gradients.training.models.predictions import Prediction, DetectionPrediction from super_gradients.training.transforms.processing import Processing, ComposeProcessing +from super_gradients.common.abstractions.abstract_logger import get_logger + + +logger = get_logger(__name__) @contextmanager @@ -35,39 +49,101 @@ class Pipeline(ABC): :param device: The device on which the model will be run. Defaults to "cpu". Use "cuda" for GPU support. """ - def __init__(self, model: SgModule, image_processor: Union[Processing, List[Processing]], device: Optional[str] = "cpu"): + def __init__(self, model: SgModule, image_processor: Union[Processing, List[Processing]], class_names: List[str], device: Optional[str] = "cpu"): super().__init__() self.model = model.to(device) self.device = device + self.class_names = class_names if isinstance(image_processor, list): image_processor = ComposeProcessing(image_processor) self.image_processor = image_processor - @abstractmethod - def __call__(self, images: Union[ImageType, List[ImageType]]) -> Union[Results, Tuple[List[np.ndarray], List[Prediction]]]: - """Apply the pipeline on images and return the result. + def __call__(self, inputs: Union[str, ImageSource, List[ImageSource]], batch_size: Optional[int] = 32) -> ImagesPredictions: + """Predict an image or a list of images. + + Supported types include: + - str: A string representing either a video, an image or an URL. + - numpy.ndarray: A numpy array representing the image + - torch.Tensor: A PyTorch tensor representing the image + - PIL.Image.Image: A PIL Image object + - List: A list of images of any of the above image types (list of videos not supported). - :param images: Single image or a list of images of supported types. - :return Results object containing the results of the prediction and the image. + :param inputs: inputs to the model, which can be any of the above-mentioned types. + :param batch_size: Number of images to be processed at the same time. + :return: Results of the prediction. """ - return self._run(images=images) - - def _run(self, images: Union[ImageType, List[ImageType]]) -> Tuple[List[np.ndarray], List[Prediction]]: - """Run the pipeline and return (image, predictions). The pipeline is made of 4 steps: - 1. Load images - Loading the images into a list of numpy arrays. - 2. Preprocess - Encode the image in the shape/format expected by the model - 3. Predict - Run the model on the preprocessed image - 4. Postprocess - Decode the output of the model so that the predictions are in the shape/format of original image. - - :param images: Single image or a list of images of supported types. - :return: - - List of numpy arrays representing images. - - List of model predictions. + + if is_video(inputs): + return self.predict_video(inputs, batch_size) + elif check_image_typing(inputs): + return self.predict_images(inputs, batch_size) + else: + raise ValueError(f"Input {inputs} not supported for prediction.") + + def predict_images(self, images: Union[ImageSource, List[ImageSource]], batch_size: Optional[int] = 32) -> ImagesPredictions: + """Predict an image or a list of images. + + :param images: Images to predict. + :param batch_size: The size of each batch. + :return: Results of the prediction. """ - self.model = self.model.to(self.device) # Make sure the model is on the correct device, as it might have been moved after init + from super_gradients.training.utils.media.image import load_images images = load_images(images) + result_generator = self._generate_prediction_result(images=images, batch_size=batch_size) + return self._combine_image_prediction_to_images(result_generator, n_images=len(images)) + + def predict_video(self, video_path: str, batch_size: Optional[int] = 32) -> VideoPredictions: + """Predict on a video file, by processing the frames in batches. + + :param video_path: Path to the video file. + :param batch_size: The size of each batch. + :return: Results of the prediction. + """ + video_frames, fps = load_video(file_path=video_path) + result_generator = self._generate_prediction_result(images=video_frames, batch_size=batch_size) + return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames)) + + def predict_webcam(self) -> None: + """Predict using webcam""" + + def _draw_predictions(frame: np.ndarray) -> np.ndarray: + """Draw the predictions on a single frame from the stream.""" + frame_prediction = next(iter(self._generate_prediction_result(images=[frame]))) + return frame_prediction.draw() + + video_streaming = WebcamStreaming(frame_processing_fn=_draw_predictions, fps_update_frequency=1) + video_streaming.run() + + def _generate_prediction_result(self, images: Iterable[np.ndarray], batch_size: Optional[int] = None) -> Iterable[ImagePrediction]: + """Run the pipeline on the images as single batch or through multiple batches. + + NOTE: A core motivation to have this function as a generator is that it can be used in a lazy way (if images is generator itself), + i.e. without having to load all the images into memory. + + :param images: Iterable of numpy arrays representing images. + :param batch_size: The size of each batch. + :return: Iterable of Results object, each containing the results of the prediction and the image. + """ + if batch_size is None: + yield from self._generate_prediction_result_single_batch(images) + else: + for batch_images in generate_batch(images, batch_size): + yield from self._generate_prediction_result_single_batch(batch_images) + + def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray]) -> Iterable[ImagePrediction]: + """Run the pipeline on images. The pipeline is made of 4 steps: + 1. Load images - Loading the images into a list of numpy arrays. + 2. Preprocess - Encode the image in the shape/format expected by the model + 3. Predict - Run the model on the preprocessed image + 4. Postprocess - Decode the output of the model so that the predictions are in the shape/format of original image. + + :param images: Iterable of numpy arrays representing images. + :return: Iterable of Results object, each containing the results of the prediction and the image. + """ + images = list(images) # We need to load all the images into memory, and to reuse it afterwards. + self.model = self.model.to(self.device) # Make sure the model is on the correct device, as it might have been moved after init # Preprocess preprocessed_images, processing_metadatas = [], [] @@ -84,11 +160,13 @@ def _run(self, images: Union[ImageType, List[ImageType]]) -> Tuple[List[np.ndarr # Postprocess postprocessed_predictions = [] - for prediction, processing_metadata in zip(predictions, processing_metadatas): + for image, prediction, processing_metadata in zip(images, predictions, processing_metadatas): prediction = self.image_processor.postprocess_predictions(predictions=prediction, metadata=processing_metadata) postprocessed_predictions.append(prediction) - return images, postprocessed_predictions + # Yield results one by one + for image, prediction in zip(images, postprocessed_predictions): + yield self._instantiate_image_prediction(image=image, prediction=prediction) @abstractmethod def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[Prediction]: @@ -98,7 +176,40 @@ def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], m :param model_input: Model input (i.e. images after preprocessing). :return: Model predictions, without any post-processing. """ - pass + raise NotImplementedError + + @abstractmethod + def _instantiate_image_prediction(self, image: np.ndarray, prediction: Prediction) -> ImagePrediction: + """Instantiate an object wrapping an image and the pipeline's prediction. + + :param image: Image to predict. + :param prediction: Model prediction on that image. + :return: Object wrapping an image and the pipeline's prediction. + """ + raise NotImplementedError + + @abstractmethod + def _combine_image_prediction_to_images(self, images_prediction_lst: Iterable[ImagePrediction], n_images: Optional[int] = None) -> ImagesPredictions: + """Instantiate an object wrapping the list of images and the pipeline's predictions on them. + + :param images_prediction_lst: List of image predictions. + :param n_images: (Optional) Number of images in the list. This used for tqdm progress bar to work with iterables, but is not required. + :return: Object wrapping the list of image predictions. + """ + raise NotImplementedError + + @abstractmethod + def _combine_image_prediction_to_video( + self, images_prediction_lst: Iterable[ImagePrediction], fps: float, n_images: Optional[int] = None + ) -> VideoPredictions: + """Instantiate an object holding the video frames and the pipeline's predictions on it. + + :param images_prediction_lst: List of image predictions. + :param fps: Frames per second. + :param n_images: (Optional) Number of images in the list. This used for tqdm progress bar to work with iterables, but is not required. + :return: Object wrapping the list of image predictions as a Video. + """ + raise NotImplementedError class DetectionPipeline(Pipeline): @@ -120,18 +231,8 @@ def __init__( device: Optional[str] = "cpu", image_processor: Optional[Processing] = None, ): - super().__init__(model=model, device=device, image_processor=image_processor) + super().__init__(model=model, device=device, image_processor=image_processor, class_names=class_names) self.post_prediction_callback = post_prediction_callback - self.class_names = class_names - - def __call__(self, images: Union[List[ImageType], ImageType]) -> DetectionResults: - """Apply the pipeline on images and return the detection result. - - :param images: Single image or a list of images of supported types. - :return Results object containing the results of the prediction and the image. - """ - images, predictions = super().__call__(images=images) - return DetectionResults(images=images, predictions=predictions, class_names=self.class_names) def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[DetectionPrediction]: """Decode the model output, by applying post prediction callback. This includes NMS. @@ -144,7 +245,7 @@ def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], m predictions = [] for prediction, image in zip(post_nms_predictions, model_input): - prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32) + prediction = prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32) prediction = prediction.detach().cpu().numpy() predictions.append( DetectionPrediction( @@ -157,3 +258,18 @@ def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], m ) return predictions + + def _instantiate_image_prediction(self, image: np.ndarray, prediction: DetectionPrediction) -> ImagePrediction: + return ImageDetectionPrediction(image=image, prediction=prediction, class_names=self.class_names) + + def _combine_image_prediction_to_images( + self, images_predictions: Iterable[ImageDetectionPrediction], n_images: Optional[int] = None + ) -> ImagesDetectionPrediction: + images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")] + return ImagesDetectionPrediction(_images_prediction_lst=images_predictions) + + def _combine_image_prediction_to_video( + self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None + ) -> VideoDetectionPrediction: + images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")] + return VideoDetectionPrediction(_images_prediction_lst=images_predictions, fps=fps) diff --git a/src/super_gradients/training/transforms/processing.py b/src/super_gradients/training/transforms/processing.py index aa96d2cade..18eca83033 100644 --- a/src/super_gradients/training/transforms/processing.py +++ b/src/super_gradients/training/transforms/processing.py @@ -1,9 +1,10 @@ -from typing import Tuple, List, Union, Optional +from typing import Tuple, List, Union from abc import ABC, abstractmethod from dataclasses import dataclass import numpy as np +from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST from super_gradients.training.models.predictions import Prediction, DetectionPrediction from super_gradients.training.transforms.utils import ( _rescale_image, @@ -96,6 +97,49 @@ def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Pr return predictions +class ReverseImageChannels(Processing): + """Reverse the order of the image channels (RGB -> BGR or BGR -> RGB).""" + + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]: + """Reverse the channel order of an image. + + :param image: Image, in (H, W, C) format. + :return: Image with reversed channel order. (RGB if input was BGR, BGR if input was RGB) + """ + + if image.shape[2] != 3: + raise ValueError("ReverseImageChannels expects 3 channels, got: " + str(image.shape[2])) + + processed_image = image[..., ::-1] + return processed_image, None + + def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Prediction: + return predictions + + +class StandardizeImage(Processing): + """Standardize image pixel values with img/max_val + + :param max_value: Current maximum value of the image pixels. (usually 255) + """ + + def __init__(self, max_value: float = 255.0): + super().__init__() + self.max_value = max_value + + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]: + """Reverse the channel order of an image. + + :param image: Image, in (H, W, C) format. + :return: Image with reversed channel order. (RGB if input was BGR, BGR if input was RGB) + """ + processed_image = (image / self.max_value).astype(np.float32) + return processed_image, None + + def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Prediction: + return predictions + + class NormalizeImage(Processing): """Normalize an image based on means and standard deviation. @@ -204,41 +248,84 @@ def postprocess_predictions(self, predictions: DetectionPrediction, metadata: Re return predictions -def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -> Tuple[Optional[List[str]], Optional[Processing]]: - """Get the processing parameters for a pretrained model.""" - if "yolox" in model_name and pretrained_weights == "coco": - return default_yolox_coco_processing_params() - elif "ppyoloe" in model_name and pretrained_weights == "coco": - return default_ppyoloe_coco_processing_params() - else: - return None, None - - -def default_yolox_coco_processing_params() -> Tuple[List[str], Processing]: - """Processing parameters commonly used for training YoloX on COCO dataset.""" - from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST +def default_yolox_coco_processing_params() -> dict: + """Processing parameters commonly used for training YoloX on COCO dataset. + TODO: remove once we load it from the checkpoint + """ image_processor = ComposeProcessing( [ + ReverseImageChannels(), DetectionLongestMaxSizeRescale((640, 640)), DetectionBottomRightPadding((640, 640), 114), ImagePermute((2, 0, 1)), ] ) - class_names = COCO_DETECTION_CLASSES_LIST - return class_names, image_processor + + params = dict( + class_names=COCO_DETECTION_CLASSES_LIST, + image_processor=image_processor, + iou=0.65, + conf=0.1, + ) + return params -def default_ppyoloe_coco_processing_params() -> Tuple[List[str], Processing]: - """Processing parameters commonly used for training PPYoloE on COCO dataset.""" - from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST +def default_ppyoloe_coco_processing_params() -> dict: + """Processing parameters commonly used for training PPYoloE on COCO dataset. + TODO: remove once we load it from the checkpoint + """ image_processor = ComposeProcessing( [ + ReverseImageChannels(), DetectionRescale(output_shape=(640, 640)), NormalizeImage(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), ImagePermute(permutation=(2, 0, 1)), ] ) - class_names = COCO_DETECTION_CLASSES_LIST - return class_names, image_processor + + params = dict( + class_names=COCO_DETECTION_CLASSES_LIST, + image_processor=image_processor, + iou=0.65, + conf=0.5, + ) + return params + + +def default_deciyolo_coco_processing_params() -> dict: + """Processing parameters commonly used for training DeciYolo on COCO dataset. + TODO: remove once we load it from the checkpoint + """ + + image_processor = ComposeProcessing( + [ + DetectionLongestMaxSizeRescale(output_shape=(636, 636)), + DetectionCenterPadding(output_shape=(640, 640), pad_value=114), + StandardizeImage(max_value=255.0), + ImagePermute(permutation=(2, 0, 1)), + ] + ) + + params = dict( + class_names=COCO_DETECTION_CLASSES_LIST, + image_processor=image_processor, + iou=0.65, + conf=0.5, + ) + return params + + +def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -> dict: + """Get the processing parameters for a pretrained model. + TODO: remove once we load it from the checkpoint + """ + if pretrained_weights == "coco": + if "yolox" in model_name: + return default_yolox_coco_processing_params() + elif "ppyoloe" in model_name: + return default_ppyoloe_coco_processing_params() + elif "deciyolo" in model_name: + return default_deciyolo_coco_processing_params() + return dict() diff --git a/src/super_gradients/training/utils/load_image.py b/src/super_gradients/training/utils/load_image.py index 2173c76bcf..e69de29bb2 100644 --- a/src/super_gradients/training/utils/load_image.py +++ b/src/super_gradients/training/utils/load_image.py @@ -1,87 +0,0 @@ -from typing import Union, List -import PIL - -import numpy as np -import torch -import requests -from urllib.parse import urlparse - -IMG_EXTENSIONS = ("bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm") -ImageType = Union[str, np.ndarray, torch.Tensor, PIL.Image.Image] - - -def load_images(images: Union[List[ImageType], ImageType]) -> List[np.ndarray]: - """Load a single image or a list of images and return them as a list of numpy arrays. - - Supported image types include: - - numpy.ndarray: A numpy array representing the image - - torch.Tensor: A PyTorch tensor representing the image - - PIL.Image.Image: A PIL Image object - - str: A string representing either a local file path or a URL to an image - - :param images: Single image or a list of images of supported types. - :return: List of images as numpy arrays. If loaded from string, the image will be returned as RGB. - """ - if isinstance(images, list): - return [load_image(image=image) for image in images] - else: - return [load_image(image=images)] - - -def load_image(image: ImageType) -> np.ndarray: - """Load a single image and return it as a numpy arrays. - - Supported image types include: - - numpy.ndarray: A numpy array representing the image - - torch.Tensor: A PyTorch tensor representing the image - - PIL.Image.Image: A PIL Image object - - str: A string representing either a local file path or a URL to an image - - :param image: Single image of supported types. - :return: Image as numpy arrays. If loaded from string, the image will be returned as RGB. - """ - if isinstance(image, np.ndarray): - return image - elif isinstance(image, torch.Tensor): - return image.numpy() - elif isinstance(image, PIL.Image.Image): - return load_np_image_from_pil(image) - elif isinstance(image, str): - image = load_pil_image_from_str(image_str=image) - return load_np_image_from_pil(image) - else: - raise ValueError(f"Unsupported image type: {type(image)}") - - -def load_np_image_from_pil(image: PIL.Image.Image) -> np.ndarray: - """Convert a PIL image to numpy array in RGB format.""" - return np.asarray(image.convert("RGB")) - - -def load_pil_image_from_str(image_str: str) -> PIL.Image.Image: - """Load an image based on a string (local file path or URL).""" - - if is_url(image_str): - response = requests.get(image_str, stream=True) - response.raise_for_status() - return PIL.Image.open(response.raw) - else: - return PIL.Image.open(image_str) - - -def is_url(url: str) -> bool: - """Check if the given string is a URL.""" - try: - result = urlparse(url) - return all([result.scheme, result.netloc, result.path]) - except Exception: - return False - - -def is_image(filename: str) -> bool: - """Check if the given file name refers to image. - - :param filename: The filename to check. - :return: True if the file is an image, False otherwise. - """ - return filename.split(".")[-1].lower() in IMG_EXTENSIONS diff --git a/src/super_gradients/training/utils/media/__init__.py b/src/super_gradients/training/utils/media/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/super_gradients/training/utils/media/image.py b/src/super_gradients/training/utils/media/image.py new file mode 100644 index 0000000000..2d32a7e3c5 --- /dev/null +++ b/src/super_gradients/training/utils/media/image.py @@ -0,0 +1,158 @@ +from typing import Union, List, Iterable, Iterator +from typing_extensions import get_args +import PIL + +import os +from PIL import Image +import matplotlib.pyplot as plt + +import numpy as np +import torch +import requests +from urllib.parse import urlparse + + +IMG_EXTENSIONS = ("bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm") +SingleImageSource = Union[str, np.ndarray, torch.Tensor, PIL.Image.Image] +ImageSource = Union[SingleImageSource, List[SingleImageSource]] + + +def load_images(images: Union[List[ImageSource], ImageSource]) -> List[np.ndarray]: + """Load a single image or a list of images and return them as a list of numpy arrays. + + Supported types include: + - str: A string representing either an image or an URL. + - numpy.ndarray: A numpy array representing the image + - torch.Tensor: A PyTorch tensor representing the image + - PIL.Image.Image: A PIL Image object + - List: A list of images of any of the above types. + + :param images: Single image or a list of images of supported types. + :return: List of images as numpy arrays. If loaded from string, the image will be returned as RGB. + """ + return [image for image in generate_image_loader(images=images)] + + +def generate_image_loader(images: Union[List[ImageSource], ImageSource]) -> Iterable[np.ndarray]: + """Generator that loads images one at a time. + + Supported types include: + - str: A string representing either an image or an URL. + - numpy.ndarray: A numpy array representing the image + - torch.Tensor: A PyTorch tensor representing the image + - PIL.Image.Image: A PIL Image object + - List: A list of images of any of the above types. + + :param images: Single image or a list of images of supported types. + :return: Generator of images as numpy arrays. If loaded from string, the image will be returned as RGB. + """ + if isinstance(images, str) and os.path.isdir(images): + images_paths = list_images_in_folder(images) + for image_path in images_paths: + yield load_image(image=image_path) + elif isinstance(images, (list, Iterator)): + for image in images: + yield load_image(image=image) + else: + yield load_image(image=images) + + +def list_images_in_folder(directory: str) -> List[str]: + """List all the images in a directory. + :param directory: The path to the directory containing the images. + :return: A list of image file names. + """ + files = os.listdir(directory) + images_paths = [os.path.join(directory, f) for f in files if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif"))] + return images_paths + + +def load_image(image: ImageSource) -> np.ndarray: + """Load a single image and return it as a numpy arrays. + + Supported image types include: + - numpy.ndarray: A numpy array representing the image + - torch.Tensor: A PyTorch tensor representing the image + - PIL.Image.Image: A PIL Image object + - str: A string representing either a local file path or a URL to an image + + :param image: Single image of supported types. + :return: Image as numpy arrays. If loaded from string, the image will be returned as RGB. + """ + if isinstance(image, np.ndarray): + return image + elif isinstance(image, torch.Tensor): + return image.numpy() + elif isinstance(image, PIL.Image.Image): + return load_np_image_from_pil(image) + elif isinstance(image, str): + image = load_pil_image_from_str(image_str=image) + return load_np_image_from_pil(image) + else: + raise ValueError(f"Unsupported image type: {type(image)}") + + +def load_np_image_from_pil(image: PIL.Image.Image) -> np.ndarray: + """Convert a PIL image to numpy array in RGB format.""" + return np.asarray(image.convert("RGB")) + + +def load_pil_image_from_str(image_str: str) -> PIL.Image.Image: + """Load an image based on a string (local file path or URL).""" + + if is_url(image_str): + response = requests.get(image_str, stream=True) + response.raise_for_status() + return PIL.Image.open(response.raw) + else: + return PIL.Image.open(image_str) + + +def save_image(image: np.ndarray, path: str) -> None: + """Save a numpy array as an image. + :param image: Image to save, (H, W, C), RGB. + :param path: Path to save the image to. + """ + Image.fromarray(image).save(path) + + +def is_url(url: str) -> bool: + """Check if the given string is a URL. + :param url: String to check. + """ + try: + result = urlparse(url) + return all([result.scheme, result.netloc, result.path]) + except Exception: + return False + + +def show_image(image: np.ndarray) -> None: + """Show an image using matplotlib. + :param image: Image to show in (H, W, C), RGB. + """ + plt.imshow(image, interpolation="nearest") + plt.axis("off") + plt.show() + + +def check_image_typing(image: ImageSource) -> bool: + """Check if the given object respects typing of image. + :param image: Image to check. + :return: True if the object is an image, False otherwise. + """ + if isinstance(image, get_args(SingleImageSource)): + return True + elif isinstance(image, list): + return all([isinstance(image_item, get_args(SingleImageSource)) for image_item in image]) + else: + return False + + +def is_image(filename: str) -> bool: + """Check if the given file name refers to image. + + :param filename: The filename to check. + :return: True if the file is an image, False otherwise. + """ + return filename.split(".")[-1].lower() in IMG_EXTENSIONS diff --git a/src/super_gradients/training/utils/media/stream.py b/src/super_gradients/training/utils/media/stream.py new file mode 100644 index 0000000000..73d969f85d --- /dev/null +++ b/src/super_gradients/training/utils/media/stream.py @@ -0,0 +1,117 @@ +import cv2 +import numpy as np +import time +from typing import Callable, Optional + + +__all__ = ["WebcamStreaming"] + + +class WebcamStreaming: + """Stream video from a webcam. Press 'q' to quit the streaming. + + :param window_name: Name of the window to display the video stream. + :param frame_processing_fn: Function to apply to each frame before displaying it. + If None, frames are displayed as is. + :param capture: ID of the video capture device to use. + Default is cv2.CAP_ANY (which selects the first available device). + :param fps_update_frequency: Minimum time (in seconds) between updates to the FPS counter. + If None, the counter is updated every frame. + """ + + def __init__( + self, + window_name: str = "", + frame_processing_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None, + capture: int = cv2.CAP_ANY, + fps_update_frequency: Optional[float] = None, + ): + self.window_name = window_name + self.frame_processing_fn = frame_processing_fn + self.cap = cv2.VideoCapture(capture) + if not self.cap.isOpened(): + raise ValueError("Could not open video capture device") + + self._fps_counter = FPSCounter(update_frequency=fps_update_frequency) + + def run(self) -> None: + """Start streaming video from the webcam and displaying it in a window. + + Press 'q' to quit the streaming. + """ + while not self._stop(): + self._display_single_frame() + + def _display_single_frame(self) -> None: + """Read a single frame from the video capture device, apply any specified frame processing, + and display the resulting frame in the window. + + Also updates the FPS counter and displays it in the frame. + """ + _ret, frame = self.cap.read() + + if self.frame_processing_fn: + frame = self.frame_processing_fn(frame) + + _write_fps_to_frame(frame, self.fps) + cv2.imshow(self.window_name, frame) + + def _stop(self) -> bool: + """Stopping condition for the streaming.""" + return cv2.waitKey(1) & 0xFF == ord("q") + + @property + def fps(self) -> float: + return self._fps_counter.fps + + def __del__(self): + """Release the video capture device and close the window.""" + self.cap.release() + cv2.destroyAllWindows() + + +def _write_fps_to_frame(frame: np.ndarray, fps: float) -> None: + """Write the current FPS value on the given frame. + + :param frame: Frame to write the FPS value on. + :param fps: Current FPS value to write. + """ + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + font_color = (0, 255, 0) + line_type = 2 + cv2.putText(frame, "FPS: {:.2f}".format(fps), (10, 30), font, font_scale, font_color, line_type) + + +class FPSCounter: + """Class for calculating the FPS of a video stream.""" + + def __init__(self, update_frequency: Optional[float] = None): + """Create a new FPSCounter object. + + :param update_frequency: Minimum time (in seconds) between updates to the FPS counter. + If None, the counter is updated every frame. + """ + self._update_frequency = update_frequency + + self._start_time = time.time() + self._frame_count = 0 + self._fps = 0.0 + + def _update_fps(self, elapsed_time, current_time) -> None: + """Compute new value of FPS and reset the counter.""" + self._fps = self._frame_count / elapsed_time + self._start_time = current_time + self._frame_count = 0 + + @property + def fps(self) -> float: + """Current FPS value.""" + + self._frame_count += 1 + current_time, elapsed_time = time.time(), time.time() - self._start_time + + if self._update_frequency is None or elapsed_time > self._update_frequency: + self._update_fps(elapsed_time=elapsed_time, current_time=current_time) + + return self._fps diff --git a/src/super_gradients/training/utils/videos.py b/src/super_gradients/training/utils/media/video.py similarity index 63% rename from src/super_gradients/training/utils/videos.py rename to src/super_gradients/training/utils/media/video.py index d5e58a62f2..9ea78d4ae6 100644 --- a/src/super_gradients/training/utils/videos.py +++ b/src/super_gradients/training/utils/media/video.py @@ -4,7 +4,7 @@ import numpy as np -__all__ = ["load_video", "save_video"] +__all__ = ["load_video", "save_video", "is_video", "show_video_from_disk", "show_video_from_frames"] def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]: @@ -30,6 +30,7 @@ def _open_video(file_path: str) -> cv2.VideoCapture: :return: Opened video capture object """ cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): raise ValueError(f"Failed to open video file: {file_path}") return cap @@ -97,3 +98,60 @@ def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]: raise RuntimeError("Your frames must include 3 channels.") return max_height, max_width + + +def show_video_from_disk(video_path: str, window_name: str = "Prediction"): + """Display a video from disk using OpenCV. + + :param video_path: Path to the video file. + :param window_name: Name of the window to display the video + """ + cap = _open_video(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + + while cap.isOpened(): + ret, frame = cap.read() + + if ret: + # Display the frame + cv2.imshow(window_name, frame) + + # Wait for the specified number of milliseconds before displaying the next frame + if cv2.waitKey(int(1000 / fps)) & 0xFF == ord("q"): + break + else: + break + + # Release the VideoCapture object and destroy the window + cap.release() + cv2.destroyAllWindows() + cv2.waitKey(1) + + +def show_video_from_frames(frames: List[np.ndarray], fps: float, window_name: str = "Prediction") -> None: + """Display a video from a list of frames using OpenCV. + + :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape. + :param fps: Frames per second + :param window_name: Name of the window to display the video + """ + for frame in frames: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + cv2.imshow(window_name, frame) + cv2.waitKey(int(1000 / fps)) + cv2.destroyAllWindows() + cv2.waitKey(1) + + +def is_video(file_path: str) -> bool: + """Check if a file is a video file. + :param file_path: Path to the video file. + :return: True if the file is a video file, False otherwise. + """ + try: + cap = cv2.VideoCapture(file_path, apiPreference=cv2.CAP_FFMPEG) + if cap.isOpened(): + cap.release() + return True + except Exception: + return False diff --git a/src/super_gradients/training/utils/utils.py b/src/super_gradients/training/utils/utils.py index 475d78ca2b..ec4eed01f7 100755 --- a/src/super_gradients/training/utils/utils.py +++ b/src/super_gradients/training/utils/utils.py @@ -1,18 +1,19 @@ +import os +import tarfile +import re import math import time from functools import lru_cache from pathlib import Path -from typing import Mapping, Optional, Tuple, Union, List, Dict, Any +from typing import Mapping, Optional, Tuple, Union, List, Dict, Any, Iterable from zipfile import ZipFile -import os from jsonschema import validate -import tarfile +from itertools import islice + from PIL import Image, ExifTags -import re import torch import torch.nn as nn - # These functions changed from torch 1.2 to torch 1.3 import random @@ -526,3 +527,14 @@ def override_default_params_without_nones(params: Dict, default_params: Mapping) if key not in params.keys() or params[key] is None: params[key] = val return params + + +def generate_batch(iterable: Iterable, batch_size: int) -> Iterable: + """Batch data into tuples of length n. The last batch may be shorter.""" + it = iter(iterable) + while True: + batch = tuple(islice(it, batch_size)) + if batch: + yield batch + else: + return