Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upload checkpoint files to neptune from stream #17430

Merged
merged 12 commits into from
Apr 27, 2023
Merged
7 changes: 6 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for the TPU-v4 architecture ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))

-

- Added support for XLA's new PJRT runtime ([#17352](https://github.com/Lightning-AI/lightning/pull/17352))


Expand All @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for multiple optimizer parameter groups when using the FSDP strategy ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))


- Enabled saving the full model state dict when using the `FSDPStrategy` ([#16558](https://github.com/Lightning-AI/lightning/pull/16558))


Expand Down Expand Up @@ -72,6 +73,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308))


- Fixed a potential bug with uploading model checkpoints to Neptune.ai by uploading files from stream ([#17430](https://github.com/Lightning-AI/lightning/pull/17430))


- Fixed an issue that caused `num_nodes` not to be set correctly for `FSDPStrategy` ([#17438](https://github.com/Lightning-AI/lightning/pull/17438))


Expand Down
10 changes: 7 additions & 3 deletions src/lightning/pytorch/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,18 @@
import neptune
from neptune import Run
from neptune.handler import Handler
from neptune.types import File
from neptune.utils import stringify_unsupported
elif _NEPTUNE_CLIENT_AVAILABLE:
# <1.0 package structure
import neptune.new as neptune
from neptune.new import Run
from neptune.new.handler import Handler
from neptune.new.types import File
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from neptune.new.utils import stringify_unsupported
else:
# needed for tests, mocks and function signatures
neptune, Run, Handler, stringify_unsupported = None, None, None, None
neptune, Run, Handler, File, stringify_unsupported = None, None, None, None, None

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -483,7 +485,8 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
file_names.add(model_last_name)
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
with open(checkpoint_callback.last_model_path, "rb") as fp:
self.run[f"{checkpoints_namespace}/{model_last_name}"] = File.from_stream(fp)

# save best k models
if hasattr(checkpoint_callback, "best_k_models"):
Expand All @@ -498,7 +501,8 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:

model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
file_names.add(model_name)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
with open(checkpoint_callback.best_model_path, "rb") as fp:
self.run[f"{checkpoints_namespace}/{model_name}"] = File.from_stream(fp)

# remove old models logged to experiment if they are not part of best k models at this point
if self.run.exists(checkpoints_namespace):
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
mock.patch("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
mock.patch("lightning.pytorch.loggers.neptune.Run", new=mock.Mock),
mock.patch("lightning.pytorch.loggers.neptune.Handler", new=mock.Mock),
mock.patch("lightning.pytorch.loggers.neptune.File", new=mock.Mock()),
mock.patch("lightning.pytorch.loggers.wandb.wandb"),
mock.patch("lightning.pytorch.loggers.wandb.Run", new=mock.Mock),
)
Expand Down
25 changes: 13 additions & 12 deletions tests/tests_pytorch/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def _fit_and_test(self, logger, model):
assert trainer.log_dir == os.path.join(os.getcwd(), ".neptune")

@pytest.mark.usefixtures("tmpdir_unittest_fixture")
@patch("lightning.pytorch.loggers.neptune.File", new=mock.Mock())
def test_neptune_leave_open_experiment_after_fit(self, neptune):
"""Verify that neptune experiment was NOT closed after training."""
# given
Expand All @@ -216,6 +217,7 @@ def test_neptune_leave_open_experiment_after_fit(self, neptune):
assert run_instance_mock.stop.call_count == 0

@pytest.mark.usefixtures("tmpdir_unittest_fixture")
@patch("lightning.pytorch.loggers.neptune.File", new=mock.Mock())
def test_neptune_log_metrics_on_trained_model(self, neptune):
"""Verify that trained models do log data."""

Expand Down Expand Up @@ -305,6 +307,7 @@ def test_log_model_summary(self, neptune):
self.assertEqual(run_instance_mock.__getitem__.call_count, 0)
run_instance_mock.__setitem__.assert_called_once_with(model_summary_key, file_from_content_mock)

@patch("builtins.open", mock.mock_open(read_data="test"))
def test_after_save_checkpoint(self, neptune):
test_variants = [
({}, "training/model"),
Expand All @@ -329,26 +332,24 @@ def test_after_save_checkpoint(self, neptune):
best_model_score=None,
)

# when: save checkpoint
logger.after_save_checkpoint(cb_mock)
with patch("lightning.pytorch.loggers.neptune.File", side_effect=mock.Mock()) as mock_file:
# when: save checkpoint
logger.after_save_checkpoint(cb_mock)

# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 4)
self.assertEqual(run_attr_mock.upload.call_count, 4)
run_instance_mock.__setitem__.assert_called_once_with(
f"{model_key_prefix}/best_model_path", os.path.join(models_root_dir, "best_model")
)
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/last")
self.assertEqual(run_instance_mock.__setitem__.call_count, 3)
self.assertEqual(run_instance_mock.__getitem__.call_count, 2)
self.assertEqual(run_attr_mock.upload.call_count, 2)

self.assertEqual(mock_file.from_stream.call_count, 2)

run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/best_model")

run_attr_mock.upload.assert_has_calls(
[
call(os.path.join(models_root_dir, "last")),
call(os.path.join(models_root_dir, "model1")),
call(os.path.join(models_root_dir, "model2/with/slashes")),
call(os.path.join(models_root_dir, "best_model")),
]
)

Expand Down