Skip to content

Commit

Permalink
DagsHub Logger: Fix the error when encounter the unsupported type for…
Browse files Browse the repository at this point in the history
… mlflow. Add a Colab Notebook using DagsHub Logger
  • Loading branch information
nirbarazida committed May 14, 2023
1 parent f5cfccc commit 35998cc
Show file tree
Hide file tree
Showing 2 changed files with 3,772 additions and 10 deletions.
45 changes: 35 additions & 10 deletions src/super_gradients/common/sg_loggers/dagshub_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
def splitter(repo):
splitted = repo.split("/")
if len(splitted) != 2:
raise ValueError(f"Invalid input, should be owner_name/repo_name, but got {repo} instead")
raise Exception(f"Invalid input, should be owner_name/repo_name, but got {repo} instead")
return splitted[1], splitted[0]

def _init_env_dependency(self):
Expand Down Expand Up @@ -165,26 +165,51 @@ def _dvc_add(self, local_path="", remote_path=""):
def _dvc_commit(self, commit=""):
self.dvc_folder.commit(commit, versioning="dvc", force=True)

@multi_process_safe
def _get_nested_dict_values(self, d, parent_key="", sep="/"):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, dict):
items.extend(self._get_nested_dict_values(v, new_key, sep=sep))
else:
items.append((new_key, v))
return items

@multi_process_safe
def add_config(self, tag: str, config: dict):
super(DagsHubSGLogger, self).add_config(tag=tag, config=config)
param_keys = config.keys()
for pk in param_keys:
for k, v in config[pk].items():
try:
mlflow.log_params({k: v})
except Exception:
logger.warning(f"Skip to log {k}: {v}")
flatten_dict = self._get_nested_dict_values(d=config)
for k, v in flatten_dict:
try:
mlflow.log_params({k: v})
except Exception as e:
logger.debug(e)

@multi_process_safe
def add_scalar(self, tag: str, scalar_value: float, global_step: int = 0):
super(DagsHubSGLogger, self).add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step)
mlflow.log_metric(key=tag, value=scalar_value, step=global_step)
try:
mlflow.log_metric(key=tag, value=scalar_value, step=global_step)
except Exception as e:
logger.debug(e)

@multi_process_safe
def add_scalars(self, tag_scalar_dict: dict, global_step: int = 0):
super(DagsHubSGLogger, self).add_scalars(tag_scalar_dict=tag_scalar_dict, global_step=global_step)
mlflow.log_metrics(metrics=tag_scalar_dict, step=global_step)
try:
mlflow.log_metrics(metrics=tag_scalar_dict, step=global_step)
except Exception:
flatten_dicts = self._get_nested_dict_values(tag_scalar_dict)
for k, v in flatten_dicts:
try:
if isinstance(v, torch.Tensor):
v = v.item()
else:
v = float(v)
self.add_scalar(tag=k.replace("@", "at"), scalar_value=v, global_step=global_step)
except Exception as e:
logger.debug(f"error: {e}")

@multi_process_safe
def close(self):
Expand Down
Loading

0 comments on commit 35998cc

Please sign in to comment.