Skip to content

Commit

Permalink
PLFM-3331 Register experiments with model name (#585)
Browse files Browse the repository at this point in the history
* PLFM-3331 Register experiments with model name

* PLFM-3331 Fix black
  • Loading branch information
roikoren755 authored Dec 28, 2022
1 parent 8fc2fb4 commit 24772f4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 2 additions & 2 deletions requirements.pro.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
deci-lab-client==3.1.1
deci-common==3.5.0
deci-lab-client==4.8.0
deci-common==3.15.0
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(
save_tensorboard_remote: bool = True,
save_logs_remote: bool = True,
monitor_system: bool = True,
model_name: Optional[str] = None,
):

if _imported_deci_lab_failure is not None:
Expand All @@ -59,7 +61,13 @@ def __init__(

self.platform_client = DeciPlatformClient()
self.platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))
self.platform_client.register_experiment(name=experiment_name)
if model_name is None:
logger.warning(
"'model_name' parameter not passed. "
"The experiment won't be connected to an architecture in the Deci platform. "
"To pass a model_name, please use the 'sg_logger_params.model_name' field in the training recipe."
)
self.platform_client.register_experiment(name=experiment_name, model_name=model_name if model_name else None)
self.checkpoints_dir_path = checkpoints_dir_path

@multi_process_safe
Expand Down

0 comments on commit 24772f4

Please sign in to comment.