Skip to content

Commit

Permalink
Update Weights & Biases Integration + Add Prediction Visualization fo…
Browse files Browse the repository at this point in the history
…r Object Detection (#1010)

* add: prediction logging on wandb

* update: added wandb to requirements.txt

* update: made wandb run initialization optional inside WandBSGLogger

* update:added wandb artifacts

* update: fix imports

* update: artifact logging

* update: artifact logging

* update: added sync_tensorboard to WandBSGLogger

* fix: unused imports

* update: remove wandb from requirements.txt

* update: refactored wandb module

* update: renamed _save_artifact function

* update: refactor artifact saving functionality

* update: doctrings

* update: refactor artifact saving functionality

* fix: linting

* update: remove sync_tensorboard

* update: made visualize_image_detection_prediction_on_wandb private

* update: remove wandb run check from _save_wandb_artifact since it is ensured that the run is always initialized

* update: add error message for wandb initialization to log_detection_results_to_wandb

* fix: imports

* update: made saving model checkpoints as artifact as optional

* update: wandb run initialization in WandBSGLogger

* fix: linting

* update: warning

* fix: linting

* fix: linting

* fix: typo

---------

Co-authored-by: Ran Rubin <ranrubin@gmail.com>
Co-authored-by: Ofri Masad <ofrimasad@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 5, 2023
1 parent 8abe887 commit e5e3167
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/super_gradients/common/plugins/wandb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from super_gradients.common.plugins.wandb.log_predictions import log_detection_results_to_wandb


__all__ = ["log_detection_results_to_wandb"]
45 changes: 45 additions & 0 deletions src/super_gradients/common/plugins/wandb/log_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
try:
import wandb
except (ModuleNotFoundError, ImportError, NameError):
pass # no action or logging - this is normal in most cases

from super_gradients.training.models.prediction_results import ImageDetectionPrediction, ImagesDetectionPrediction


def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool):
boxes = []
image = prediction.image.copy()
height, width, _ = image.shape
class_id_to_labels = {int(_id): str(_class_name) for _id, _class_name in enumerate(prediction.class_names)}

for pred_i in range(len(prediction.prediction)):
class_id = int(prediction.prediction.labels[pred_i])
box = {
"position": {
"minX": float(int(prediction.prediction.bboxes_xyxy[pred_i, 0]) / width),
"maxX": float(int(prediction.prediction.bboxes_xyxy[pred_i, 2]) / width),
"minY": float(int(prediction.prediction.bboxes_xyxy[pred_i, 1]) / height),
"maxY": float(int(prediction.prediction.bboxes_xyxy[pred_i, 3]) / height),
},
"class_id": int(class_id),
"box_caption": str(prediction.class_names[class_id]),
}
if show_confidence:
box["scores"] = {"confidence": float(round(prediction.prediction.confidence[pred_i], 2))}
boxes.append(box)

wandb_image = wandb.Image(image, boxes={"predictions": {"box_data": boxes, "class_labels": class_id_to_labels}})

wandb.log({"Predictions": wandb_image})


def log_detection_results_to_wandb(prediction: ImagesDetectionPrediction, show_confidence: bool = True):
"""Log predictions for object detection to Weights & Biases using interactive bounding box overlays.
:param prediction: The model predictions (a `super_gradients.training.models.prediction_results.ImagesDetectionPrediction` object)
:param show_confidence: Whether to log confidence scores to Weights & Biases or not.
"""
if wandb.run is None:
raise wandb.Error("Images and bounding boxes cannot be visualized on Weights & Biases without initializing a run using `wandb.init()`")
for prediction in prediction._images_prediction_lst:
_visualize_image_detection_prediction_on_wandb(prediction=prediction, show_confidence=show_confidence)
34 changes: 32 additions & 2 deletions src/super_gradients/common/sg_loggers/wandb_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
api_server: Optional[str] = None,
save_code: bool = False,
monitor_system: bool = None,
save_checkpoint_as_artifact: bool = False,
**kwargs,
):
"""
Expand All @@ -63,6 +64,9 @@ def __init__(
:param save_logs_remote: Saves log files in s3.
:param monitor_system: Not Available for WandB logger. Save the system statistics (GPU utilization, CPU, ...) in the tensorboard
:param save_code: Save current code to wandb
:save_checkpoint_as_artifact: Save model checkpoint using Weights & Biases Artifact. Note that setting this option to True would save model
checkpoints every epoch as a versioned artifact, which will result in use of increased storage usage on
Weights & Biases.
"""
if monitor_system is not None:
logger.warning("monitor_system not available on WandBSGLogger. To remove this warning, please don't set monitor_system in your logger parameters")
Expand Down Expand Up @@ -102,14 +106,24 @@ def __init__(
)
wandb_id = self._get_wandb_id()

run = wandb.init(project=project_name, name=experiment_name, entity=entity, resume=resumed, id=wandb_id, **kwargs)
if wandb.run is None:
run = wandb.init(project=project_name, name=experiment_name, entity=entity, resume=resumed, id=wandb_id, **kwargs)
else:
logger.warning(
"A Weights & Biases run was initialized before initializing `WandBSGLogger`. "
"This means that `super-gradients` cannot control the run ID to which this session will be logged."
)
logger.warning(f"In order to resume this run please call `wandb.init(id={wandb.run.id}, resume='must')` before reinitializing `WandBSGLogger`.")
run = wandb.run

if save_code:
self._save_code_lines()

self._set_wandb_id(run.id)
self.save_checkpoints_wandb = save_checkpoints_remote
self.save_tensorboard_wandb = save_tensorboard_remote
self.save_logs_wandb = save_logs_remote
self.save_checkpoint_as_artifact = save_checkpoint_as_artifact

@multi_process_safe
def _save_code_lines(self):
Expand Down Expand Up @@ -234,6 +248,19 @@ def upload(self):
if self.save_logs_wandb:
wandb.save(glob_str=self.experiment_log_path, base_path=self._local_dir, policy="now")

def _save_wandb_artifact(self, path):
"""Upload a file or a directory as a Weights & Biases Artifact.
Note that this function can be called only after wandb.init()
:param path: the local full path to the pth file to be uploaded
"""
artifact = wandb.Artifact(f"{wandb.run.id}-checkpoint", type="model")
if os.path.isdir(path):
artifact.add_dir(path)
elif os.path.isfile(path):
artifact.add_file(path)
wandb.log_artifact(artifact)

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
name = f"ckpt_{global_step}.pth" if tag is None else tag
Expand All @@ -246,7 +273,10 @@ def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
if self.save_checkpoints_wandb:
if self.s3_location_available:
self.model_checkpoints_data_interface.save_remote_checkpoints_file(self.experiment_name, self._local_dir, name)
wandb.save(glob_str=path, base_path=self._local_dir, policy="now")
if self.save_checkpoint_as_artifact:
self._save_wandb_artifact(path)
else:
wandb.save(glob_str=path, base_path=self._local_dir, policy="now")

def _get_tensorboard_file_name(self):
try:
Expand Down

0 comments on commit e5e3167

Please sign in to comment.