diff --git a/CHANGELOG.md b/CHANGELOG.md index bfc89a3..4eef25a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## neptune-pytorch 1.1.0 + +### Fixes +- Rename `save_model` to `log_model` and `save_checkpoint` to `log_checkpoint`. (https://github.com/neptune-ai/neptune-pytorch/pull/9) + ## neptune-pytorch 1.0.1 ### Fixes diff --git a/src/neptune_pytorch/impl/__init__.py b/src/neptune_pytorch/impl/__init__.py index aa0a827..b9f37a7 100644 --- a/src/neptune_pytorch/impl/__init__.py +++ b/src/neptune_pytorch/impl/__init__.py @@ -183,7 +183,7 @@ def hook(module, inp, output): def base_namespace(self): return self._base_namespace - def save_model(self, model_name: Optional[str] = None): + def log_model(self, model_name: Optional[str] = None): if model_name is None: # Default model name model_name = "model.pt" @@ -193,7 +193,7 @@ def save_model(self, model_name: Optional[str] = None): safe_upload_model(self._namespace_handler["model"], model_name, self.model) - def save_checkpoint(self, checkpoint_name: Optional[str] = None): + def log_checkpoint(self, checkpoint_name: Optional[str] = None): if checkpoint_name is None: # Default checkpoint name checkpoint_name = f"checkpoint_{self.ckpt_number}.pt" diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 8147306..54415e3 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -44,10 +44,10 @@ def test_e2e(model, dataset): run[npt_logger.base_namespace]["batch/loss"].append(loss.item()) - npt_logger.save_checkpoint() + npt_logger.log_checkpoint() # Save final model - npt_logger.save_model("model") + npt_logger.log_model("model") run.wait() run.exists(f"{npt_logger.base_namespace}/batch/loss")