From 128058c8f5c1fb5f073a7bc77674b02f806f5ab5 Mon Sep 17 00:00:00 2001 From: manuel Date: Mon, 4 Dec 2023 23:06:57 +0100 Subject: [PATCH] avoid double reporting metrics --- experiment_buddy/experiment_buddy.py | 21 +++++++++++++++++++-- setup.py | 2 +- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/experiment_buddy/experiment_buddy.py b/experiment_buddy/experiment_buddy.py index 19ae526..d959ce6 100644 --- a/experiment_buddy/experiment_buddy.py +++ b/experiment_buddy/experiment_buddy.py @@ -52,9 +52,26 @@ def __init__(self, experiment_id, debug, wandb_kwargs): wandb_kwargs["settings"] = wandb_kwargs.get("settings", wandb.Settings(start_method="fork")) self.run = wandb.init(name=experiment_id, **wandb_kwargs) + self.already_logged = set() + + def log(self, metrics_dict, **kwargs): + # args, = args + if isinstance(metrics_dict, dict): + new_keys = set(metrics_dict.keys()) + if new_keys.issubset(self.already_logged): + raise ValueError(f"Keys {new_keys - self.already_logged} already logged") + self.already_logged.update(new_keys) + elif isinstance(metrics_dict, str): + if metrics_dict in self.already_logged: + raise ValueError(f"Key {metrics_dict} already logged") + self.already_logged.add(metrics_dict) + else: + raise ValueError(f"Invalid type {type(metrics_dict)}") + + self.run.log(metrics_dict, **kwargs) - def log(self, *args, **kwargs): - self.run.log(*args, **kwargs) + if "commit" in kwargs and kwargs["commit"]: + self.already_logged = set() def deploy(url: str = "", sweep_definition: str = "", proc_num: int = 1, wandb_kwargs=None, diff --git a/setup.py b/setup.py index cec3870..47f7797 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name='experiment_buddy', - version='0.0.19', + version='0.0.20', packages=["experiment_buddy", "experiment_buddy.buddy_init", "scripts"], package_data={'scripts': ['*/*.sh']}, url='https://github.com/ministry-of-silly-code/experiment_buddy/',