Skip to content

Commit

Permalink
Fix support for ModelCheckpoint monitors with dots (#12783)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and otaj committed Apr 25, 2022
1 parent 026a245 commit 5bdd603
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly ([#12716](https://github.com/PyTorchLightning/pytorch-lightning/pull/12716))


- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))


## [1.6.1] - 2022-04-13

### Changed
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,12 @@ def _format_checkpoint_name(
if auto_insert_metric_name:
filename = filename.replace(group, name + "={" + name)

# support for dots: https://stackoverflow.com/a/7934969
filename = filename.replace(group, f"{{0[{name}]")

if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
filename = filename.format(metrics)

if prefix:
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
Expand Down
6 changes: 6 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,12 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
)
assert ckpt_name == "epoch=003-val_acc=0.03"

# dots in the metric name
ckpt_name = ModelCheckpoint._format_checkpoint_name(
"mAP@0.50={val/mAP@0.50:.4f}", {"val/mAP@0.50": 0.2}, auto_insert_metric_name=False
)
assert ckpt_name == "mAP@0.50=0.2000"


class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = ".tpkc"
Expand Down

0 comments on commit 5bdd603

Please sign in to comment.