diff --git a/CHANGELOG.md b/CHANGELOG.md index f78569c1b7a0b..232967b16b832 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) +- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277)) + + - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f457e9de7d0fa..f05a10a41996b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -131,6 +131,16 @@ class ModelCheckpoint(Callback): ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... ) + # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard + # or Neptune, due to the presence of characters like '=' or '/') + # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... monitor='val/loss', + ... dirpath='my/path/', + ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', + ... auto_insert_metric_name=False + ... ) + # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback]) @@ -156,6 +166,7 @@ def __init__( save_weights_only: bool = False, mode: str = "min", period: int = 1, + auto_insert_metric_name: bool = True ): super().__init__() self.monitor = monitor @@ -164,6 +175,7 @@ def __init__( self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period + self.auto_insert_metric_name = auto_insert_metric_name self._last_global_step_saved = -1 self.current_score = None self.best_k_models = {} @@ -356,6 +368,7 @@ def _format_checkpoint_name( step: int, metrics: Dict[str, Any], prefix: str = "", + auto_insert_metric_name: bool = True ) -> str: if not filename: # filename is not set, use default name @@ -367,7 +380,10 @@ def _format_checkpoint_name( metrics.update({"epoch": epoch, 'step': step}) for group in groups: name = group[1:] - filename = filename.replace(group, name + "={" + name) + + if auto_insert_metric_name: + filename = filename.replace(group, name + "={" + name) + if name not in metrics: metrics[name] = 0 filename = filename.format(**metrics) @@ -392,6 +408,11 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, + ... filename='epoch={epoch}-validation_loss={val_loss:.2f}', + ... auto_insert_metric_name=False) + >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) + 'epoch=2-validation_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' @@ -400,7 +421,13 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], 'step=0.ckpt' """ - filename = self._format_checkpoint_name(self.filename, epoch, step, metrics) + filename = self._format_checkpoint_name( + self.filename, + epoch, + step, + metrics, + auto_insert_metric_name=self.auto_insert_metric_name) + if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3d5cddc4537a7..d1f33b2e6d007 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -425,6 +425,15 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03}) assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' + # auto_insert_metric_name=False + ckpt_name = ModelCheckpoint._format_checkpoint_name( + 'epoch={epoch:03d}-val_acc={val/acc}', + 3, + 2, + {'val/acc': 0.03}, + auto_insert_metric_name=False) + assert ckpt_name == 'epoch=003-val_acc=0.03' + class ModelCheckpointExtensionTest(ModelCheckpoint): FILE_EXTENSION = '.tpkc'