Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper support for Remote Stop and Remote Abort with NeptuneLogger #19130

Merged
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470))


- Fixed support for Remote Stop and Remote Abort with NeptuneLogger ([#19130](https://github.com/Lightning-AI/pytorch-lightning/pull/19130))


-


Expand Down
34 changes: 26 additions & 8 deletions src/lightning/pytorch/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import logging
import os
from argparse import Namespace
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Union
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union

from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
Expand Down Expand Up @@ -48,6 +49,23 @@
_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"


def catch_inactive(func: Callable) -> Callable:
Raalsky marked this conversation as resolved.
Show resolved Hide resolved
"""Neptune client throws `InactiveRunException` when trying to log to an inactive run.
Raalsky marked this conversation as resolved.
Show resolved Hide resolved

This may happen when the run was stopped through the UI and the logger is still trying to log to it.

"""

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
from neptune.exceptions import InactiveRunException

with contextlib.suppress(InactiveRunException):
return func(*args, **kwargs)

return wrapper


class NeptuneLogger(Logger):
r"""Log using `Neptune <https://neptune.ai>`_.

Expand Down Expand Up @@ -245,10 +263,7 @@ def __init__(
if self._run_instance is not None:
self._retrieve_run_data()

if _NEPTUNE_AVAILABLE:
from neptune.handler import Handler
else:
from neptune.new.handler import Handler
from neptune.handler import Handler

# make sure that we've log integration version for outside `Run` instances
root_obj = self._run_instance
Expand Down Expand Up @@ -383,6 +398,7 @@ def run(self) -> "Run":

@override
@rank_zero_only
@catch_inactive
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
r"""Log hyperparameters to the run.

Expand Down Expand Up @@ -430,9 +446,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:

@override
@rank_zero_only
def log_metrics( # type: ignore[override]
self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
) -> None:
@catch_inactive
def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
"""Log metrics (numeric values) in Neptune runs.

Args:
Expand All @@ -450,6 +465,7 @@ def log_metrics( # type: ignore[override]

@override
@rank_zero_only
@catch_inactive
def finalize(self, status: str) -> None:
if not self._run_instance:
# When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
Expand All @@ -473,6 +489,7 @@ def save_dir(self) -> Optional[str]:
return os.path.join(os.getcwd(), ".neptune")

@rank_zero_only
@catch_inactive
def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
from neptune.types import File

Expand All @@ -483,6 +500,7 @@ def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) ->

@override
@rank_zero_only
@catch_inactive
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.

Expand Down
4 changes: 4 additions & 0 deletions tests/tests_pytorch/loggers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def __setitem__(self, key, value):
neptune_utils.stringify_unsupported = Mock()
monkeypatch.setitem(sys.modules, "neptune.utils", neptune_utils)

neptune_exceptions = ModuleType("exceptions")
neptune_exceptions.InactiveRunException = Exception
Raalsky marked this conversation as resolved.
Show resolved Hide resolved
monkeypatch.setitem(sys.modules, "neptune.exceptions", neptune_exceptions)

neptune.handler = neptune_handler
neptune.types = neptune_types
neptune.utils = neptune_utils
Expand Down
10 changes: 10 additions & 0 deletions tests/tests_pytorch/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,13 @@ def test_get_full_model_names_from_exp_structure():
}
expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys


def test_inactive_run(neptune_mock, tmp_path):
from neptune.exceptions import InactiveRunException

logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project")
run_instance_mock.__setitem__.side_effect = InactiveRunException

# this should work without any exceptions
Raalsky marked this conversation as resolved.
Show resolved Hide resolved
_fit_and_test(logger=logger, model=BoringModel(), tmp_path=tmp_path)
Loading