Skip to content

Commit 4203efa

Browse files
committed
Fix #387 so that checkpoint saver works with max history of 1. Add checkpoint-hist arg to train.py.
1 parent 99b82ae commit 4203efa

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

timm/utils/checkpoint_saver.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def save_checkpoint(self, epoch, metric=None):
6666
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
6767
self._save(tmp_save_path, epoch, metric)
6868
if os.path.exists(last_save_path):
69-
os.unlink(last_save_path) # required for Windows support.
69+
os.unlink(last_save_path) # required for Windows support.
7070
os.rename(tmp_save_path, last_save_path)
7171
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
7272
if (len(self.checkpoint_files) < self.max_history
@@ -118,7 +118,7 @@ def _save(self, save_path, epoch, metric=None):
118118
def _cleanup_checkpoints(self, trim=0):
119119
trim = min(len(self.checkpoint_files), trim)
120120
delete_index = self.max_history - trim
121-
if delete_index <= 0 or len(self.checkpoint_files) <= delete_index:
121+
if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
122122
return
123123
to_delete = self.checkpoint_files[delete_index:]
124124
for d in to_delete:
@@ -147,7 +147,4 @@ def find_recovery(self):
147147
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
148148
files = glob.glob(recovery_path + '*' + self.extension)
149149
files = sorted(files)
150-
if len(files):
151-
return files[0]
152-
else:
153-
return ''
150+
return files[0] if len(files) else ''

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@
236236
help='how many batches to wait before logging training status')
237237
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
238238
help='how many batches to wait before writing recovery checkpoint')
239+
parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
240+
help='number of checkpoints to keep (default: 10)')
239241
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
240242
help='how many training processes to use (default: 1)')
241243
parser.add_argument('--save-images', action='store_true', default=False,
@@ -547,7 +549,7 @@ def main():
547549
decreasing = True if eval_metric == 'loss' else False
548550
saver = CheckpointSaver(
549551
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
550-
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)
552+
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
551553
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
552554
f.write(args_text)
553555

0 commit comments

Comments
 (0)