From f2f3ef5d3dfde71e8e1618f3ba49ff8cd82ef7e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Fri, 23 Feb 2024 20:33:17 +0100 Subject: [PATCH] Proper support for Remote Stop and Remote Abort with NeptuneLogger (#19130) --- src/lightning/pytorch/CHANGELOG.md | 3 +++ src/lightning/pytorch/loggers/neptune.py | 30 +++++++++++++++------ tests/tests_pytorch/loggers/conftest.py | 4 +++ tests/tests_pytorch/loggers/test_neptune.py | 10 +++++++ 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 44d33dc05bf6c..b8873f6978a9d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -48,6 +48,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470)) +- Fixed support for Remote Stop and Remote Abort with NeptuneLogger ([#19130](https://github.com/Lightning-AI/pytorch-lightning/pull/19130)) + + - diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index 35a0aa5655bf0..558d1f91ca6fc 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -20,7 +20,8 @@ import logging import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Union +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -48,6 +49,19 @@ _INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning" +# Neptune client throws `InactiveRunException` when trying to log to an inactive run. +# This may happen when the run was stopped through the UI and the logger is still trying to log to it. +def _catch_inactive(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + from neptune.exceptions import InactiveRunException + + with contextlib.suppress(InactiveRunException): + return func(*args, **kwargs) + + return wrapper + + class NeptuneLogger(Logger): r"""Log using `Neptune `_. @@ -245,10 +259,7 @@ def __init__( if self._run_instance is not None: self._retrieve_run_data() - if _NEPTUNE_AVAILABLE: - from neptune.handler import Handler - else: - from neptune.new.handler import Handler + from neptune.handler import Handler # make sure that we've log integration version for outside `Run` instances root_obj = self._run_instance @@ -383,6 +394,7 @@ def run(self) -> "Run": @override @rank_zero_only + @_catch_inactive def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: r"""Log hyperparameters to the run. @@ -430,9 +442,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: @override @rank_zero_only - def log_metrics( # type: ignore[override] - self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None - ) -> None: + @_catch_inactive + def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: """Log metrics (numeric values) in Neptune runs. Args: @@ -450,6 +461,7 @@ def log_metrics( # type: ignore[override] @override @rank_zero_only + @_catch_inactive def finalize(self, status: str) -> None: if not self._run_instance: # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been @@ -473,6 +485,7 @@ def save_dir(self) -> Optional[str]: return os.path.join(os.getcwd(), ".neptune") @rank_zero_only + @_catch_inactive def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None: from neptune.types import File @@ -483,6 +496,7 @@ def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> @override @rank_zero_only + @_catch_inactive def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 8bc671b4c9cf6..0a0923b6902a6 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -141,6 +141,10 @@ def __setitem__(self, key, value): neptune_utils.stringify_unsupported = Mock() monkeypatch.setitem(sys.modules, "neptune.utils", neptune_utils) + neptune_exceptions = ModuleType("exceptions") + neptune_exceptions.InactiveRunException = Exception + monkeypatch.setitem(sys.modules, "neptune.exceptions", neptune_exceptions) + neptune.handler = neptune_handler neptune.types = neptune_types neptune.utils = neptune_utils diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index ea070912f90dc..13941c18db8e8 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -303,3 +303,13 @@ def test_get_full_model_names_from_exp_structure(): } expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"} assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys + + +def test_inactive_run(neptune_mock, tmp_path): + from neptune.exceptions import InactiveRunException + + logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project") + run_instance_mock.__setitem__.side_effect = InactiveRunException + + # this should work without any exceptions + _fit_and_test(logger=logger, model=BoringModel(), tmp_path=tmp_path)