-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Weights & Biases Integration + Add Prediction Visualization fo…
…r Object Detection (#1010) * add: prediction logging on wandb * update: added wandb to requirements.txt * update: made wandb run initialization optional inside WandBSGLogger * update:added wandb artifacts * update: fix imports * update: artifact logging * update: artifact logging * update: added sync_tensorboard to WandBSGLogger * fix: unused imports * update: remove wandb from requirements.txt * update: refactored wandb module * update: renamed _save_artifact function * update: refactor artifact saving functionality * update: doctrings * update: refactor artifact saving functionality * fix: linting * update: remove sync_tensorboard * update: made visualize_image_detection_prediction_on_wandb private * update: remove wandb run check from _save_wandb_artifact since it is ensured that the run is always initialized * update: add error message for wandb initialization to log_detection_results_to_wandb * fix: imports * update: made saving model checkpoints as artifact as optional * update: wandb run initialization in WandBSGLogger * fix: linting * update: warning * fix: linting * fix: linting * fix: typo --------- Co-authored-by: Ran Rubin <ranrubin@gmail.com> Co-authored-by: Ofri Masad <ofrimasad@users.noreply.github.com>
- Loading branch information
1 parent
8abe887
commit e5e3167
Showing
3 changed files
with
81 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from super_gradients.common.plugins.wandb.log_predictions import log_detection_results_to_wandb | ||
|
||
|
||
__all__ = ["log_detection_results_to_wandb"] |
45 changes: 45 additions & 0 deletions
45
src/super_gradients/common/plugins/wandb/log_predictions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
try: | ||
import wandb | ||
except (ModuleNotFoundError, ImportError, NameError): | ||
pass # no action or logging - this is normal in most cases | ||
|
||
from super_gradients.training.models.prediction_results import ImageDetectionPrediction, ImagesDetectionPrediction | ||
|
||
|
||
def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool): | ||
boxes = [] | ||
image = prediction.image.copy() | ||
height, width, _ = image.shape | ||
class_id_to_labels = {int(_id): str(_class_name) for _id, _class_name in enumerate(prediction.class_names)} | ||
|
||
for pred_i in range(len(prediction.prediction)): | ||
class_id = int(prediction.prediction.labels[pred_i]) | ||
box = { | ||
"position": { | ||
"minX": float(int(prediction.prediction.bboxes_xyxy[pred_i, 0]) / width), | ||
"maxX": float(int(prediction.prediction.bboxes_xyxy[pred_i, 2]) / width), | ||
"minY": float(int(prediction.prediction.bboxes_xyxy[pred_i, 1]) / height), | ||
"maxY": float(int(prediction.prediction.bboxes_xyxy[pred_i, 3]) / height), | ||
}, | ||
"class_id": int(class_id), | ||
"box_caption": str(prediction.class_names[class_id]), | ||
} | ||
if show_confidence: | ||
box["scores"] = {"confidence": float(round(prediction.prediction.confidence[pred_i], 2))} | ||
boxes.append(box) | ||
|
||
wandb_image = wandb.Image(image, boxes={"predictions": {"box_data": boxes, "class_labels": class_id_to_labels}}) | ||
|
||
wandb.log({"Predictions": wandb_image}) | ||
|
||
|
||
def log_detection_results_to_wandb(prediction: ImagesDetectionPrediction, show_confidence: bool = True): | ||
"""Log predictions for object detection to Weights & Biases using interactive bounding box overlays. | ||
:param prediction: The model predictions (a `super_gradients.training.models.prediction_results.ImagesDetectionPrediction` object) | ||
:param show_confidence: Whether to log confidence scores to Weights & Biases or not. | ||
""" | ||
if wandb.run is None: | ||
raise wandb.Error("Images and bounding boxes cannot be visualized on Weights & Biases without initializing a run using `wandb.init()`") | ||
for prediction in prediction._images_prediction_lst: | ||
_visualize_image_detection_prediction_on_wandb(prediction=prediction, show_confidence=show_confidence) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters