From 07e802dd59f508280e8887715ef626263a8fbded Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Tue, 3 Sep 2024 21:14:16 +0100 Subject: [PATCH 01/25] remove deprecated rest api key --- src/lightning/pytorch/loggers/comet.py | 11 +++-------- tests/tests_pytorch/loggers/test_comet.py | 5 ----- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 277af5c85f539..96c0cfa151215 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -35,7 +35,7 @@ from comet_ml import ExistingExperiment, Experiment, OfflineExperiment log = logging.getLogger(__name__) -_COMET_AVAILABLE = RequirementCache("comet-ml>=3.31.0", module="comet_ml") +_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml") class CometLogger(Logger): @@ -64,7 +64,6 @@ class CometLogger(Logger): workspace=os.environ.get("COMET_WORKSPACE"), # Optional save_dir=".", # Optional project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"), # Optional experiment_name="lightning_logs", # Optional ) @@ -81,7 +80,6 @@ class CometLogger(Logger): save_dir=".", workspace=os.environ.get("COMET_WORKSPACE"), # Optional project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional experiment_name="lightning_logs", # Optional ) trainer = Trainer(logger=comet_logger) @@ -172,10 +170,8 @@ def __init__(self, *args, **kwarg): save_dir: Required in offline mode. The path for the directory to save local comet logs. If given, this also sets the directory for saving checkpoints. project_name: Optional. Send your experiment to a specific project. - Otherwise will be sent to Uncategorized Experiments. + Otherwise, will be sent to Uncategorized Experiments. If the project name does not already exist, Comet.ml will create a new project. - rest_api_key: Optional. Rest API key found in Comet.ml settings. - This is used to determine version number experiment_name: Optional. String representing the name for this particular experiment on Comet.ml. experiment_key: Optional. If set, restores from existing experiment. offline: If api_key and save_dir are both given, this determines whether @@ -201,7 +197,6 @@ def __init__( api_key: Optional[str] = None, save_dir: Optional[str] = None, project_name: Optional[str] = None, - rest_api_key: Optional[str] = None, experiment_name: Optional[str] = None, experiment_key: Optional[str] = None, offline: bool = False, @@ -213,9 +208,9 @@ def __init__( super().__init__() self._experiment = None self._save_dir: Optional[str] - self.rest_api_key: Optional[str] # needs to be set before the first `comet_ml` import + # because comet_ml imported after another machine learning libraries (Torch) os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1" import comet_ml diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index e467c63543ede..224509faabdf1 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -60,11 +60,6 @@ def test_comet_logger_online(comet_mock): ) comet_existing().set_name.assert_called_once_with("experiment") - # API experiment - api = comet_mock.api.API - CometLogger(api_key="key", workspace="dummy-test", project_name="general", rest_api_key="rest") - api.assert_called_once_with("rest") - @mock.patch.dict(os.environ, {}) def test_comet_experiment_resets_if_not_alive(comet_mock): From 075b581f827f7e62d17009932a632877ccf00bc2 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Tue, 3 Sep 2024 21:20:04 +0100 Subject: [PATCH 02/25] do not end experiment automatically --- src/lightning/pytorch/loggers/comet.py | 24 ++++++++++------------- tests/tests_pytorch/loggers/test_comet.py | 2 +- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 96c0cfa151215..ff651cf944217 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -164,15 +164,15 @@ def __init__(self, *args, **kwarg): - `Comet Documentation `__ Args: - api_key: Required in online mode. API key, found on Comet.ml. If not given, this + api_key: Required in online mode. API key, found on Comet.com. If not given, this will be loaded from the environment variable COMET_API_KEY or ~/.comet.config if either exists. save_dir: Required in offline mode. The path for the directory to save local comet logs. If given, this also sets the directory for saving checkpoints. project_name: Optional. Send your experiment to a specific project. Otherwise, will be sent to Uncategorized Experiments. - If the project name does not already exist, Comet.ml will create a new project. - experiment_name: Optional. String representing the name for this particular experiment on Comet.ml. + If the project name does not already exist, Comet.com will create a new project. + experiment_name: Optional. String representing the name for this particular experiment on Comet.com. experiment_key: Optional. If set, restores from existing experiment. offline: If api_key and save_dir are both given, this determines whether the experiment will be in online or offline mode. This is useful if you use @@ -309,7 +309,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: @rank_zero_only def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - # Comet.ml expects metrics to be a dictionary of detached tensors on CPU + # Comet.com expects metrics to be a dictionary of detached tensors on CPU metrics_without_epoch = metrics.copy() for key, val in metrics_without_epoch.items(): if isinstance(val, Tensor): @@ -319,26 +319,22 @@ def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optiona metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR) self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) - def reset_experiment(self) -> None: - self._experiment = None @override @rank_zero_only def finalize(self, status: str) -> None: - r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you - need to log any more data, you need to create an ExistingCometExperiment. For example, to log data when testing - your model after training, because when training is finalized :meth:`CometLogger.finalize` is called. - - This happens automatically in the :meth:`~CometLogger.experiment` property, when - ``self._experiment`` is set to ``None``, i.e. ``self.reset_experiment()``. + """ + We will not end experiment (self._experiment.end()) here + to have an ability to continue using it after training is complete + but instead of ending we will upload/save all the data """ if self._experiment is None: # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been # initialized there return - self.experiment.end() - self.reset_experiment() + + self.experiment.flush() @property @override diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index 224509faabdf1..75504ace87ce2 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -203,7 +203,7 @@ def test_comet_version_without_experiment(comet_mock): _ = logger.experiment - logger.reset_experiment() + logger._experiment = None second_version = logger.version == "1234" assert second_version is not None From be371cd05f54d439a3d69f05b775fa75249de8e5 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Tue, 3 Sep 2024 21:27:04 +0100 Subject: [PATCH 03/25] support nested hyperparams --- src/lightning/pytorch/loggers/comet.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index ff651cf944217..844d014de59aa 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -26,7 +26,7 @@ from torch.nn import Module from typing_extensions import override -from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from lightning.fabric.utilities.logger import _add_prefix, _convert_params from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only @@ -104,6 +104,9 @@ def __init__(self, *args, **kwarg): # log multiple parameters logger.log_hyperparams({"batch_size": 16, "learning_rate": 0.001}) + # log nested parameters + logger.log_hyperparams({"specific": {'param': {'subparam': "value"}}}) + **Log Metrics:** .. code-block:: python @@ -114,6 +117,9 @@ def __init__(self, *args, **kwarg): # add multiple metrics logger.log_metrics({"train/loss": 0.001, "val/loss": 0.002}) + # add nested metrics + logger.log_hyperparams({"specific": {'metric': {'submetric': "value"}}}) + **Access the Comet Experiment object:** You can gain access to the underlying Comet @@ -302,7 +308,6 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) - params = _flatten_dict(params) self.experiment.log_parameters(params) @override From 7481e33639e09c2a2737d7d6af2b2982a97379d0 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Tue, 3 Sep 2024 21:36:34 +0100 Subject: [PATCH 04/25] use special internal methods while logging --- src/lightning/pytorch/loggers/comet.py | 27 ++++++++++++++++------- tests/tests_pytorch/loggers/conftest.py | 17 +++++++------- tests/tests_pytorch/loggers/test_comet.py | 8 ++++++- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 844d014de59aa..8c5eb7462114e 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -26,7 +26,7 @@ from torch.nn import Module from typing_extensions import override -from lightning.fabric.utilities.logger import _add_prefix, _convert_params +from lightning.fabric.utilities.logger import _convert_params from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only @@ -37,6 +37,8 @@ log = logging.getLogger(__name__) _COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml") +FRAMEWORK_NAME = "pytorch-lightning" + class CometLogger(Logger): r"""Track your parameters, metrics, source code and more using `Comet @@ -196,8 +198,6 @@ def __init__(self, *args, **kwarg): """ - LOGGER_JOIN_CHAR = "-" - def __init__( self, api_key: Optional[str] = None, @@ -293,7 +293,7 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi self._experiment = OfflineExperiment( offline_directory=self.save_dir, project_name=self._project_name, **self._kwargs ) - self._experiment.log_other("Created from", "pytorch-lightning") + self._experiment.log_other("Created from", FRAMEWORK_NAME) finally: if self._future_experiment_key is not None: os.environ.pop("COMET_EXPERIMENT_KEY") @@ -308,7 +308,10 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) - self.experiment.log_parameters(params) + self.experiment.__internal_api__log_parameters__( + parameters=params, + framework=FRAMEWORK_NAME, + ) @override @rank_zero_only @@ -321,8 +324,13 @@ def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optiona metrics_without_epoch[key] = val.cpu().detach() epoch = metrics_without_epoch.pop("epoch", None) - metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR) - self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) + self.experiment.__internal_api__log_metrics__( + metrics_without_epoch, + step=step, + epoch=epoch, + prefix=self._prefix, + framework=FRAMEWORK_NAME, + ) @override @@ -423,4 +431,7 @@ def __getstate__(self) -> Dict[str, Any]: @override def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: if self._experiment is not None: - self._experiment.set_model_graph(model) + self._experiment.__internal_api__set_model_graph__( + graph=model, + framework=FRAMEWORK_NAME, + ) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 7cc5cc94fe8cc..8435fb56172a3 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -94,17 +94,16 @@ def comet_mock(monkeypatch): comet = ModuleType("comet_ml") monkeypatch.setitem(sys.modules, "comet_ml", comet) - comet.Experiment = Mock() - comet.ExistingExperiment = Mock() - comet.OfflineExperiment = Mock() - comet.API = Mock() - comet.config = Mock() + # to support dunder methods calling we will create a special mock + comet_experiment = MagicMock() + setattr(comet_experiment, '__internal_api__set_model_graph__', MagicMock()) + setattr(comet_experiment, '__internal_api__log_metrics__', MagicMock()) - comet_api = ModuleType("api") - comet_api.API = Mock() - monkeypatch.setitem(sys.modules, "comet_ml.api", comet_api) + comet.Experiment = MagicMock(return_value=comet_experiment) + comet.ExistingExperiment = MagicMock(return_value=comet_experiment) + comet.OfflineExperiment = MagicMock(return_value=comet_experiment) - comet.api = comet_api + comet.config = Mock() monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True) return comet diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index 75504ace87ce2..06c00af9b171f 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -216,7 +216,13 @@ def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): _patch_comet_atexit(monkeypatch) logger = CometLogger(project_name="test", save_dir=str(tmp_path)) logger.log_metrics({"test": 1, "epoch": 1}, step=123) - logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) + logger.experiment.__internal_api__log_metrics__.assert_called_once_with( + {"test": 1}, + epoch=1, + step=123, + prefix=logger._prefix, + framework="pytorch-lightning", + ) @mock.patch.dict(os.environ, {}) From 9d9c86f36e647f527128e83ee7627c532ce23ca7 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Wed, 4 Sep 2024 16:40:33 +0100 Subject: [PATCH 05/25] linter fix --- src/lightning/pytorch/loggers/comet.py | 9 ++------- tests/tests_pytorch/loggers/conftest.py | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 8c5eb7462114e..9bfeff462014d 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -332,16 +332,11 @@ def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optiona framework=FRAMEWORK_NAME, ) - @override @rank_zero_only def finalize(self, status: str) -> None: - """ - We will not end experiment (self._experiment.end()) here - to have an ability to continue using it after training is complete - but instead of ending we will upload/save all the data - - """ + """We will not end experiment (self._experiment.end()) here to have an ability to continue using it after + training is complete but instead of ending we will upload/save all the data.""" if self._experiment is None: # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been # initialized there diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 8435fb56172a3..d04fad43ca6f6 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -96,8 +96,8 @@ def comet_mock(monkeypatch): # to support dunder methods calling we will create a special mock comet_experiment = MagicMock() - setattr(comet_experiment, '__internal_api__set_model_graph__', MagicMock()) - setattr(comet_experiment, '__internal_api__log_metrics__', MagicMock()) + setattr(comet_experiment, "__internal_api__set_model_graph__", MagicMock()) + setattr(comet_experiment, "__internal_api__log_metrics__", MagicMock()) comet.Experiment = MagicMock(return_value=comet_experiment) comet.ExistingExperiment = MagicMock(return_value=comet_experiment) From f2cd1b9f84d0d547b61eb43541461ab2e963e9ea Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Wed, 4 Sep 2024 23:15:53 +0100 Subject: [PATCH 06/25] use comet_ml.start() instead of manual initialization --- src/lightning/pytorch/loggers/comet.py | 226 +++++++++---------------- 1 file changed, 84 insertions(+), 142 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 9bfeff462014d..76dbc7f23c914 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -19,7 +19,7 @@ import logging import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -27,8 +27,8 @@ from typing_extensions import override from lightning.fabric.utilities.logger import _convert_params +from lightning.fabric.utilities.rank_zero import _get_rank from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: @@ -38,6 +38,7 @@ _COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml") FRAMEWORK_NAME = "pytorch-lightning" +comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"] class CometLogger(Logger): @@ -62,11 +63,10 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - api_key=os.environ.get("COMET_API_KEY"), - workspace=os.environ.get("COMET_WORKSPACE"), # Optional - save_dir=".", # Optional - project_name="default_project", # Optional - experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"), # Optional + api_key="COMET_API_KEY", # Optional + workspace="COMET_WORKSPACE", # Optional + project="default_project", # Optional + experiment_key="COMET_EXPERIMENT_KEY", # Optional experiment_name="lightning_logs", # Optional ) trainer = Trainer(logger=comet_logger) @@ -79,10 +79,10 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - save_dir=".", - workspace=os.environ.get("COMET_WORKSPACE"), # Optional - project_name="default_project", # Optional + workspace="COMET_WORKSPACE", # Optional + project="default_project", # Optional experiment_name="lightning_logs", # Optional + online=False ) trainer = Trainer(logger=comet_logger) @@ -172,48 +172,66 @@ def __init__(self, *args, **kwarg): - `Comet Documentation `__ Args: - api_key: Required in online mode. API key, found on Comet.com. If not given, this - will be loaded from the environment variable COMET_API_KEY or ~/.comet.config - if either exists. - save_dir: Required in offline mode. The path for the directory to save local - comet logs. If given, this also sets the directory for saving checkpoints. - project_name: Optional. Send your experiment to a specific project. - Otherwise, will be sent to Uncategorized Experiments. - If the project name does not already exist, Comet.com will create a new project. - experiment_name: Optional. String representing the name for this particular experiment on Comet.com. - experiment_key: Optional. If set, restores from existing experiment. - offline: If api_key and save_dir are both given, this determines whether - the experiment will be in online or offline mode. This is useful if you use - save_dir to control the checkpoints directory and have a ~/.comet.config - file but still want to run offline experiments. - prefix: A string to put at the beginning of metric keys. - \**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by + api_key (str, optional): Comet API key. It's recommended to configure the API Key with `comet login`. + workspace (str, optional): Comet workspace name. If not provided, uses the default workspace. + project (str, optional): Comet project name. Defaults to `Uncategorized`. + experiment_key (str, optional): The Experiment identifier to be used for logging. This is used either to append + data to an Existing Experiment or to control the key of new experiments (for example to match another + identifier). Must be an alphanumeric string whose length is between 32 and 50 characters. + mode (str, optional): Control how the Comet experiment is started. + * ``"get_or_create"``: Starts a fresh experiment if required, or persists logging to an existing one. + * ``"get"``: Continue logging to an existing experiment identified by the ``experiment_key`` value. + * ``"create"``: Always creates of a new experiment, useful for HPO sweeps. + online (boolean, optional): If True, the data will be logged to Comet server, otherwise it will be stored + locally in an offline experiment. Default is ``True``. + **kwargs: Additional arguments like `experiment_name`, `log_code`, `prefix`, `offline_directory` etc. used by :class:`CometExperiment` can be passed as keyword arguments in this logger. Raises: ModuleNotFoundError: If required Comet package is not installed on the device. - MisconfigurationException: - If neither ``api_key`` nor ``save_dir`` are passed as arguments. """ def __init__( self, api_key: Optional[str] = None, - save_dir: Optional[str] = None, - project_name: Optional[str] = None, - experiment_name: Optional[str] = None, + workspace: Optional[str] = None, + project: Optional[str] = None, experiment_key: Optional[str] = None, - offline: bool = False, - prefix: str = "", + mode: Optional[Literal["get_or_create", "get", "create"]] = None, + online: Optional[bool] = None, **kwargs: Any, ): if not _COMET_AVAILABLE: raise ModuleNotFoundError(str(_COMET_AVAILABLE)) + super().__init__() - self._experiment = None - self._save_dir: Optional[str] + + ################################################## + # HANDLE PASSED OLD_TYPE PARAMS + self._prefix: Optional[str] = kwargs.pop("prefix", None) + + # handle old "project name" param + if "project_name" in kwargs and project is None: + project = kwargs.pop("project_name") + + # handle old "offline" experiment flag + if "offline" in kwargs and online is None: + online = kwargs.pop("offline") + + # handle old "save_dir" param + if "save_dir" in kwargs and "offline_directory" not in kwargs: + kwargs["offline_directory"] = kwargs.pop("save_dir") + ################################################## + + self._experiment: Optional[comet_experiment] = None + self._workspace: Optional[str] = workspace + self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode + self._online: Optional[bool] = online + self._project_name: Optional[str] = project + self._experiment_key: Optional[str] = experiment_key + self._kwargs: Dict[str, Any] = kwargs # needs to be set before the first `comet_ml` import # because comet_ml imported after another machine learning libraries (Torch) @@ -221,46 +239,31 @@ def __init__( import comet_ml - # Determine online or offline mode based on which arguments were passed to CometLogger - api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config()) - - if api_key is not None and save_dir is not None: - self.mode = "offline" if offline else "online" - self.api_key = api_key - self._save_dir = save_dir - elif api_key is not None: - self.mode = "online" - self.api_key = api_key - self._save_dir = None - elif save_dir is not None: - self.mode = "offline" - self._save_dir = save_dir - else: - # If neither api_key nor save_dir are passed as arguments, raise an exception - raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.") - - log.info(f"CometLogger will be initialized in {self.mode} mode") - - self._project_name: Optional[str] = project_name - self._experiment_key: Optional[str] = experiment_key - self._experiment_name: Optional[str] = experiment_name - self._prefix: str = prefix - self._kwargs: Any = kwargs - self._future_experiment_key: Optional[str] = None + self._comet_config = comet_ml.ExperimentConfig(**self._kwargs) + + # create real experiment only on main node/process + if _get_rank() is not None and _get_rank() != 0: + return + + self._experiment = comet_ml.start( + api_key=api_key, + workspace=self._workspace, + project=self._project_name, + experiment_key=self._experiment, + mode=self._mode, + online=self._online, + experiment_config=self._comet_config, + ) + + self._experiment_key = self._experiment.get_key() + self._project_name = self.experiment.project_name - if rest_api_key is not None: - from comet_ml.api import API + self._experiment.log_other("Created from", FRAMEWORK_NAME) - # Comet.ml rest API, used to determine version number - self.rest_api_key = rest_api_key - self.comet_api = API(self.rest_api_key) - else: - self.rest_api_key = None - self.comet_api = None @property @rank_zero_experiment - def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperiment"]: + def experiment(self) -> comet_experiment: r"""Actual Comet object. To use Comet features in your :class:`~lightning.pytorch.core.LightningModule` do the following. @@ -269,38 +272,6 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi self.logger.experiment.some_comet_function() """ - if self._experiment is not None and self._experiment.alive: - return self._experiment - - if self._future_experiment_key is not None: - os.environ["COMET_EXPERIMENT_KEY"] = self._future_experiment_key - - from comet_ml import ExistingExperiment, Experiment, OfflineExperiment - - try: - if self.mode == "online": - if self._experiment_key is None: - self._experiment = Experiment(api_key=self.api_key, project_name=self._project_name, **self._kwargs) - self._experiment_key = self._experiment.get_key() - else: - self._experiment = ExistingExperiment( - api_key=self.api_key, - project_name=self._project_name, - previous_experiment=self._experiment_key, - **self._kwargs, - ) - else: - self._experiment = OfflineExperiment( - offline_directory=self.save_dir, project_name=self._project_name, **self._kwargs - ) - self._experiment.log_other("Created from", FRAMEWORK_NAME) - finally: - if self._future_experiment_key is not None: - os.environ.pop("COMET_EXPERIMENT_KEY") - self._future_experiment_key = None - - if self._experiment_name: - self._experiment.set_name(self._experiment_name) return self._experiment @@ -335,13 +306,14 @@ def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optiona @override @rank_zero_only def finalize(self, status: str) -> None: - """We will not end experiment (self._experiment.end()) here to have an ability to continue using it after - training is complete but instead of ending we will upload/save all the data.""" + """We will not end experiment (will not call self._experiment.end()) here to have an ability to continue using + it after training is complete but instead of ending we will upload/save all the data.""" if self._experiment is None: # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been # initialized there return + # just save the data self.experiment.flush() @property @@ -353,61 +325,31 @@ def save_dir(self) -> Optional[str]: The path to the save directory. """ - return self._save_dir + return self._comet_config.offline_directory @property @override - def name(self) -> str: + def name(self) -> Optional[str]: """Gets the project name. Returns: - The project name if it is specified, else "comet-default". + The project name if it is specified. """ - # Don't create an experiment if we don't have one - if self._experiment is not None and self._experiment.project_name is not None: - return self._experiment.project_name - - if self._project_name is not None: - return self._project_name - - return "comet-default" + return self._project_name @property @override - def version(self) -> str: + def version(self) -> Optional[str]: """Gets the version. Returns: - The first one of the following that is set in the following order - - 1. experiment id. - 2. experiment key. - 3. "COMET_EXPERIMENT_KEY" environment variable. - 4. future experiment key. - - If none are present generates a new guid. + The experiment key if present """ # Don't create an experiment if we don't have one if self._experiment is not None: - return self._experiment.id - - if self._experiment_key is not None: - return self._experiment_key - - if "COMET_EXPERIMENT_KEY" in os.environ: - return os.environ["COMET_EXPERIMENT_KEY"] - - if self._future_experiment_key is not None: - return self._future_experiment_key - - import comet_ml - - # Pre-generate an experiment key - self._future_experiment_key = comet_ml.generate_guid() - - return self._future_experiment_key + return self._experiment.get_key() def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() @@ -415,7 +357,7 @@ def __getstate__(self) -> Dict[str, Any]: # Save the experiment id in case an experiment object already exists, # this way we could create an ExistingExperiment pointing to the same # experiment - state["_experiment_key"] = self._experiment.id if self._experiment is not None else None + state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None # Remove the experiment object as it contains hard to pickle objects # (like network connections), the experiment object will be recreated if From 946c92e8218f78a6a5d4f2501adfd2457f6518a3 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Thu, 5 Sep 2024 15:51:50 +0100 Subject: [PATCH 07/25] update unit tests --- src/lightning/pytorch/loggers/comet.py | 2 +- tests/tests_pytorch/loggers/conftest.py | 10 +- tests/tests_pytorch/loggers/test_comet.py | 224 +++++++--------------- 3 files changed, 79 insertions(+), 157 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 76dbc7f23c914..e64a8a99fc802 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -249,7 +249,7 @@ def __init__( api_key=api_key, workspace=self._workspace, project=self._project_name, - experiment_key=self._experiment, + experiment_key=self._experiment_key, mode=self._mode, online=self._online, experiment_config=self._comet_config, diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index d04fad43ca6f6..275e9cff8e3af 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -95,14 +95,16 @@ def comet_mock(monkeypatch): monkeypatch.setitem(sys.modules, "comet_ml", comet) # to support dunder methods calling we will create a special mock - comet_experiment = MagicMock() + comet_experiment = MagicMock(name="CommonExperiment") setattr(comet_experiment, "__internal_api__set_model_graph__", MagicMock()) setattr(comet_experiment, "__internal_api__log_metrics__", MagicMock()) - comet.Experiment = MagicMock(return_value=comet_experiment) - comet.ExistingExperiment = MagicMock(return_value=comet_experiment) - comet.OfflineExperiment = MagicMock(return_value=comet_experiment) + comet.Experiment = MagicMock(name="Experiment", return_value=comet_experiment) + comet.ExistingExperiment = MagicMock(name="ExistingExperiment", return_value=comet_experiment) + comet.OfflineExperiment = MagicMock(name="OfflineExperiment", return_value=comet_experiment) + comet.ExperimentConfig = Mock() + comet.start = Mock(name="comet_ml.start", return_value=comet.Experiment()) comet.config = Mock() monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True) diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index 06c00af9b171f..254d71e01b121 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -13,13 +13,9 @@ # limitations under the License. import os from unittest import mock -from unittest.mock import DEFAULT, Mock, patch +from unittest.mock import Mock, call -import pytest -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CometLogger -from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch import tensor @@ -33,181 +29,105 @@ def _patch_comet_atexit(monkeypatch): @mock.patch.dict(os.environ, {}) def test_comet_logger_online(comet_mock): """Test comet online with mocks.""" - # Test api_key given - comet_experiment = comet_mock.Experiment - logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general") - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") - - # Test both given - comet_experiment.reset_mock() - logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general") - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") - - # Test already exists - comet_existing = comet_mock.ExistingExperiment - logger = CometLogger( - experiment_key="test", - experiment_name="experiment", + + comet_start = comet_mock.start + + # Test api_key given with old param "project_name" + _logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general") + comet_start.assert_called_once_with( + api_key="key", + workspace="dummy-test", + project="general", + experiment_key=None, + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), + ) + + # Test online given + comet_start.reset_mock() + _logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general", online=True) + comet_start.assert_called_once_with( api_key="key", workspace="dummy-test", - project_name="general", + project="general", + experiment_key=None, + mode=None, + online=True, + experiment_config=comet_mock.ExperimentConfig(), ) - _ = logger.experiment - comet_existing.assert_called_once_with( - api_key="key", workspace="dummy-test", project_name="general", previous_experiment="test" + + # Test experiment_key given + comet_start.reset_mock() + _logger = CometLogger( + experiment_key="test_key", + api_key="key", + project="general", + ) + comet_start.assert_called_once_with( + api_key="key", + workspace=None, + project="general", + experiment_key="test_key", + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), ) - comet_existing().set_name.assert_called_once_with("experiment") @mock.patch.dict(os.environ, {}) -def test_comet_experiment_resets_if_not_alive(comet_mock): - """Test that the CometLogger creates a new experiment if the old one is not alive anymore.""" +def test_comet_experiment_is_still_alive_after_training_complete(comet_mock): + """Test that the CometLogger will not end an experiment after training is complete.""" + logger = CometLogger() - assert logger._experiment is None - alive_experiment = Mock(alive=True) - logger._experiment = alive_experiment - assert logger.experiment is alive_experiment + assert logger.experiment is not None - unalive_experiment = Mock(alive=False) - logger._experiment = unalive_experiment - assert logger.experiment is not unalive_experiment + logger._experiment = Mock() + logger.finalize("ended") + # Assert that data was saved to comet.com + logger._experiment.flush.assert_called_once() -@mock.patch.dict(os.environ, {}) -def test_comet_logger_no_api_key_given(comet_mock): - """Test that CometLogger fails to initialize if both api key and save_dir are missing.""" - with pytest.raises(MisconfigurationException, match="requires either api_key or save_dir"): - comet_mock.config.get_api_key.return_value = None - CometLogger(workspace="dummy-test", project_name="general") + # Assert that was not ended + logger._experiment.end.assert_not_called() @mock.patch.dict(os.environ, {}) def test_comet_logger_experiment_name(comet_mock): """Test that Comet Logger experiment name works correctly.""" - api_key = "key" - experiment_name = "My Name" - - # Test api_key given - comet_experiment = comet_mock.Experiment - logger = CometLogger(api_key=api_key, experiment_name=experiment_name) - assert logger._experiment is None - - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) - comet_experiment().set_name.assert_called_once_with(experiment_name) - - -@mock.patch.dict(os.environ, {}) -def test_comet_logger_manual_experiment_key(comet_mock): - """Test that Comet Logger respects manually set COMET_EXPERIMENT_KEY.""" - api_key = "key" - experiment_key = "96346da91469407a85641afe5766b554" - - instantiation_environ = {} - - def save_os_environ(*args, **kwargs): - nonlocal instantiation_environ - instantiation_environ = os.environ.copy() + api_key = "api_key" + experiment_name = "My Experiment Name" - return DEFAULT + comet_start = comet_mock.start - comet_experiment = comet_mock.Experiment - comet_experiment.side_effect = save_os_environ - - # Test api_key given - with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}): - logger = CometLogger(api_key=api_key) - assert logger.version == experiment_key - assert logger._experiment is None - - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) - - assert instantiation_environ["COMET_EXPERIMENT_KEY"] == experiment_key - - -@mock.patch.dict(os.environ, {}) -def test_comet_logger_dirs_creation(comet_mock, tmp_path, monkeypatch): - """Test that the logger creates the folders and files in the right place.""" - _patch_comet_atexit(monkeypatch) - comet_experiment = comet_mock.OfflineExperiment - - comet_mock.config.get_api_key.return_value = None - comet_mock.generate_guid = Mock() - comet_mock.generate_guid.return_value = "4321" - - logger = CometLogger(project_name="test", save_dir=str(tmp_path)) - assert not os.listdir(tmp_path) - assert logger.mode == "offline" - assert logger.save_dir == str(tmp_path) - assert logger.name == "test" - assert logger.version == "4321" - - _ = logger.experiment - comet_experiment.assert_called_once_with(offline_directory=str(tmp_path), project_name="test") - - # mock return values of experiment - logger.experiment.id = "1" - logger.experiment.project_name = "test" - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmp_path, logger=logger, max_epochs=1, limit_train_batches=3, limit_val_batches=3 + logger = CometLogger(api_key=api_key, experiment_name=experiment_name) + comet_start.assert_called_once_with( + api_key=api_key, + workspace=None, + project=None, + experiment_key=None, + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), ) - assert trainer.log_dir == logger.save_dir - trainer.fit(model) + # check that we saved "experiment name" in kwargs + assert logger._kwargs["experiment_name"] == experiment_name - assert trainer.checkpoint_callback.dirpath == str(tmp_path / "test" / "1" / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"} - assert trainer.log_dir == logger.save_dir + # check that "experiment name" was passed to experiment config + assert call(experiment_name=experiment_name) in comet_mock.ExperimentConfig.call_args_list @mock.patch.dict(os.environ, {}) -def test_comet_name_default(comet_mock): - """Test that CometLogger.name don't create an Experiment and returns a default value.""" - api_key = "key" - logger = CometLogger(api_key=api_key) - assert logger._experiment is None - assert logger.name == "comet-default" - assert logger._experiment is None - - -@mock.patch.dict(os.environ, {}) -def test_comet_name_project_name(comet_mock): - """Test that CometLogger.name does not create an Experiment and returns project name if passed.""" - api_key = "key" - project_name = "My Project Name" - logger = CometLogger(api_key=api_key, project_name=project_name) - assert logger._experiment is None - assert logger.name == project_name - assert logger._experiment is None - - -@mock.patch.dict(os.environ, {}) -def test_comet_version_without_experiment(comet_mock): - """Test that CometLogger.version does not create an Experiment.""" +def test_comet_version(comet_mock): + """Test that CometLogger.version returns an Experiment key.""" api_key = "key" experiment_name = "My Name" - comet_mock.generate_guid = Mock() - comet_mock.generate_guid.return_value = "1234" logger = CometLogger(api_key=api_key, experiment_name=experiment_name) - assert logger._experiment is None - - first_version = logger.version - assert first_version is not None - assert logger.version == first_version - assert logger._experiment is None - - _ = logger.experiment - - logger._experiment = None + assert logger._experiment is not None + _ = logger.version - second_version = logger.version == "1234" - assert second_version is not None - assert second_version != first_version + logger._experiment.get_key.assert_called() @mock.patch.dict(os.environ, {}) From 3a2a27ddb37214ba63e6e68365be48ae43dfe3a4 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Mon, 9 Sep 2024 14:24:53 +0100 Subject: [PATCH 08/25] add PR suggestions --- src/lightning/pytorch/loggers/comet.py | 58 +++++++++++++++++++------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index e64a8a99fc802..1c44d08f06137 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -120,7 +120,7 @@ def __init__(self, *args, **kwarg): logger.log_metrics({"train/loss": 0.001, "val/loss": 0.002}) # add nested metrics - logger.log_hyperparams({"specific": {'metric': {'submetric': "value"}}}) + logger.log_metrics({"specific": {'metric': {'submetric': "value"}}}) **Access the Comet Experiment object:** @@ -184,7 +184,9 @@ def __init__(self, *args, **kwarg): * ``"create"``: Always creates of a new experiment, useful for HPO sweeps. online (boolean, optional): If True, the data will be logged to Comet server, otherwise it will be stored locally in an offline experiment. Default is ``True``. - **kwargs: Additional arguments like `experiment_name`, `log_code`, `prefix`, `offline_directory` etc. used by + prefix (str, optional): The prefix to add to names of the logged metrics. + example: prefix=`exp1`, then metric name will be `exp1_metric_name` + **kwargs: Additional arguments like `experiment_name`, `log_code`, `offline_directory` etc. used by :class:`CometExperiment` can be passed as keyword arguments in this logger. Raises: @@ -201,6 +203,7 @@ def __init__( experiment_key: Optional[str] = None, mode: Optional[Literal["get_or_create", "get", "create"]] = None, online: Optional[bool] = None, + prefix: Optional[str] = None, **kwargs: Any, ): if not _COMET_AVAILABLE: @@ -209,28 +212,44 @@ def __init__( super().__init__() ################################################## - # HANDLE PASSED OLD_TYPE PARAMS - self._prefix: Optional[str] = kwargs.pop("prefix", None) + # HANDLE PASSED OLD TYPE PARAMS # handle old "project name" param - if "project_name" in kwargs and project is None: - project = kwargs.pop("project_name") + if "project_name" in kwargs: + log.warning('The parameter "project_name" is deprecated, please use "project" instead.') + if project is None: + project = kwargs.pop("project_name") + else: + log.warning('You specified both "project_name" and "project" parameters, please use "project" only') # handle old "offline" experiment flag - if "offline" in kwargs and online is None: - online = kwargs.pop("offline") + if "offline" in kwargs: + log.warning('The parameter "offline" is deprecated, please use "online" instead.') + if online is None: + online = kwargs.pop("offline") + else: + log.warning('You specified both "offline" and "online" parameters, please use "online" only') # handle old "save_dir" param - if "save_dir" in kwargs and "offline_directory" not in kwargs: - kwargs["offline_directory"] = kwargs.pop("save_dir") + if "save_dir" in kwargs: + log.warning('The parameter "save_dir" is deprecated, please use "offline_directory" instead.') + if "offline_directory" not in kwargs: + kwargs["offline_directory"] = kwargs.pop("save_dir") + else: + log.warning( + 'You specified both "save_dir" and "offline_directory" parameters, ' + 'please use "offline_directory" only' + ) ################################################## + self._api_key: Optional[str] = api_key self._experiment: Optional[comet_experiment] = None self._workspace: Optional[str] = workspace self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode self._online: Optional[bool] = online self._project_name: Optional[str] = project self._experiment_key: Optional[str] = experiment_key + self._prefix: Optional[str] = prefix self._kwargs: Dict[str, Any] = kwargs # needs to be set before the first `comet_ml` import @@ -241,12 +260,17 @@ def __init__( self._comet_config = comet_ml.ExperimentConfig(**self._kwargs) - # create real experiment only on main node/process + # create real experiment only on main node/process (when strategy=auto/ddp) if _get_rank() is not None and _get_rank() != 0: return + self._create_experiment() + + def _create_experiment(self): + import comet_ml + self._experiment = comet_ml.start( - api_key=api_key, + api_key=self._api_key, workspace=self._workspace, project=self._project_name, experiment_key=self._experiment_key, @@ -254,13 +278,10 @@ def __init__( online=self._online, experiment_config=self._comet_config, ) - self._experiment_key = self._experiment.get_key() - self._project_name = self.experiment.project_name - + self._project_name = self._experiment.project_name self._experiment.log_other("Created from", FRAMEWORK_NAME) - @property @rank_zero_experiment def experiment(self) -> comet_experiment: @@ -273,6 +294,11 @@ def experiment(self) -> comet_experiment: """ + # if by some chance there is no experiment created yet (for example, when strategy=ddp_spawn) + # then we will create a new one + if not self._experiment: + self._create_experiment() + return self._experiment @override From 05de746238efc0a6a8f46fd45f0af80d3cb058f3 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Mon, 9 Sep 2024 21:30:33 +0100 Subject: [PATCH 09/25] fix mypy check --- src/lightning/pytorch/loggers/comet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 1c44d08f06137..0f0ff4047e3b7 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -266,7 +266,7 @@ def __init__( self._create_experiment() - def _create_experiment(self): + def _create_experiment(self) -> None: import comet_ml self._experiment = comet_ml.start( From e1f904280a5c967e7ce61549d896c65436d08b3c Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Tue, 10 Sep 2024 16:06:39 +0100 Subject: [PATCH 10/25] fix mypy check #2 --- src/lightning/pytorch/loggers/comet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 0f0ff4047e3b7..59ee2206ca375 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -278,6 +278,10 @@ def _create_experiment(self) -> None: online=self._online, experiment_config=self._comet_config, ) + + if self._experiment is None: + raise comet_ml.exceptions.ExperimentNotFound("Failed to create Comet experiment.") + self._experiment_key = self._experiment.get_key() self._project_name = self._experiment.project_name self._experiment.log_other("Created from", FRAMEWORK_NAME) From d9c7480984fa86690f89949746b21d3c6c1016e6 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Wed, 11 Sep 2024 17:38:16 +0100 Subject: [PATCH 11/25] update changelog --- src/lightning/pytorch/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1b251b8fb06fa..d2152844260b3 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] - YYYY-MM-DD + +### Changed + +- CometML logger was updated to support the recent Comet SDK + ## [2.4.0] - 2024-08-06 From 0fea284159509190b0a8f9bed401e2ef37984bcd Mon Sep 17 00:00:00 2001 From: Alexander Barannikov <32936723+japdubengsub@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:15:40 +0100 Subject: [PATCH 12/25] Update CHANGELOG.md with PR number --- src/lightning/pytorch/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index d2152844260b3..2762664b9b0b1 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- CometML logger was updated to support the recent Comet SDK +- CometML logger was updated to support the recent Comet SDK ([#20275](https://github.com/Lightning-AI/pytorch-lightning/pull/20275)) ## [2.4.0] - 2024-08-06 From beb8d527cda2069ca44f7f1dad9b5b3fb7b7eb08 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Wed, 25 Sep 2024 16:44:28 +0100 Subject: [PATCH 13/25] use ` notation --- src/lightning/pytorch/loggers/comet.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 59ee2206ca375..9473ef49133e0 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -216,29 +216,29 @@ def __init__( # handle old "project name" param if "project_name" in kwargs: - log.warning('The parameter "project_name" is deprecated, please use "project" instead.') + log.warning('The parameter `project_name` is deprecated, please use `project` instead.') if project is None: project = kwargs.pop("project_name") else: - log.warning('You specified both "project_name" and "project" parameters, please use "project" only') + log.warning('You specified both `project_name` and `project` parameters, please use `project` only') # handle old "offline" experiment flag if "offline" in kwargs: - log.warning('The parameter "offline" is deprecated, please use "online" instead.') + log.warning('The parameter `offline is deprecated, please use `online` instead.') if online is None: online = kwargs.pop("offline") else: - log.warning('You specified both "offline" and "online" parameters, please use "online" only') + log.warning('You specified both `offline` and `online` parameters, please use `online` only') # handle old "save_dir" param if "save_dir" in kwargs: - log.warning('The parameter "save_dir" is deprecated, please use "offline_directory" instead.') + log.warning('The parameter `save_dir` is deprecated, please use `offline_directory` instead.') if "offline_directory" not in kwargs: kwargs["offline_directory"] = kwargs.pop("save_dir") else: log.warning( - 'You specified both "save_dir" and "offline_directory" parameters, ' - 'please use "offline_directory" only' + 'You specified both `save_dir` and `offline_directory` parameters, ' + 'please use `offline_directory` only' ) ################################################## From 5ab36626bf70d12196268d920477053d367816fe Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Wed, 25 Sep 2024 17:16:29 +0100 Subject: [PATCH 14/25] enable keyword arguments only --- src/lightning/pytorch/loggers/comet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 9473ef49133e0..bd5a530aad6a0 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -197,6 +197,7 @@ def __init__(self, *args, **kwarg): def __init__( self, + *, api_key: Optional[str] = None, workspace: Optional[str] = None, project: Optional[str] = None, From 036a2ebe9dd6b757d42d21f1a1c8236229abf439 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Wed, 25 Sep 2024 17:37:06 +0100 Subject: [PATCH 15/25] fix linter warning --- src/lightning/pytorch/loggers/comet.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index bd5a530aad6a0..bf651a1ead129 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -217,29 +217,29 @@ def __init__( # handle old "project name" param if "project_name" in kwargs: - log.warning('The parameter `project_name` is deprecated, please use `project` instead.') + log.warning("The parameter `project_name` is deprecated, please use `project` instead.") if project is None: project = kwargs.pop("project_name") else: - log.warning('You specified both `project_name` and `project` parameters, please use `project` only') + log.warning("You specified both `project_name` and `project` parameters, please use `project` only") # handle old "offline" experiment flag if "offline" in kwargs: - log.warning('The parameter `offline is deprecated, please use `online` instead.') + log.warning("The parameter `offline is deprecated, please use `online` instead.") if online is None: online = kwargs.pop("offline") else: - log.warning('You specified both `offline` and `online` parameters, please use `online` only') + log.warning("You specified both `offline` and `online` parameters, please use `online` only") # handle old "save_dir" param if "save_dir" in kwargs: - log.warning('The parameter `save_dir` is deprecated, please use `offline_directory` instead.') + log.warning("The parameter `save_dir` is deprecated, please use `offline_directory` instead.") if "offline_directory" not in kwargs: kwargs["offline_directory"] = kwargs.pop("save_dir") else: log.warning( - 'You specified both `save_dir` and `offline_directory` parameters, ' - 'please use `offline_directory` only' + "You specified both `save_dir` and `offline_directory` parameters, " + "please use `offline_directory` only" ) ################################################## From a804240038e48cef300f982946ca33b9f2828991 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Mon, 30 Sep 2024 20:04:01 +0100 Subject: [PATCH 16/25] fix unit tests --- tests/tests_pytorch/loggers/test_all.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index c5b07562afb0a..1e4cb4ce380f3 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -107,6 +107,7 @@ def log_metrics(self, metrics, step): if logger_class == CometLogger: logger.experiment.id = "foo" + logger._comet_config.offline_directory = None logger.experiment.project_name = "bar" if logger_class == NeptuneLogger: @@ -299,7 +300,9 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_moc _patch_comet_atexit(monkeypatch) logger = _instantiate_logger(CometLogger, save_dir=tmp_path, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) - logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0) + logger.experiment.__internal_api__log_metrics__.assert_called_once_with( + {"test": 1.0}, epoch=None, step=0, prefix=prefix, framework="pytorch-lightning" + ) # MLflow Metric = mlflow_mock.entities.Metric From cb2ce527b21e5471f04de323745440f8d4e57ea7 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Wed, 2 Oct 2024 17:04:47 +0100 Subject: [PATCH 17/25] add support of old style "experiment_name" argument --- src/lightning/pytorch/loggers/comet.py | 21 +++++++++++++++++---- tests/tests_pytorch/loggers/test_comet.py | 13 ++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index bf651a1ead129..fb4550e84c9af 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -67,7 +67,7 @@ class CometLogger(Logger): workspace="COMET_WORKSPACE", # Optional project="default_project", # Optional experiment_key="COMET_EXPERIMENT_KEY", # Optional - experiment_name="lightning_logs", # Optional + name="lightning_logs", # Optional ) trainer = Trainer(logger=comet_logger) @@ -81,7 +81,7 @@ class CometLogger(Logger): comet_logger = CometLogger( workspace="COMET_WORKSPACE", # Optional project="default_project", # Optional - experiment_name="lightning_logs", # Optional + name="lightning_logs", # Optional online=False ) trainer = Trainer(logger=comet_logger) @@ -186,7 +186,7 @@ def __init__(self, *args, **kwarg): locally in an offline experiment. Default is ``True``. prefix (str, optional): The prefix to add to names of the logged metrics. example: prefix=`exp1`, then metric name will be `exp1_metric_name` - **kwargs: Additional arguments like `experiment_name`, `log_code`, `offline_directory` etc. used by + **kwargs: Additional arguments like `name`, `log_code`, `offline_directory` etc. used by :class:`CometExperiment` can be passed as keyword arguments in this logger. Raises: @@ -215,7 +215,20 @@ def __init__( ################################################## # HANDLE PASSED OLD TYPE PARAMS - # handle old "project name" param + # handle old "experiment_name" param + if "experiment_name" in kwargs: + log.warning("The parameter `experiment_name` is deprecated, please use `name` instead.") + experiment_name = kwargs.pop("experiment_name") + + if "name" not in kwargs: + kwargs["name"] = experiment_name + else: + log.warning( + "You specified both `experiment_name` and `name` parameters, " + "please use `name` only" + ) + + # handle old "project_name" param if "project_name" in kwargs: log.warning("The parameter `project_name` is deprecated, please use `project` instead.") if project is None: diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index 254d71e01b121..51ab1efca747c 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -100,6 +100,7 @@ def test_comet_logger_experiment_name(comet_mock): comet_start = comet_mock.start + # here we use old style arg "experiment_name" (new one is "name") logger = CometLogger(api_key=api_key, experiment_name=experiment_name) comet_start.assert_called_once_with( api_key=api_key, @@ -110,11 +111,13 @@ def test_comet_logger_experiment_name(comet_mock): online=None, experiment_config=comet_mock.ExperimentConfig(), ) - # check that we saved "experiment name" in kwargs - assert logger._kwargs["experiment_name"] == experiment_name + # check that we saved "experiment name" in kwargs as new "name" arg + assert logger._kwargs["name"] == experiment_name + assert "experiment_name" not in logger._kwargs - # check that "experiment name" was passed to experiment config - assert call(experiment_name=experiment_name) in comet_mock.ExperimentConfig.call_args_list + # check that "experiment name" was passed to experiment config correctly + assert call(experiment_name=experiment_name) not in comet_mock.ExperimentConfig.call_args_list + assert call(name=experiment_name) in comet_mock.ExperimentConfig.call_args_list @mock.patch.dict(os.environ, {}) @@ -123,7 +126,7 @@ def test_comet_version(comet_mock): api_key = "key" experiment_name = "My Name" - logger = CometLogger(api_key=api_key, experiment_name=experiment_name) + logger = CometLogger(api_key=api_key, name=experiment_name) assert logger._experiment is not None _ = logger.version From caab33840fe9fd5062b97fd334b86560fd7aad37 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Wed, 2 Oct 2024 17:36:41 +0100 Subject: [PATCH 18/25] fix linter warning --- src/lightning/pytorch/loggers/comet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index fb4550e84c9af..acec422ed08bc 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -223,10 +223,7 @@ def __init__( if "name" not in kwargs: kwargs["name"] = experiment_name else: - log.warning( - "You specified both `experiment_name` and `name` parameters, " - "please use `name` only" - ) + log.warning("You specified both `experiment_name` and `name` parameters, please use `name` only") # handle old "project_name" param if "project_name" in kwargs: From 5c01290ddde988fc9bb6969e33a12e2a916f9c44 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Fri, 4 Oct 2024 20:10:06 +0100 Subject: [PATCH 19/25] remove type hints from docstrings --- src/lightning/pytorch/loggers/comet.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index acec422ed08bc..357d6498afb3d 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -172,20 +172,20 @@ def __init__(self, *args, **kwarg): - `Comet Documentation `__ Args: - api_key (str, optional): Comet API key. It's recommended to configure the API Key with `comet login`. - workspace (str, optional): Comet workspace name. If not provided, uses the default workspace. - project (str, optional): Comet project name. Defaults to `Uncategorized`. - experiment_key (str, optional): The Experiment identifier to be used for logging. This is used either to append + api_key: Comet API key. It's recommended to configure the API Key with `comet login`. + workspace: Comet workspace name. If not provided, uses the default workspace. + project: Comet project name. Defaults to `Uncategorized`. + experiment_key: The Experiment identifier to be used for logging. This is used either to append data to an Existing Experiment or to control the key of new experiments (for example to match another identifier). Must be an alphanumeric string whose length is between 32 and 50 characters. - mode (str, optional): Control how the Comet experiment is started. + mode: Control how the Comet experiment is started. * ``"get_or_create"``: Starts a fresh experiment if required, or persists logging to an existing one. * ``"get"``: Continue logging to an existing experiment identified by the ``experiment_key`` value. * ``"create"``: Always creates of a new experiment, useful for HPO sweeps. - online (boolean, optional): If True, the data will be logged to Comet server, otherwise it will be stored + online: If True, the data will be logged to Comet server, otherwise it will be stored locally in an offline experiment. Default is ``True``. - prefix (str, optional): The prefix to add to names of the logged metrics. - example: prefix=`exp1`, then metric name will be `exp1_metric_name` + prefix: The prefix to add to names of the logged metrics. + example: prefix=`exp1`, then metric name will be logged as `exp1_metric_name` **kwargs: Additional arguments like `name`, `log_code`, `offline_directory` etc. used by :class:`CometExperiment` can be passed as keyword arguments in this logger. From 06a8028c695e1d8c3297e634d3a54a505421c121 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Tue, 8 Oct 2024 18:25:24 +0100 Subject: [PATCH 20/25] revert back default values in docstring examples --- src/lightning/pytorch/loggers/comet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 357d6498afb3d..86f437da45c56 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -63,10 +63,10 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - api_key="COMET_API_KEY", # Optional - workspace="COMET_WORKSPACE", # Optional + api_key=os.environ.get("COMET_API_KEY"), # Optional + workspace=os.environ.get("COMET_WORKSPACE"), # Optional project="default_project", # Optional - experiment_key="COMET_EXPERIMENT_KEY", # Optional + experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"), # Optional name="lightning_logs", # Optional ) trainer = Trainer(logger=comet_logger) @@ -79,7 +79,7 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - workspace="COMET_WORKSPACE", # Optional + workspace=os.environ.get("COMET_WORKSPACE"), # Optional project="default_project", # Optional name="lightning_logs", # Optional online=False From 467378d0f2dae3f1831cf90294eb13c78cdc930a Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Tue, 8 Oct 2024 19:04:06 +0100 Subject: [PATCH 21/25] flatten nested params --- src/lightning/pytorch/loggers/comet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 86f437da45c56..7f626c59f985c 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -323,6 +323,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: self.experiment.__internal_api__log_parameters__( parameters=params, framework=FRAMEWORK_NAME, + flatten_nested=True, + source="manual", ) @override From 512d7633b0b7c710f137ea1d2f586846696215e9 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Tue, 8 Oct 2024 21:32:25 +0100 Subject: [PATCH 22/25] fix unit test mock issue --- tests/tests_pytorch/loggers/test_all.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 1e4cb4ce380f3..4e4d0bc4c6a91 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -108,7 +108,8 @@ def log_metrics(self, metrics, step): if logger_class == CometLogger: logger.experiment.id = "foo" logger._comet_config.offline_directory = None - logger.experiment.project_name = "bar" + logger._project_name = "bar" + logger.experiment.get_key.return_value = "SOME_KEY" if logger_class == NeptuneLogger: logger._retrieve_run_data = Mock() From a785341db084211c6389ff836bbcb011c3fa1b9b Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Tue, 8 Oct 2024 22:03:58 +0100 Subject: [PATCH 23/25] fix cli arguments unit test --- tests/tests_pytorch/test_cli.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 56b58d4d157a1..d28e23d5bf87b 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1580,8 +1580,13 @@ def _test_logger_init_args(logger_name, init, unresolved=None): def test_comet_logger_init_args(): _test_logger_init_args( "CometLogger", - init={"save_dir": "comet"}, # Resolve from CometLogger.__init__ - unresolved={"workspace": "comet"}, # Resolve from Comet{,Existing,Offline}Experiment.__init__ + init={ + "experiment_key": "some_key", # Resolve from CometLogger.__init__ + "workspace": "comet", + }, + unresolved={ + "save_dir": "comet", # Resolve from CometLogger.__init__ as kwarg + }, ) From ad2d319c1d2af1d20c02290fa4bc38cb91e8bbca Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Thu, 10 Oct 2024 10:17:25 +0100 Subject: [PATCH 24/25] add unit tests for public/double underscore methods --- tests/tests_pytorch/loggers/conftest.py | 1 + tests/tests_pytorch/loggers/test_comet.py | 42 +++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 275e9cff8e3af..3912a4dea0db9 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -98,6 +98,7 @@ def comet_mock(monkeypatch): comet_experiment = MagicMock(name="CommonExperiment") setattr(comet_experiment, "__internal_api__set_model_graph__", MagicMock()) setattr(comet_experiment, "__internal_api__log_metrics__", MagicMock()) + setattr(comet_experiment, "__internal_api__log_parameters__", MagicMock()) comet.Experiment = MagicMock(name="Experiment", return_value=comet_experiment) comet.ExistingExperiment = MagicMock(name="ExistingExperiment", return_value=comet_experiment) diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index 51ab1efca747c..34c24211d13c9 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -18,6 +18,8 @@ from lightning.pytorch.loggers import CometLogger from torch import tensor +FRAMEWORK_NAME = "pytorch-lightning" + def _patch_comet_atexit(monkeypatch): """Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it.""" @@ -148,6 +150,46 @@ def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): ) +@mock.patch.dict(os.environ, {}) +def test_comet_log_hyperparams(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger.log_hyperparams calls internal API method.""" + _patch_comet_atexit(monkeypatch) + + logger = CometLogger(project_name="test") + hyperparams = { + "batch_size": 256, + "config": { + "SLURM Job ID": "22334455", + "RGB slurm jobID": "12345678", + "autoencoder_model": False, + }, + } + logger.log_hyperparams(hyperparams) + + logger.experiment.__internal_api__log_parameters__.assert_called_once_with( + parameters=hyperparams, + framework=FRAMEWORK_NAME, + flatten_nested=True, + source="manual", + ) + + +@mock.patch.dict(os.environ, {}) +def test_comet_log_graph(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger.log_hyperparams calls internal API method.""" + _patch_comet_atexit(monkeypatch) + + logger = CometLogger(project_name="test") + model = Mock() + + logger.log_graph(model=model) + + logger.experiment.__internal_api__set_model_graph__.assert_called_once_with( + graph=model, + framework="pytorch-lightning", + ) + + @mock.patch.dict(os.environ, {}) def test_comet_metrics_safe(comet_mock, tmp_path, monkeypatch): """Test that CometLogger.log_metrics doesn't do inplace modification of metrics.""" From 27bdbcf602b312de1391bdcd083841a53f600386 Mon Sep 17 00:00:00 2001 From: Alexander Barannikov Date: Mon, 25 Nov 2024 18:13:38 +0000 Subject: [PATCH 25/25] fix linter warning --- src/lightning/pytorch/loggers/comet.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 7f626c59f985c..7ff7605249f2f 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -19,7 +19,8 @@ import logging import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Optional, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Literal, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -261,7 +262,7 @@ def __init__( self._project_name: Optional[str] = project self._experiment_key: Optional[str] = experiment_key self._prefix: Optional[str] = prefix - self._kwargs: Dict[str, Any] = kwargs + self._kwargs: dict[str, Any] = kwargs # needs to be set before the first `comet_ml` import # because comet_ml imported after another machine learning libraries (Torch) @@ -318,7 +319,7 @@ def experiment(self) -> comet_experiment: @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) self.experiment.__internal_api__log_parameters__( parameters=params, @@ -394,7 +395,7 @@ def version(self) -> Optional[str]: if self._experiment is not None: return self._experiment.get_key() - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # Save the experiment id in case an experiment object already exists,