diff --git a/python/ray/train/_internal/checkpoint.py b/python/ray/train/_internal/checkpoint.py index 9f9bda77b3db2..39bcddc8fbcb8 100644 --- a/python/ray/train/_internal/checkpoint.py +++ b/python/ray/train/_internal/checkpoint.py @@ -214,7 +214,12 @@ def _load_checkpoint( ) -> Optional[Union[Dict, Checkpoint]]: loaded_checkpoint = super()._load_checkpoint(checkpoint_to_load) assert not loaded_checkpoint or isinstance(loaded_checkpoint, Checkpoint) - self._latest_checkpoint_id = getattr(loaded_checkpoint, TUNE_CHECKPOINT_ID, 0) + # `latest_checkpoint_id` will be the id assigned to the next checkpoint, + # which should be one more than the loaded checkpoint's id + # If no checkpoint is loaded, initialize this to 0 + self._latest_checkpoint_id = ( + getattr(loaded_checkpoint, TUNE_CHECKPOINT_ID, -1) + 1 + ) return loaded_checkpoint def add_tune_checkpoint_id(self, checkpoint: Checkpoint):