Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DagsHub Logger: Fix unsupported metric formats for MLflow, Add example notebook #915

Merged
merged 17 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions src/super_gradients/common/sg_loggers/dagshub_sg_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
from pathlib import Path
from typing import Optional
from typing import Optional, Mapping

import torch

Expand Down Expand Up @@ -114,7 +115,7 @@ def __init__(
def splitter(repo):
splitted = repo.split("/")
if len(splitted) != 2:
raise ValueError(f"Invalid input, should be owner_name/repo_name, but got {repo} instead")
raise Exception(f"Invalid input, should be owner_name/repo_name, but got {repo} instead")
return splitted[1], splitted[0]

def _init_env_dependency(self):
Expand Down Expand Up @@ -166,28 +167,69 @@ def _dvc_add(self, local_path="", remote_path=""):
def _dvc_commit(self, commit=""):
self.dvc_folder.commit(commit, versioning="dvc", force=True)

@multi_process_safe
def _get_nested_dict_values(self, d, parent_key="", sep="/"):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, Mapping):
items.extend(self._get_nested_dict_values(v, new_key, sep=sep))
else:
items.append((new_key, v))
return items

@multi_process_safe
def _contains_special_characters(self, text):
pattern = r"[!\"#$%&'()*+,:;<=>?@[\]^`{|}~\t\n\r\x0b\x0c]"
matches = re.findall(pattern, text)
if matches:
return True, ", ".join(matches)
return False, None

@multi_process_safe
def add_config(self, tag: str, config: dict):
super(DagsHubSGLogger, self).add_config(tag=tag, config=config)
param_keys = config.keys()
for pk in param_keys:
for k, v in config[pk].items():
try:
mlflow.log_params({k: v})
except Exception:
logger.warning(f"Skip to log {k}: {v}")
flatten_dict = self._get_nested_dict_values(d=config)
for k, v in flatten_dict:
try:
mlflow.log_params({k: v})
except Exception as e:
is_contain, spec_char = self._contains_special_characters(k)
if is_contain:
err_msg = f"Fail to log {k}, please remove the unsupported characters: {spec_char}"
nirbarazida marked this conversation as resolved.
Show resolved Hide resolved
else:
err_msg = f"Fail to log the config: {k}, got an expection: {e}"
logger.warning(err_msg)

@multi_process_safe
def add_scalar(self, tag: str, scalar_value: float, global_step: [int, TimeUnit] = 0):
super(DagsHubSGLogger, self).add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step)
if isinstance(global_step, TimeUnit):
global_step = global_step.get_value()
mlflow.log_metric(key=tag, value=scalar_value, step=global_step)
try:
mlflow.log_metric(key=tag, value=scalar_value, step=global_step)
nirbarazida marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
is_contain, spec_char = self._contains_special_characters(tag)
if is_contain:
err_msg = f"Fail to log {tag}, please remove the unsupported characters: {spec_char}"
else:
err_msg = f"Fail to log the metric: {tag}, got an expection: {e}"
raise Exception(err_msg)

@multi_process_safe
def add_scalars(self, tag_scalar_dict: dict, global_step: int = 0):
super(DagsHubSGLogger, self).add_scalars(tag_scalar_dict=tag_scalar_dict, global_step=global_step)
mlflow.log_metrics(metrics=tag_scalar_dict, step=global_step)
try:
mlflow.log_metrics(metrics=tag_scalar_dict, step=global_step)
except Exception:
flatten_dicts = self._get_nested_dict_values(tag_scalar_dict)
for k, v in flatten_dicts:
try:
if isinstance(v, torch.Tensor):
v = v.item()
else:
v = float(v)
self.add_scalar(tag=k.replace("@", "at"), scalar_value=v, global_step=global_step)
nirbarazida marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
logger.warning(e)

@multi_process_safe
def close(self):
Expand Down
Loading