Skip to content

Commit

Permalink
Fix loaded checkpoint id one-off error for AIR trainers w/ custom tra…
Browse files Browse the repository at this point in the history
…in loop

Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
  • Loading branch information
justinvyu committed Jan 4, 2023
1 parent 892b4f0 commit 41ce606
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/ray/train/_internal/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ def _load_checkpoint(
) -> Optional[Union[Dict, Checkpoint]]:
loaded_checkpoint = super()._load_checkpoint(checkpoint_to_load)
assert not loaded_checkpoint or isinstance(loaded_checkpoint, Checkpoint)
self._latest_checkpoint_id = getattr(loaded_checkpoint, TUNE_CHECKPOINT_ID, 0)
# `latest_checkpoint_id` will be the id assigned to the next checkpoint,
# which should be one more than the loaded checkpoint's id
# If no checkpoint is loaded, initialize this to 0
self._latest_checkpoint_id = (
getattr(loaded_checkpoint, TUNE_CHECKPOINT_ID, -1) + 1
)
return loaded_checkpoint

def add_tune_checkpoint_id(self, checkpoint: Checkpoint):
Expand Down

0 comments on commit 41ce606

Please sign in to comment.