@@ -165,7 +165,13 @@ def __init__(
165165 every_n_epochs : int = 1 ,
166166 every_n_batches : int = - 1 ,
167167 mode : str = "min" ,
168+ << << << < HEAD
168169 period : Optional [int ] = None ,
170+ == == == =
171+ period : int = 1 ,
172+ every_n_epochs : int = 1 ,
173+ every_n_batches : int = - 1 ,
174+ >> >> >> > Update model_checkpoint .py
169175 ):
170176 super ().__init__ ()
171177 self .monitor = monitor
@@ -206,13 +212,10 @@ def on_pretrain_routine_start(self, trainer, pl_module):
206212 self .save_function = trainer .save_checkpoint
207213
208214 def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ) - > None :
209- """
210- Save a checkpoint during the training loop if configured to do so.
211- """
212215 if self ._should_skip_saving_checkpoint (trainer ):
213216 return
214217 step = trainer .global_step
215- skip_step = self .every_n_batches < 1 or ((step + 1 ) % self .every_n_batches != 0 )
218+ skip_step = self .every_n_steps < 1 or ((step + 1 ) % self .every_n_steps != 0 )
216219 if skip_step :
217220 return
218221 self .save_checkpoint (trainer , pl_module )
@@ -240,6 +243,14 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
240243 self .best_model_score = checkpointed_state ["best_model_score" ]
241244 self .best_model_path = checkpointed_state ["best_model_path" ]
242245
246+ def _should_skip_saving_checkpoint (self , trainer ) -> bool :
247+ return (
248+ trainer .fast_dev_run # disable checkpointing with fast_dev_run
249+ or trainer .running_sanity_check # don't save anything during sanity check
250+ or self .save_top_k == 0 # no models are saved
251+ or self ._last_global_step_saved == global_step # already saved at the last step
252+ )
253+
243254 def save_checkpoint (self , trainer , pl_module ):
244255 """
245256 Performs the main logic around saving a checkpoint.
@@ -277,6 +288,10 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool:
277288 def __validate_init_configuration (self ):
278289 if self .save_top_k is not None and self .save_top_k < - 1 :
279290 raise MisconfigurationException (f'Invalid value for save_top_k={ self .save_top_k } . Must be None or >= -1' )
291+ if self .every_n_epochs == 0 or self .every_n_epochs < - 1 :
292+ raise MisconfigurationException (f'Invalid value for every_n_epochs={ self .every_n_epochs } . Must be positive or -1' )
293+ if self .every_n_batches == 0 or self .every_n_batches < - 1 :
294+ raise MisconfigurationException (f'Invalid value for every_n_batches={ self .every_n_batches } . Must be positive or -1' )
280295 if self .monitor is None :
281296 # None: save last epoch, -1: save all epochs, 0: nothing is saved
282297 if self .save_top_k not in [None , - 1 , 0 ]:
0 commit comments