Skip to content

Commit b88b888

Browse files
committed
Update model_checkpoint.py
1 parent 20eeebe commit b88b888

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,13 @@ def __init__(
165165
every_n_epochs: int = 1,
166166
every_n_batches: int = -1,
167167
mode: str = "min",
168+
<<<<<<< HEAD
168169
period: Optional[int] = None,
170+
=======
171+
period: int = 1,
172+
every_n_epochs: int = 1,
173+
every_n_batches: int = -1,
174+
>>>>>>> Update model_checkpoint.py
169175
):
170176
super().__init__()
171177
self.monitor = monitor
@@ -206,13 +212,10 @@ def on_pretrain_routine_start(self, trainer, pl_module):
206212
self.save_function = trainer.save_checkpoint
207213

208214
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None:
209-
"""
210-
Save a checkpoint during the training loop if configured to do so.
211-
"""
212215
if self._should_skip_saving_checkpoint(trainer):
213216
return
214217
step = trainer.global_step
215-
skip_step = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0)
218+
skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0)
216219
if skip_step:
217220
return
218221
self.save_checkpoint(trainer, pl_module)
@@ -240,6 +243,14 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
240243
self.best_model_score = checkpointed_state["best_model_score"]
241244
self.best_model_path = checkpointed_state["best_model_path"]
242245

246+
def _should_skip_saving_checkpoint(self, trainer) -> bool:
247+
return (
248+
trainer.fast_dev_run # disable checkpointing with fast_dev_run
249+
or trainer.running_sanity_check # don't save anything during sanity check
250+
or self.save_top_k == 0 # no models are saved
251+
or self._last_global_step_saved == global_step # already saved at the last step
252+
)
253+
243254
def save_checkpoint(self, trainer, pl_module):
244255
"""
245256
Performs the main logic around saving a checkpoint.
@@ -277,6 +288,10 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool:
277288
def __validate_init_configuration(self):
278289
if self.save_top_k is not None and self.save_top_k < -1:
279290
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
291+
if self.every_n_epochs == 0 or self.every_n_epochs < -1:
292+
raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1')
293+
if self.every_n_batches == 0 or self.every_n_batches < -1:
294+
raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1')
280295
if self.monitor is None:
281296
# None: save last epoch, -1: save all epochs, 0: nothing is saved
282297
if self.save_top_k not in [None, -1, 0]:

0 commit comments

Comments
 (0)