diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1b251b8fb06fa..2762664b9b0b1 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] - YYYY-MM-DD + +### Changed + +- CometML logger was updated to support the recent Comet SDK ([#20275](https://github.com/Lightning-AI/pytorch-lightning/pull/20275)) + ## [2.4.0] - 2024-08-06 diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 9c05317655129..7ff7605249f2f 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -20,23 +20,26 @@ import os from argparse import Namespace from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn import Module from typing_extensions import override -from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from lightning.fabric.utilities.logger import _convert_params +from lightning.fabric.utilities.rank_zero import _get_rank from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: from comet_ml import ExistingExperiment, Experiment, OfflineExperiment log = logging.getLogger(__name__) -_COMET_AVAILABLE = RequirementCache("comet-ml>=3.31.0", module="comet_ml") +_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml") + +FRAMEWORK_NAME = "pytorch-lightning" +comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"] class CometLogger(Logger): @@ -61,13 +64,11 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - api_key=os.environ.get("COMET_API_KEY"), + api_key=os.environ.get("COMET_API_KEY"), # Optional workspace=os.environ.get("COMET_WORKSPACE"), # Optional - save_dir=".", # Optional - project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional + project="default_project", # Optional experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"), # Optional - experiment_name="lightning_logs", # Optional + name="lightning_logs", # Optional ) trainer = Trainer(logger=comet_logger) @@ -79,11 +80,10 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - save_dir=".", workspace=os.environ.get("COMET_WORKSPACE"), # Optional - project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional - experiment_name="lightning_logs", # Optional + project="default_project", # Optional + name="lightning_logs", # Optional + online=False ) trainer = Trainer(logger=comet_logger) @@ -107,6 +107,9 @@ def __init__(self, *args, **kwarg): # log multiple parameters logger.log_hyperparams({"batch_size": 16, "learning_rate": 0.001}) + # log nested parameters + logger.log_hyperparams({"specific": {'param': {'subparam': "value"}}}) + **Log Metrics:** .. code-block:: python @@ -117,6 +120,9 @@ def __init__(self, *args, **kwarg): # add multiple metrics logger.log_metrics({"train/loss": 0.001, "val/loss": 0.002}) + # add nested metrics + logger.log_metrics({"specific": {'metric': {'submetric': "value"}}}) + **Access the Comet Experiment object:** You can gain access to the underlying Comet @@ -167,100 +173,134 @@ def __init__(self, *args, **kwarg): - `Comet Documentation `__ Args: - api_key: Required in online mode. API key, found on Comet.ml. If not given, this - will be loaded from the environment variable COMET_API_KEY or ~/.comet.config - if either exists. - save_dir: Required in offline mode. The path for the directory to save local - comet logs. If given, this also sets the directory for saving checkpoints. - project_name: Optional. Send your experiment to a specific project. - Otherwise will be sent to Uncategorized Experiments. - If the project name does not already exist, Comet.ml will create a new project. - rest_api_key: Optional. Rest API key found in Comet.ml settings. - This is used to determine version number - experiment_name: Optional. String representing the name for this particular experiment on Comet.ml. - experiment_key: Optional. If set, restores from existing experiment. - offline: If api_key and save_dir are both given, this determines whether - the experiment will be in online or offline mode. This is useful if you use - save_dir to control the checkpoints directory and have a ~/.comet.config - file but still want to run offline experiments. - prefix: A string to put at the beginning of metric keys. - \**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by + api_key: Comet API key. It's recommended to configure the API Key with `comet login`. + workspace: Comet workspace name. If not provided, uses the default workspace. + project: Comet project name. Defaults to `Uncategorized`. + experiment_key: The Experiment identifier to be used for logging. This is used either to append + data to an Existing Experiment or to control the key of new experiments (for example to match another + identifier). Must be an alphanumeric string whose length is between 32 and 50 characters. + mode: Control how the Comet experiment is started. + * ``"get_or_create"``: Starts a fresh experiment if required, or persists logging to an existing one. + * ``"get"``: Continue logging to an existing experiment identified by the ``experiment_key`` value. + * ``"create"``: Always creates of a new experiment, useful for HPO sweeps. + online: If True, the data will be logged to Comet server, otherwise it will be stored + locally in an offline experiment. Default is ``True``. + prefix: The prefix to add to names of the logged metrics. + example: prefix=`exp1`, then metric name will be logged as `exp1_metric_name` + **kwargs: Additional arguments like `name`, `log_code`, `offline_directory` etc. used by :class:`CometExperiment` can be passed as keyword arguments in this logger. Raises: ModuleNotFoundError: If required Comet package is not installed on the device. - MisconfigurationException: - If neither ``api_key`` nor ``save_dir`` are passed as arguments. """ - LOGGER_JOIN_CHAR = "-" - def __init__( self, + *, api_key: Optional[str] = None, - save_dir: Optional[str] = None, - project_name: Optional[str] = None, - rest_api_key: Optional[str] = None, - experiment_name: Optional[str] = None, + workspace: Optional[str] = None, + project: Optional[str] = None, experiment_key: Optional[str] = None, - offline: bool = False, - prefix: str = "", + mode: Optional[Literal["get_or_create", "get", "create"]] = None, + online: Optional[bool] = None, + prefix: Optional[str] = None, **kwargs: Any, ): if not _COMET_AVAILABLE: raise ModuleNotFoundError(str(_COMET_AVAILABLE)) + super().__init__() - self._experiment = None - self._save_dir: Optional[str] - self.rest_api_key: Optional[str] + + ################################################## + # HANDLE PASSED OLD TYPE PARAMS + + # handle old "experiment_name" param + if "experiment_name" in kwargs: + log.warning("The parameter `experiment_name` is deprecated, please use `name` instead.") + experiment_name = kwargs.pop("experiment_name") + + if "name" not in kwargs: + kwargs["name"] = experiment_name + else: + log.warning("You specified both `experiment_name` and `name` parameters, please use `name` only") + + # handle old "project_name" param + if "project_name" in kwargs: + log.warning("The parameter `project_name` is deprecated, please use `project` instead.") + if project is None: + project = kwargs.pop("project_name") + else: + log.warning("You specified both `project_name` and `project` parameters, please use `project` only") + + # handle old "offline" experiment flag + if "offline" in kwargs: + log.warning("The parameter `offline is deprecated, please use `online` instead.") + if online is None: + online = kwargs.pop("offline") + else: + log.warning("You specified both `offline` and `online` parameters, please use `online` only") + + # handle old "save_dir" param + if "save_dir" in kwargs: + log.warning("The parameter `save_dir` is deprecated, please use `offline_directory` instead.") + if "offline_directory" not in kwargs: + kwargs["offline_directory"] = kwargs.pop("save_dir") + else: + log.warning( + "You specified both `save_dir` and `offline_directory` parameters, " + "please use `offline_directory` only" + ) + ################################################## + + self._api_key: Optional[str] = api_key + self._experiment: Optional[comet_experiment] = None + self._workspace: Optional[str] = workspace + self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode + self._online: Optional[bool] = online + self._project_name: Optional[str] = project + self._experiment_key: Optional[str] = experiment_key + self._prefix: Optional[str] = prefix + self._kwargs: dict[str, Any] = kwargs # needs to be set before the first `comet_ml` import + # because comet_ml imported after another machine learning libraries (Torch) os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1" import comet_ml - # Determine online or offline mode based on which arguments were passed to CometLogger - api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config()) - - if api_key is not None and save_dir is not None: - self.mode = "offline" if offline else "online" - self.api_key = api_key - self._save_dir = save_dir - elif api_key is not None: - self.mode = "online" - self.api_key = api_key - self._save_dir = None - elif save_dir is not None: - self.mode = "offline" - self._save_dir = save_dir - else: - # If neither api_key nor save_dir are passed as arguments, raise an exception - raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.") - - log.info(f"CometLogger will be initialized in {self.mode} mode") - - self._project_name: Optional[str] = project_name - self._experiment_key: Optional[str] = experiment_key - self._experiment_name: Optional[str] = experiment_name - self._prefix: str = prefix - self._kwargs: Any = kwargs - self._future_experiment_key: Optional[str] = None + self._comet_config = comet_ml.ExperimentConfig(**self._kwargs) - if rest_api_key is not None: - from comet_ml.api import API + # create real experiment only on main node/process (when strategy=auto/ddp) + if _get_rank() is not None and _get_rank() != 0: + return + + self._create_experiment() + + def _create_experiment(self) -> None: + import comet_ml - # Comet.ml rest API, used to determine version number - self.rest_api_key = rest_api_key - self.comet_api = API(self.rest_api_key) - else: - self.rest_api_key = None - self.comet_api = None + self._experiment = comet_ml.start( + api_key=self._api_key, + workspace=self._workspace, + project=self._project_name, + experiment_key=self._experiment_key, + mode=self._mode, + online=self._online, + experiment_config=self._comet_config, + ) + + if self._experiment is None: + raise comet_ml.exceptions.ExperimentNotFound("Failed to create Comet experiment.") + + self._experiment_key = self._experiment.get_key() + self._project_name = self._experiment.project_name + self._experiment.log_other("Created from", FRAMEWORK_NAME) @property @rank_zero_experiment - def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperiment"]: + def experiment(self) -> comet_experiment: r"""Actual Comet object. To use Comet features in your :class:`~lightning.pytorch.core.LightningModule` do the following. @@ -269,38 +309,11 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi self.logger.experiment.some_comet_function() """ - if self._experiment is not None and self._experiment.alive: - return self._experiment - - if self._future_experiment_key is not None: - os.environ["COMET_EXPERIMENT_KEY"] = self._future_experiment_key - - from comet_ml import ExistingExperiment, Experiment, OfflineExperiment - - try: - if self.mode == "online": - if self._experiment_key is None: - self._experiment = Experiment(api_key=self.api_key, project_name=self._project_name, **self._kwargs) - self._experiment_key = self._experiment.get_key() - else: - self._experiment = ExistingExperiment( - api_key=self.api_key, - project_name=self._project_name, - previous_experiment=self._experiment_key, - **self._kwargs, - ) - else: - self._experiment = OfflineExperiment( - offline_directory=self.save_dir, project_name=self._project_name, **self._kwargs - ) - self._experiment.log_other("Created from", "pytorch-lightning") - finally: - if self._future_experiment_key is not None: - os.environ.pop("COMET_EXPERIMENT_KEY") - self._future_experiment_key = None - if self._experiment_name: - self._experiment.set_name(self._experiment_name) + # if by some chance there is no experiment created yet (for example, when strategy=ddp_spawn) + # then we will create a new one + if not self._experiment: + self._create_experiment() return self._experiment @@ -308,43 +321,44 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi @rank_zero_only def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) - params = _flatten_dict(params) - self.experiment.log_parameters(params) + self.experiment.__internal_api__log_parameters__( + parameters=params, + framework=FRAMEWORK_NAME, + flatten_nested=True, + source="manual", + ) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - # Comet.ml expects metrics to be a dictionary of detached tensors on CPU + # Comet.com expects metrics to be a dictionary of detached tensors on CPU metrics_without_epoch = metrics.copy() for key, val in metrics_without_epoch.items(): if isinstance(val, Tensor): metrics_without_epoch[key] = val.cpu().detach() epoch = metrics_without_epoch.pop("epoch", None) - metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR) - self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) - - def reset_experiment(self) -> None: - self._experiment = None + self.experiment.__internal_api__log_metrics__( + metrics_without_epoch, + step=step, + epoch=epoch, + prefix=self._prefix, + framework=FRAMEWORK_NAME, + ) @override @rank_zero_only def finalize(self, status: str) -> None: - r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you - need to log any more data, you need to create an ExistingCometExperiment. For example, to log data when testing - your model after training, because when training is finalized :meth:`CometLogger.finalize` is called. - - This happens automatically in the :meth:`~CometLogger.experiment` property, when - ``self._experiment`` is set to ``None``, i.e. ``self.reset_experiment()``. - - """ + """We will not end experiment (will not call self._experiment.end()) here to have an ability to continue using + it after training is complete but instead of ending we will upload/save all the data.""" if self._experiment is None: # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been # initialized there return - self.experiment.end() - self.reset_experiment() + + # just save the data + self.experiment.flush() @property @override @@ -355,61 +369,31 @@ def save_dir(self) -> Optional[str]: The path to the save directory. """ - return self._save_dir + return self._comet_config.offline_directory @property @override - def name(self) -> str: + def name(self) -> Optional[str]: """Gets the project name. Returns: - The project name if it is specified, else "comet-default". + The project name if it is specified. """ - # Don't create an experiment if we don't have one - if self._experiment is not None and self._experiment.project_name is not None: - return self._experiment.project_name - - if self._project_name is not None: - return self._project_name - - return "comet-default" + return self._project_name @property @override - def version(self) -> str: + def version(self) -> Optional[str]: """Gets the version. Returns: - The first one of the following that is set in the following order - - 1. experiment id. - 2. experiment key. - 3. "COMET_EXPERIMENT_KEY" environment variable. - 4. future experiment key. - - If none are present generates a new guid. + The experiment key if present """ # Don't create an experiment if we don't have one if self._experiment is not None: - return self._experiment.id - - if self._experiment_key is not None: - return self._experiment_key - - if "COMET_EXPERIMENT_KEY" in os.environ: - return os.environ["COMET_EXPERIMENT_KEY"] - - if self._future_experiment_key is not None: - return self._future_experiment_key - - import comet_ml - - # Pre-generate an experiment key - self._future_experiment_key = comet_ml.generate_guid() - - return self._future_experiment_key + return self._experiment.get_key() def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() @@ -417,7 +401,7 @@ def __getstate__(self) -> dict[str, Any]: # Save the experiment id in case an experiment object already exists, # this way we could create an ExistingExperiment pointing to the same # experiment - state["_experiment_key"] = self._experiment.id if self._experiment is not None else None + state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None # Remove the experiment object as it contains hard to pickle objects # (like network connections), the experiment object will be recreated if @@ -428,4 +412,7 @@ def __getstate__(self) -> dict[str, Any]: @override def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: if self._experiment is not None: - self._experiment.set_model_graph(model) + self._experiment.__internal_api__set_model_graph__( + graph=model, + framework=FRAMEWORK_NAME, + ) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 7cc5cc94fe8cc..3912a4dea0db9 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -94,18 +94,20 @@ def comet_mock(monkeypatch): comet = ModuleType("comet_ml") monkeypatch.setitem(sys.modules, "comet_ml", comet) - comet.Experiment = Mock() - comet.ExistingExperiment = Mock() - comet.OfflineExperiment = Mock() - comet.API = Mock() + # to support dunder methods calling we will create a special mock + comet_experiment = MagicMock(name="CommonExperiment") + setattr(comet_experiment, "__internal_api__set_model_graph__", MagicMock()) + setattr(comet_experiment, "__internal_api__log_metrics__", MagicMock()) + setattr(comet_experiment, "__internal_api__log_parameters__", MagicMock()) + + comet.Experiment = MagicMock(name="Experiment", return_value=comet_experiment) + comet.ExistingExperiment = MagicMock(name="ExistingExperiment", return_value=comet_experiment) + comet.OfflineExperiment = MagicMock(name="OfflineExperiment", return_value=comet_experiment) + + comet.ExperimentConfig = Mock() + comet.start = Mock(name="comet_ml.start", return_value=comet.Experiment()) comet.config = Mock() - comet_api = ModuleType("api") - comet_api.API = Mock() - monkeypatch.setitem(sys.modules, "comet_ml.api", comet_api) - - comet.api = comet_api - monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True) return comet diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 1b845c57ec35d..d6de763b8f74e 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -105,7 +105,9 @@ def log_metrics(self, metrics, step): if logger_class == CometLogger: logger.experiment.id = "foo" - logger.experiment.project_name = "bar" + logger._comet_config.offline_directory = None + logger._project_name = "bar" + logger.experiment.get_key.return_value = "SOME_KEY" if logger_class == NeptuneLogger: logger._retrieve_run_data = Mock() @@ -292,7 +294,9 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_moc _patch_comet_atexit(monkeypatch) logger = _instantiate_logger(CometLogger, save_dir=tmp_path, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) - logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0) + logger.experiment.__internal_api__log_metrics__.assert_called_once_with( + {"test": 1.0}, epoch=None, step=0, prefix=prefix, framework="pytorch-lightning" + ) # MLflow Metric = mlflow_mock.entities.Metric diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index e467c63543ede..34c24211d13c9 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -13,15 +13,13 @@ # limitations under the License. import os from unittest import mock -from unittest.mock import DEFAULT, Mock, patch +from unittest.mock import Mock, call -import pytest -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CometLogger -from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch import tensor +FRAMEWORK_NAME = "pytorch-lightning" + def _patch_comet_atexit(monkeypatch): """Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it.""" @@ -33,195 +31,163 @@ def _patch_comet_atexit(monkeypatch): @mock.patch.dict(os.environ, {}) def test_comet_logger_online(comet_mock): """Test comet online with mocks.""" - # Test api_key given - comet_experiment = comet_mock.Experiment - logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general") - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") - - # Test both given - comet_experiment.reset_mock() - logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general") - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") - - # Test already exists - comet_existing = comet_mock.ExistingExperiment - logger = CometLogger( - experiment_key="test", - experiment_name="experiment", + + comet_start = comet_mock.start + + # Test api_key given with old param "project_name" + _logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general") + comet_start.assert_called_once_with( api_key="key", workspace="dummy-test", - project_name="general", + project="general", + experiment_key=None, + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), ) - _ = logger.experiment - comet_existing.assert_called_once_with( - api_key="key", workspace="dummy-test", project_name="general", previous_experiment="test" + + # Test online given + comet_start.reset_mock() + _logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general", online=True) + comet_start.assert_called_once_with( + api_key="key", + workspace="dummy-test", + project="general", + experiment_key=None, + mode=None, + online=True, + experiment_config=comet_mock.ExperimentConfig(), ) - comet_existing().set_name.assert_called_once_with("experiment") - # API experiment - api = comet_mock.api.API - CometLogger(api_key="key", workspace="dummy-test", project_name="general", rest_api_key="rest") - api.assert_called_once_with("rest") + # Test experiment_key given + comet_start.reset_mock() + _logger = CometLogger( + experiment_key="test_key", + api_key="key", + project="general", + ) + comet_start.assert_called_once_with( + api_key="key", + workspace=None, + project="general", + experiment_key="test_key", + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), + ) @mock.patch.dict(os.environ, {}) -def test_comet_experiment_resets_if_not_alive(comet_mock): - """Test that the CometLogger creates a new experiment if the old one is not alive anymore.""" +def test_comet_experiment_is_still_alive_after_training_complete(comet_mock): + """Test that the CometLogger will not end an experiment after training is complete.""" + logger = CometLogger() - assert logger._experiment is None - alive_experiment = Mock(alive=True) - logger._experiment = alive_experiment - assert logger.experiment is alive_experiment + assert logger.experiment is not None - unalive_experiment = Mock(alive=False) - logger._experiment = unalive_experiment - assert logger.experiment is not unalive_experiment + logger._experiment = Mock() + logger.finalize("ended") + # Assert that data was saved to comet.com + logger._experiment.flush.assert_called_once() -@mock.patch.dict(os.environ, {}) -def test_comet_logger_no_api_key_given(comet_mock): - """Test that CometLogger fails to initialize if both api key and save_dir are missing.""" - with pytest.raises(MisconfigurationException, match="requires either api_key or save_dir"): - comet_mock.config.get_api_key.return_value = None - CometLogger(workspace="dummy-test", project_name="general") + # Assert that was not ended + logger._experiment.end.assert_not_called() @mock.patch.dict(os.environ, {}) def test_comet_logger_experiment_name(comet_mock): """Test that Comet Logger experiment name works correctly.""" - api_key = "key" - experiment_name = "My Name" + api_key = "api_key" + experiment_name = "My Experiment Name" + + comet_start = comet_mock.start - # Test api_key given - comet_experiment = comet_mock.Experiment + # here we use old style arg "experiment_name" (new one is "name") logger = CometLogger(api_key=api_key, experiment_name=experiment_name) - assert logger._experiment is None + comet_start.assert_called_once_with( + api_key=api_key, + workspace=None, + project=None, + experiment_key=None, + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), + ) + # check that we saved "experiment name" in kwargs as new "name" arg + assert logger._kwargs["name"] == experiment_name + assert "experiment_name" not in logger._kwargs - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) - comet_experiment().set_name.assert_called_once_with(experiment_name) + # check that "experiment name" was passed to experiment config correctly + assert call(experiment_name=experiment_name) not in comet_mock.ExperimentConfig.call_args_list + assert call(name=experiment_name) in comet_mock.ExperimentConfig.call_args_list @mock.patch.dict(os.environ, {}) -def test_comet_logger_manual_experiment_key(comet_mock): - """Test that Comet Logger respects manually set COMET_EXPERIMENT_KEY.""" +def test_comet_version(comet_mock): + """Test that CometLogger.version returns an Experiment key.""" api_key = "key" - experiment_key = "96346da91469407a85641afe5766b554" - - instantiation_environ = {} - - def save_os_environ(*args, **kwargs): - nonlocal instantiation_environ - instantiation_environ = os.environ.copy() - - return DEFAULT - - comet_experiment = comet_mock.Experiment - comet_experiment.side_effect = save_os_environ - - # Test api_key given - with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}): - logger = CometLogger(api_key=api_key) - assert logger.version == experiment_key - assert logger._experiment is None + experiment_name = "My Name" - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) + logger = CometLogger(api_key=api_key, name=experiment_name) + assert logger._experiment is not None + _ = logger.version - assert instantiation_environ["COMET_EXPERIMENT_KEY"] == experiment_key + logger._experiment.get_key.assert_called() @mock.patch.dict(os.environ, {}) -def test_comet_logger_dirs_creation(comet_mock, tmp_path, monkeypatch): - """Test that the logger creates the folders and files in the right place.""" +def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger removes the epoch key from the metrics dict and passes it as argument.""" _patch_comet_atexit(monkeypatch) - comet_experiment = comet_mock.OfflineExperiment - - comet_mock.config.get_api_key.return_value = None - comet_mock.generate_guid = Mock() - comet_mock.generate_guid.return_value = "4321" - logger = CometLogger(project_name="test", save_dir=str(tmp_path)) - assert not os.listdir(tmp_path) - assert logger.mode == "offline" - assert logger.save_dir == str(tmp_path) - assert logger.name == "test" - assert logger.version == "4321" - - _ = logger.experiment - comet_experiment.assert_called_once_with(offline_directory=str(tmp_path), project_name="test") - - # mock return values of experiment - logger.experiment.id = "1" - logger.experiment.project_name = "test" - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmp_path, logger=logger, max_epochs=1, limit_train_batches=3, limit_val_batches=3 + logger.log_metrics({"test": 1, "epoch": 1}, step=123) + logger.experiment.__internal_api__log_metrics__.assert_called_once_with( + {"test": 1}, + epoch=1, + step=123, + prefix=logger._prefix, + framework="pytorch-lightning", ) - assert trainer.log_dir == logger.save_dir - trainer.fit(model) - - assert trainer.checkpoint_callback.dirpath == str(tmp_path / "test" / "1" / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"} - assert trainer.log_dir == logger.save_dir @mock.patch.dict(os.environ, {}) -def test_comet_name_default(comet_mock): - """Test that CometLogger.name don't create an Experiment and returns a default value.""" - api_key = "key" - logger = CometLogger(api_key=api_key) - assert logger._experiment is None - assert logger.name == "comet-default" - assert logger._experiment is None - +def test_comet_log_hyperparams(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger.log_hyperparams calls internal API method.""" + _patch_comet_atexit(monkeypatch) -@mock.patch.dict(os.environ, {}) -def test_comet_name_project_name(comet_mock): - """Test that CometLogger.name does not create an Experiment and returns project name if passed.""" - api_key = "key" - project_name = "My Project Name" - logger = CometLogger(api_key=api_key, project_name=project_name) - assert logger._experiment is None - assert logger.name == project_name - assert logger._experiment is None + logger = CometLogger(project_name="test") + hyperparams = { + "batch_size": 256, + "config": { + "SLURM Job ID": "22334455", + "RGB slurm jobID": "12345678", + "autoencoder_model": False, + }, + } + logger.log_hyperparams(hyperparams) + + logger.experiment.__internal_api__log_parameters__.assert_called_once_with( + parameters=hyperparams, + framework=FRAMEWORK_NAME, + flatten_nested=True, + source="manual", + ) @mock.patch.dict(os.environ, {}) -def test_comet_version_without_experiment(comet_mock): - """Test that CometLogger.version does not create an Experiment.""" - api_key = "key" - experiment_name = "My Name" - comet_mock.generate_guid = Mock() - comet_mock.generate_guid.return_value = "1234" - - logger = CometLogger(api_key=api_key, experiment_name=experiment_name) - assert logger._experiment is None - - first_version = logger.version - assert first_version is not None - assert logger.version == first_version - assert logger._experiment is None - - _ = logger.experiment - - logger.reset_experiment() +def test_comet_log_graph(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger.log_hyperparams calls internal API method.""" + _patch_comet_atexit(monkeypatch) - second_version = logger.version == "1234" - assert second_version is not None - assert second_version != first_version + logger = CometLogger(project_name="test") + model = Mock() + logger.log_graph(model=model) -@mock.patch.dict(os.environ, {}) -def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): - """Test that CometLogger removes the epoch key from the metrics dict and passes it as argument.""" - _patch_comet_atexit(monkeypatch) - logger = CometLogger(project_name="test", save_dir=str(tmp_path)) - logger.log_metrics({"test": 1, "epoch": 1}, step=123) - logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) + logger.experiment.__internal_api__set_model_graph__.assert_called_once_with( + graph=model, + framework="pytorch-lightning", + ) @mock.patch.dict(os.environ, {}) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 4fc836c764833..716a6556e61bd 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1622,8 +1622,13 @@ def _test_logger_init_args(logger_name, init, unresolved=None): def test_comet_logger_init_args(): _test_logger_init_args( "CometLogger", - init={"save_dir": "comet"}, # Resolve from CometLogger.__init__ - unresolved={"workspace": "comet"}, # Resolve from Comet{,Existing,Offline}Experiment.__init__ + init={ + "experiment_key": "some_key", # Resolve from CometLogger.__init__ + "workspace": "comet", + }, + unresolved={ + "save_dir": "comet", # Resolve from CometLogger.__init__ as kwarg + }, )