Skip to content

Resuming training from an interrupted checkpoint fails to save the final checkpoint. #38939

@rangehow

Description

@rangehow

System Info

  • transformers version: 4.53.0.dev0
  • Platform: Linux-4.18.0-147.mt20200626.413.el8_1.x86_64-x86_64-with-glibc2.17
  • Python version: 3.12.3
  • Huggingface_hub version: 0.30.1
  • Safetensors version: 0.5.3
  • Accelerate version: 1.6.0
  • Accelerate config: not found
  • DeepSpeed version: 0.16.7
  • PyTorch version (accelerator?): 2.6.0+cu124 (NA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: yes

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Normally, a trainer that saves checkpoints based on steps or epochs is guaranteed to save a final checkpoint upon completion of the training. This is ensured by the following sequence of logic:

First, when the training reaches the final step, do_sync_step is set to True based on the second part of this conditional statement.

do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch

This flag indicates that a gradient update is about to be performed, and the global_step is updated precisely within that gradient update branch. (see L2623)

if do_sync_step:
# Since we perform prefetching, we need to manually set sync_gradients to True
self.accelerator.gradient_state._set_sync_gradients(True)
# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
if is_sagemaker_mp_enabled() and args.fp16:
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
from apex import amp
# Revert to normal clipping otherwise, handling Apex or full precision
_grad_norm = nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
grad_norm_context = contextlib.nullcontext
if self.is_tp_enabled:
from torch.distributed._tensor.experimental import implicit_replication
grad_norm_context = implicit_replication
with grad_norm_context():
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
if (
is_accelerate_available()
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
):
grad_norm = model.get_global_grad_norm()
# In some cases the grad norm may not return a float
if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item()
else:
grad_norm = _grad_norm
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
self.optimizer.step()
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
# get leaning rate before update
learning_rate = self._get_learning_rate()
if not self.accelerator.optimizer_step_was_skipped:
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(
tr_loss,
grad_norm,
model,
trial,
epoch,
ignore_keys_for_eval,
start_time,
learning_rate=learning_rate,
)

This coherent logic would work without any issues in an error-free training run. However, when training is resumed, the step counter is not restored; it is still initialized from -1.

This prevents a gradient update for the final set of batches in an epoch if the total dataset size is not perfectly divisible by the product of gradient_accumulation_steps and the data-parallel world size.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions