Skip to content

Commit

Permalink
Merge pull request #5 from neptune-ai/correct-checkpoint-namespace
Browse files Browse the repository at this point in the history
Correct checkpoint namespace
  • Loading branch information
kshitij12345 authored May 4, 2023
2 parents bb1b306 + c9b315b commit 4e69f16
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/neptune_pytorch/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def save_checkpoint(self, checkpoint_name: Optional[str] = None):
# User is not expected to add extension
checkpoint_name = checkpoint_name + ".pt"

safe_upload(self._namespace_handler["model"], checkpoint_name, self.model)
safe_upload(self._namespace_handler["model"]["checkpoints"], checkpoint_name, self.model)

def __del__(self):
# Remove hooks
Expand Down
4 changes: 2 additions & 2 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def test_e2e(model, dataset):

run.wait()
run.exists(f"{npt_logger.base_namespace}/batch/loss")
run.exists(f"{npt_logger.base_namespace}/model/checkpoint_1.pt")
run.exists(f"{npt_logger.base_namespace}/model/checkpoint_2.pt")
run.exists(f"{npt_logger.base_namespace}/model/checkpoints/checkpoint_1.pt")
run.exists(f"{npt_logger.base_namespace}/model/checkpoints/checkpoint_2.pt")
run.exists(f"{npt_logger.base_namespace}/model/model.pt")
run.exists(f"{npt_logger.base_namespace}/model/summary")
run.exists(f"{npt_logger.base_namespace}/model/visualization")

0 comments on commit 4e69f16

Please sign in to comment.