@@ -98,6 +98,8 @@ class Checkpoint(Serializable):
98
98
details.
99
99
include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then
100
100
there must not be another object in ``to_save`` with key ``checkpointer``.
101
+ greater_or_equal (bool): if `True`, the latest equally scored model is stored. Otherwise, the first model.
102
+ Default, `False`.
101
103
102
104
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
103
105
torch.nn.parallel.DistributedDataParallel.html
@@ -245,6 +247,8 @@ def score_function(engine):
245
247
trainer.run(data_loader, max_epochs=10)
246
248
> ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ]
247
249
250
+ .. versionchanged:: 0.4.3
251
+ Added ``greater_or_equal`` parameter.
248
252
"""
249
253
250
254
Item = NamedTuple ("Item" , [("priority" , int ), ("filename" , str )])
@@ -261,6 +265,7 @@ def __init__(
261
265
global_step_transform : Optional [Callable ] = None ,
262
266
filename_pattern : Optional [str ] = None ,
263
267
include_self : bool = False ,
268
+ greater_or_equal : bool = False ,
264
269
) -> None :
265
270
266
271
if to_save is not None : # for compatibility with ModelCheckpoint
@@ -301,6 +306,7 @@ def __init__(
301
306
self .filename_pattern = filename_pattern
302
307
self ._saved = [] # type: List["Checkpoint.Item"]
303
308
self .include_self = include_self
309
+ self .greater_or_equal = greater_or_equal
304
310
305
311
def reset (self ) -> None :
306
312
"""Method to reset saved checkpoint names.
@@ -339,6 +345,12 @@ def _check_lt_n_saved(self, or_equal: bool = False) -> bool:
339
345
return True
340
346
return len (self ._saved ) < self .n_saved + int (or_equal )
341
347
348
+ def _compare_fn (self , new : Union [int , float ]) -> bool :
349
+ if self .greater_or_equal :
350
+ return new >= self ._saved [0 ].priority
351
+ else :
352
+ return new > self ._saved [0 ].priority
353
+
342
354
def __call__ (self , engine : Engine ) -> None :
343
355
344
356
global_step = None
@@ -354,7 +366,7 @@ def __call__(self, engine: Engine) -> None:
354
366
global_step = engine .state .get_event_attrib_value (Events .ITERATION_COMPLETED )
355
367
priority = global_step
356
368
357
- if self ._check_lt_n_saved () or self ._saved [ 0 ]. priority < priority :
369
+ if self ._check_lt_n_saved () or self ._compare_fn ( priority ) :
358
370
359
371
priority_str = f"{ priority } " if isinstance (priority , numbers .Integral ) else f"{ priority :.4f} "
360
372
0 commit comments