Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utilities to plot datasets to Weights & Biases + Add callback to log validation predictions to Weights & Biases #1167

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a81fddb
add: plot_detection_dataset_on_wandb
soumik12345 Jun 9, 2023
76dc7b3
add: doctring for plot_detection_dataset_on_wandb
soumik12345 Jun 9, 2023
07f8c0c
update: bbox type
soumik12345 Jun 9, 2023
8f7b5c6
update: visualize_image_detection_prediction_on_wandb
soumik12345 Jun 9, 2023
71a52af
update: fix wandb module and linting
soumik12345 Jun 9, 2023
de56724
add: WandBDetectionValidationPredictionLoggerCallback
soumik12345 Jun 9, 2023
facad14
update: docstring for WandBDetectionValidationPredictionLoggerCallback
soumik12345 Jun 9, 2023
4c3e328
update: WandBDetectionValidationPredictionLoggerCallback
soumik12345 Jun 12, 2023
ddcd630
update: plot_detection_dataset_on_wandb
soumik12345 Jun 13, 2023
65fb521
update: imports
soumik12345 Jun 13, 2023
2b9bdbf
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 14, 2023
59edcb2
fix: linting
soumik12345 Jun 14, 2023
7271db4
fix: linting
soumik12345 Jun 14, 2023
abf274d
fix: linting
soumik12345 Jun 14, 2023
568c749
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 18, 2023
ae2e712
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 19, 2023
c8c17a2
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 19, 2023
06e9134
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 20, 2023
00755a8
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 21, 2023
60baa55
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 22, 2023
ae339d6
fix: imports in validation_logger.py
soumik12345 Jun 22, 2023
d9b5bf3
Update log_predictions.py
soumik12345 Jun 22, 2023
8f3911a
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 27, 2023
d6c75e7
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 27, 2023
1600fd7
Merge branch 'master' into soumik12345/wandb-validation-logging
BloodAxe Jun 28, 2023
7375d5f
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jun 28, 2023
344c4e8
update: add max_predictions_plotted parameter to WandBDetectionValida…
soumik12345 Jun 28, 2023
8522be2
update: docstring for WandBDetectionValidationPredictionLoggerCallback
soumik12345 Jun 28, 2023
73cb61f
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jul 7, 2023
a9d0abe
Merge branch 'master' into soumik12345/wandb-validation-logging
BloodAxe Jul 14, 2023
290d4cd
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jul 20, 2023
a8e8c42
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jul 24, 2023
b7225be
update: validation logger
soumik12345 Jul 27, 2023
caef860
update: WandBDetectionValidationPredictionLoggerCallback + visualize_…
soumik12345 Jul 27, 2023
880f55d
update: wandb import + docstring
soumik12345 Jul 27, 2023
a5c4328
update: doctring
soumik12345 Jul 27, 2023
49ba703
update: docstrings + channel reversal
soumik12345 Jul 28, 2023
b6acd23
update: make ci happy
soumik12345 Jul 28, 2023
b7285a6
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Jul 31, 2023
da2d0fb
Merge branch 'master' into soumik12345/wandb-validation-logging
BloodAxe Aug 7, 2023
970c730
Merge branch 'master' into soumik12345/wandb-validation-logging
soumik12345 Aug 7, 2023
06d59c5
Merge branch 'master' into soumik12345/wandb-validation-logging
BloodAxe Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/super_gradients/common/plugins/wandb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from super_gradients.common.plugins.wandb.log_predictions import log_detection_results_to_wandb
from super_gradients.common.plugins.wandb.log_predictions import (
visualize_image_detection_prediction_on_wandb,
log_detection_results_to_wandb,
plot_detection_dataset_on_wandb,
)
from super_gradients.common.plugins.wandb.validation_logger import WandBDetectionValidationPredictionLoggerCallback


__all__ = ["log_detection_results_to_wandb"]
__all__ = [
"visualize_image_detection_prediction_on_wandb",
"log_detection_results_to_wandb",
"plot_detection_dataset_on_wandb",
"WandBDetectionValidationPredictionLoggerCallback",
]
66 changes: 61 additions & 5 deletions src/super_gradients/common/plugins/wandb/log_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,27 @@
except (ModuleNotFoundError, ImportError, NameError):
pass # no action or logging - this is normal in most cases

import numpy as np
from tqdm import tqdm

from super_gradients.training.transforms.transforms import DetectionTargetsFormatTransform
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
from super_gradients.training.datasets.detection_datasets import DetectionDataset

from super_gradients.training.utils.predict import ImageDetectionPrediction, ImagesDetectionPrediction


def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool):
def visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool, reverse_channels: bool = False):
"""Visualize detection results on a single image.

:param prediction: Prediction results of a single image
(a `super_gradients.training.models.prediction_results.ImageDetectionPrediction` object)
:param show_confidence: Whether to log confidence scores to Weights & Biases or not.
:param reverse_channels: Reverse the order of channels on the images while plotting.
"""
boxes = []
image = prediction.image.copy()
image = image[:, :, ::-1] if reverse_channels else image
height, width, _ = image.shape
class_id_to_labels = {int(_id): str(_class_name) for _id, _class_name in enumerate(prediction.class_names)}

Expand All @@ -28,9 +43,7 @@ def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPre
box["scores"] = {"confidence": float(round(prediction.prediction.confidence[pred_i], 2))}
boxes.append(box)

wandb_image = wandb.Image(image, boxes={"predictions": {"box_data": boxes, "class_labels": class_id_to_labels}})

wandb.log({"Predictions": wandb_image})
return wandb.Image(image, boxes={"predictions": {"box_data": boxes, "class_labels": class_id_to_labels}})


def log_detection_results_to_wandb(prediction: ImagesDetectionPrediction, show_confidence: bool = True):
Expand All @@ -42,4 +55,47 @@ def log_detection_results_to_wandb(prediction: ImagesDetectionPrediction, show_c
if wandb.run is None:
raise wandb.Error("Images and bounding boxes cannot be visualized on Weights & Biases without initializing a run using `wandb.init()`")
for prediction in prediction._images_prediction_lst:
_visualize_image_detection_prediction_on_wandb(prediction=prediction, show_confidence=show_confidence)
wandb_image = visualize_image_detection_prediction_on_wandb(prediction=prediction, show_confidence=show_confidence)
wandb.log({"Predictions": wandb_image})


def plot_detection_dataset_on_wandb(detection_dataset: DetectionDataset, max_examples: int = None, dataset_name: str = None, reverse_channels: bool = True):
"""Log a detection dataset to Weights & Biases Table.

:param detection_dataset: The Detection Dataset (a `super_gradients.training.datasets.detection_datasets.DetectionDataset` object)
:param max_examples: Maximum number of examples from the detection dataset to plot (an `int`).
:param dataset_name: Name of the dataset (a `str`).
:param reverse_channels: Reverse the order of channels on the images while plotting.
"""
max_examples = len(detection_dataset) if max_examples is None else max_examples
wandb_table = wandb.Table(columns=["Images", "Class-Frequencies"])
input_format = detection_dataset.output_target_format
target_format_transform = DetectionTargetsFormatTransform(input_format=input_format, output_format=XYXY_LABEL)
class_id_to_labels = {int(_id): str(_class_name) for _id, _class_name in enumerate(detection_dataset.classes)}
for data_idx in tqdm(range(max_examples), desc="Plotting Examples on Weights & Biases"):
image, targets, *_ = detection_dataset[data_idx]
image = image.transpose(1, 2, 0).astype(np.int32)
sample = target_format_transform({"image": image, "target": targets})
boxes = sample["target"][:, 0:4]
boxes = boxes[(boxes != 0).any(axis=1)]
classes = targets[:, 0].tolist()
wandb_boxes = []
class_frequencies = {str(_class_name): 0 for _id, _class_name in enumerate(detection_dataset.classes)}
for idx in range(boxes.shape[0]):
wandb_boxes.append(
{
"position": {
"minX": float(boxes[idx][0] / image.shape[1]),
"maxX": float(boxes[idx][2] / image.shape[1]),
"minY": float(boxes[idx][1] / image.shape[0]),
"maxY": float(boxes[idx][3] / image.shape[0]),
},
"class_id": int(classes[idx]),
"box_caption": str(class_id_to_labels[int(classes[idx])]),
}
)
class_frequencies[str(class_id_to_labels[int(classes[idx])])] += 1
image = image[:, :, ::-1] if reverse_channels else image
wandb_table.add_data(wandb.Image(image, boxes={"ground_truth": {"box_data": wandb_boxes, "class_labels": class_id_to_labels}}), class_frequencies)
dataset_name = "Dataset" if dataset_name is None else dataset_name
wandb.log({dataset_name: wandb_table}, commit=False)
99 changes: 99 additions & 0 deletions src/super_gradients/common/plugins/wandb/validation_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Optional

import torch
import numpy as np

from super_gradients.training.utils.callbacks import Callback, PhaseContext
from super_gradients.common.plugins.wandb.log_predictions import visualize_image_detection_prediction_on_wandb
from super_gradients.training.models.predictions import DetectionPrediction
from super_gradients.training.utils.predict import ImageDetectionPrediction
from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
from super_gradients.module_interfaces import HasPredict
from super_gradients.training.utils.utils import unwrap_model

try:
import wandb
except (ModuleNotFoundError, ImportError, NameError):
pass # no action or logging - this is normal in most cases


class WandBDetectionValidationPredictionLoggerCallback(Callback):
def __init__(
self,
class_names,
max_predictions_plotted: Optional[int] = None,
post_prediction_callback: Optional[DetectionPostPredictionCallback] = None,
reverse_channels: bool = True,
) -> None:
"""A callback for logging object detection predictions to Weights & Biases during training. This callback is logging images on each batch in validation
and accumulating generated images in a `wandb.Table` in the RAM. This could potentially cause OOM errors for very large datasets like COCO. In order to
avoid this, it is recommended to explicitly set the parameter `max_predictions_plotted` to a small value, thus limiting the number of images logged in
the table.

:param class_names: A list of class names.
:param max_predictions_plotted: Maximum number of predictions to be plotted per epoch. This is set to `None` by default which means that the
predictions corresponding to all images from `context.inputs` is logged, otherwise only `max_predictions_plotted`
number of images is logged. Since `WandBDetectionValidationPredictionLoggerCallback` accumulates the generated
images in the RAM, it is advisable that the value of this parameter be explicitly specified for larger datasets in
order to avoid out-of-memory errors.
:param post_prediction_callback: `DetectionPostPredictionCallback` for post-processing outputs of the model.
:param reverse_channels: Reverse the order of channels on the images while plotting.
"""
super().__init__()
self.class_names = class_names
self.max_predictions_plotted = max_predictions_plotted
self.post_prediction_callback = post_prediction_callback
self.reverse_channels = reverse_channels
self.wandb_images = []
self.epoch_count = 0
self.mean_prediction_dicts = []
self.wandb_table = wandb.Table(columns=["Epoch", "Prediction", "Mean-Confidence"])

def on_validation_batch_end(self, context: PhaseContext) -> None:
self.wandb_images = []
mean_prediction_dict = {class_name: 0.0 for class_name in self.class_names}
if isinstance(context.net, HasPredict):
post_nms_predictions = context.net(context.inputs)
else:
self.post_prediction_callback = (
unwrap_model(context.net).get_post_prediction_callback() if self.post_prediction_callback is None else self.post_prediction_callback
)
self.post_prediction_callback.fuse_layers = False
post_nms_predictions = self.post_prediction_callback(context.preds, device=context.device)
if self.max_predictions_plotted is not None:
post_nms_predictions = post_nms_predictions[: self.max_predictions_plotted]
input_images = context.inputs[: self.max_predictions_plotted]
else:
input_images = context.inputs
for prediction, image in zip(post_nms_predictions, input_images):
prediction = prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32)
prediction = prediction.detach().cpu().numpy()
postprocessed_image = image.detach().cpu().numpy().transpose(1, 2, 0).astype(np.int32)
image_prediction = ImageDetectionPrediction(
image=postprocessed_image,
class_names=self.class_names,
prediction=DetectionPrediction(
bboxes=prediction[:, :4],
confidence=prediction[:, 4],
labels=prediction[:, 5],
bbox_format="xyxy",
image_shape=image.shape,
),
)
for predicted_label, prediction_confidence in zip(prediction[:, 5], prediction[:, 4]):
mean_prediction_dict[self.class_names[int(predicted_label)]] += prediction_confidence
mean_prediction_dict = {k: v / len(prediction[:, 4]) for k, v in mean_prediction_dict.items()}
self.mean_prediction_dicts.append(mean_prediction_dict)
wandb_image = visualize_image_detection_prediction_on_wandb(
prediction=image_prediction, show_confidence=True, reverse_channels=self.reverse_channels
)
self.wandb_images.append(wandb_image)

def on_validation_loader_end(self, context: PhaseContext) -> None:
for wandb_image, mean_prediction_dict in zip(self.wandb_images, self.mean_prediction_dicts):
self.wandb_table.add_data(self.epoch_count, wandb_image, mean_prediction_dict)
self.wandb_images, self.mean_prediction_dicts = [], []
self.epoch_count += 1

def on_training_end(self, context: PhaseContext) -> None:
wandb.log({"Validation-Prediction": self.wandb_table})