diff --git a/src/super_gradients/common/plugins/wandb/__init__.py b/src/super_gradients/common/plugins/wandb/__init__.py index 8614e61d98..f50cd75fad 100644 --- a/src/super_gradients/common/plugins/wandb/__init__.py +++ b/src/super_gradients/common/plugins/wandb/__init__.py @@ -1,4 +1,14 @@ -from super_gradients.common.plugins.wandb.log_predictions import log_detection_results_to_wandb +from super_gradients.common.plugins.wandb.log_predictions import ( + visualize_image_detection_prediction_on_wandb, + log_detection_results_to_wandb, + plot_detection_dataset_on_wandb, +) +from super_gradients.common.plugins.wandb.validation_logger import WandBDetectionValidationPredictionLoggerCallback -__all__ = ["log_detection_results_to_wandb"] +__all__ = [ + "visualize_image_detection_prediction_on_wandb", + "log_detection_results_to_wandb", + "plot_detection_dataset_on_wandb", + "WandBDetectionValidationPredictionLoggerCallback", +] diff --git a/src/super_gradients/common/plugins/wandb/log_predictions.py b/src/super_gradients/common/plugins/wandb/log_predictions.py index 6043bbc0d3..a242349b58 100644 --- a/src/super_gradients/common/plugins/wandb/log_predictions.py +++ b/src/super_gradients/common/plugins/wandb/log_predictions.py @@ -3,12 +3,27 @@ except (ModuleNotFoundError, ImportError, NameError): pass # no action or logging - this is normal in most cases +import numpy as np +from tqdm import tqdm + +from super_gradients.training.transforms.transforms import DetectionTargetsFormatTransform +from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL +from super_gradients.training.datasets.detection_datasets import DetectionDataset + from super_gradients.training.utils.predict import ImageDetectionPrediction, ImagesDetectionPrediction -def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool): +def visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool, reverse_channels: bool = False): + """Visualize detection results on a single image. + + :param prediction: Prediction results of a single image + (a `super_gradients.training.models.prediction_results.ImageDetectionPrediction` object) + :param show_confidence: Whether to log confidence scores to Weights & Biases or not. + :param reverse_channels: Reverse the order of channels on the images while plotting. + """ boxes = [] image = prediction.image.copy() + image = image[:, :, ::-1] if reverse_channels else image height, width, _ = image.shape class_id_to_labels = {int(_id): str(_class_name) for _id, _class_name in enumerate(prediction.class_names)} @@ -28,9 +43,7 @@ def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPre box["scores"] = {"confidence": float(round(prediction.prediction.confidence[pred_i], 2))} boxes.append(box) - wandb_image = wandb.Image(image, boxes={"predictions": {"box_data": boxes, "class_labels": class_id_to_labels}}) - - wandb.log({"Predictions": wandb_image}) + return wandb.Image(image, boxes={"predictions": {"box_data": boxes, "class_labels": class_id_to_labels}}) def log_detection_results_to_wandb(prediction: ImagesDetectionPrediction, show_confidence: bool = True): @@ -42,4 +55,47 @@ def log_detection_results_to_wandb(prediction: ImagesDetectionPrediction, show_c if wandb.run is None: raise wandb.Error("Images and bounding boxes cannot be visualized on Weights & Biases without initializing a run using `wandb.init()`") for prediction in prediction._images_prediction_lst: - _visualize_image_detection_prediction_on_wandb(prediction=prediction, show_confidence=show_confidence) + wandb_image = visualize_image_detection_prediction_on_wandb(prediction=prediction, show_confidence=show_confidence) + wandb.log({"Predictions": wandb_image}) + + +def plot_detection_dataset_on_wandb(detection_dataset: DetectionDataset, max_examples: int = None, dataset_name: str = None, reverse_channels: bool = True): + """Log a detection dataset to Weights & Biases Table. + + :param detection_dataset: The Detection Dataset (a `super_gradients.training.datasets.detection_datasets.DetectionDataset` object) + :param max_examples: Maximum number of examples from the detection dataset to plot (an `int`). + :param dataset_name: Name of the dataset (a `str`). + :param reverse_channels: Reverse the order of channels on the images while plotting. + """ + max_examples = len(detection_dataset) if max_examples is None else max_examples + wandb_table = wandb.Table(columns=["Images", "Class-Frequencies"]) + input_format = detection_dataset.output_target_format + target_format_transform = DetectionTargetsFormatTransform(input_format=input_format, output_format=XYXY_LABEL) + class_id_to_labels = {int(_id): str(_class_name) for _id, _class_name in enumerate(detection_dataset.classes)} + for data_idx in tqdm(range(max_examples), desc="Plotting Examples on Weights & Biases"): + image, targets, *_ = detection_dataset[data_idx] + image = image.transpose(1, 2, 0).astype(np.int32) + sample = target_format_transform({"image": image, "target": targets}) + boxes = sample["target"][:, 0:4] + boxes = boxes[(boxes != 0).any(axis=1)] + classes = targets[:, 0].tolist() + wandb_boxes = [] + class_frequencies = {str(_class_name): 0 for _id, _class_name in enumerate(detection_dataset.classes)} + for idx in range(boxes.shape[0]): + wandb_boxes.append( + { + "position": { + "minX": float(boxes[idx][0] / image.shape[1]), + "maxX": float(boxes[idx][2] / image.shape[1]), + "minY": float(boxes[idx][1] / image.shape[0]), + "maxY": float(boxes[idx][3] / image.shape[0]), + }, + "class_id": int(classes[idx]), + "box_caption": str(class_id_to_labels[int(classes[idx])]), + } + ) + class_frequencies[str(class_id_to_labels[int(classes[idx])])] += 1 + image = image[:, :, ::-1] if reverse_channels else image + wandb_table.add_data(wandb.Image(image, boxes={"ground_truth": {"box_data": wandb_boxes, "class_labels": class_id_to_labels}}), class_frequencies) + dataset_name = "Dataset" if dataset_name is None else dataset_name + wandb.log({dataset_name: wandb_table}, commit=False) diff --git a/src/super_gradients/common/plugins/wandb/validation_logger.py b/src/super_gradients/common/plugins/wandb/validation_logger.py new file mode 100644 index 0000000000..357e39f86d --- /dev/null +++ b/src/super_gradients/common/plugins/wandb/validation_logger.py @@ -0,0 +1,99 @@ +from typing import Optional + +import torch +import numpy as np + +from super_gradients.training.utils.callbacks import Callback, PhaseContext +from super_gradients.common.plugins.wandb.log_predictions import visualize_image_detection_prediction_on_wandb +from super_gradients.training.models.predictions import DetectionPrediction +from super_gradients.training.utils.predict import ImageDetectionPrediction +from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback +from super_gradients.module_interfaces import HasPredict +from super_gradients.training.utils.utils import unwrap_model + +try: + import wandb +except (ModuleNotFoundError, ImportError, NameError): + pass # no action or logging - this is normal in most cases + + +class WandBDetectionValidationPredictionLoggerCallback(Callback): + def __init__( + self, + class_names, + max_predictions_plotted: Optional[int] = None, + post_prediction_callback: Optional[DetectionPostPredictionCallback] = None, + reverse_channels: bool = True, + ) -> None: + """A callback for logging object detection predictions to Weights & Biases during training. This callback is logging images on each batch in validation + and accumulating generated images in a `wandb.Table` in the RAM. This could potentially cause OOM errors for very large datasets like COCO. In order to + avoid this, it is recommended to explicitly set the parameter `max_predictions_plotted` to a small value, thus limiting the number of images logged in + the table. + + :param class_names: A list of class names. + :param max_predictions_plotted: Maximum number of predictions to be plotted per epoch. This is set to `None` by default which means that the + predictions corresponding to all images from `context.inputs` is logged, otherwise only `max_predictions_plotted` + number of images is logged. Since `WandBDetectionValidationPredictionLoggerCallback` accumulates the generated + images in the RAM, it is advisable that the value of this parameter be explicitly specified for larger datasets in + order to avoid out-of-memory errors. + :param post_prediction_callback: `DetectionPostPredictionCallback` for post-processing outputs of the model. + :param reverse_channels: Reverse the order of channels on the images while plotting. + """ + super().__init__() + self.class_names = class_names + self.max_predictions_plotted = max_predictions_plotted + self.post_prediction_callback = post_prediction_callback + self.reverse_channels = reverse_channels + self.wandb_images = [] + self.epoch_count = 0 + self.mean_prediction_dicts = [] + self.wandb_table = wandb.Table(columns=["Epoch", "Prediction", "Mean-Confidence"]) + + def on_validation_batch_end(self, context: PhaseContext) -> None: + self.wandb_images = [] + mean_prediction_dict = {class_name: 0.0 for class_name in self.class_names} + if isinstance(context.net, HasPredict): + post_nms_predictions = context.net(context.inputs) + else: + self.post_prediction_callback = ( + unwrap_model(context.net).get_post_prediction_callback() if self.post_prediction_callback is None else self.post_prediction_callback + ) + self.post_prediction_callback.fuse_layers = False + post_nms_predictions = self.post_prediction_callback(context.preds, device=context.device) + if self.max_predictions_plotted is not None: + post_nms_predictions = post_nms_predictions[: self.max_predictions_plotted] + input_images = context.inputs[: self.max_predictions_plotted] + else: + input_images = context.inputs + for prediction, image in zip(post_nms_predictions, input_images): + prediction = prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32) + prediction = prediction.detach().cpu().numpy() + postprocessed_image = image.detach().cpu().numpy().transpose(1, 2, 0).astype(np.int32) + image_prediction = ImageDetectionPrediction( + image=postprocessed_image, + class_names=self.class_names, + prediction=DetectionPrediction( + bboxes=prediction[:, :4], + confidence=prediction[:, 4], + labels=prediction[:, 5], + bbox_format="xyxy", + image_shape=image.shape, + ), + ) + for predicted_label, prediction_confidence in zip(prediction[:, 5], prediction[:, 4]): + mean_prediction_dict[self.class_names[int(predicted_label)]] += prediction_confidence + mean_prediction_dict = {k: v / len(prediction[:, 4]) for k, v in mean_prediction_dict.items()} + self.mean_prediction_dicts.append(mean_prediction_dict) + wandb_image = visualize_image_detection_prediction_on_wandb( + prediction=image_prediction, show_confidence=True, reverse_channels=self.reverse_channels + ) + self.wandb_images.append(wandb_image) + + def on_validation_loader_end(self, context: PhaseContext) -> None: + for wandb_image, mean_prediction_dict in zip(self.wandb_images, self.mean_prediction_dicts): + self.wandb_table.add_data(self.epoch_count, wandb_image, mean_prediction_dict) + self.wandb_images, self.mean_prediction_dicts = [], [] + self.epoch_count += 1 + + def on_training_end(self, context: PhaseContext) -> None: + wandb.log({"Validation-Prediction": self.wandb_table})