Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expose remove_previous_ckpt option to training entry point an… #274

Merged
merged 5 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,6 @@ trainer:
test_freq: -1
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
16 changes: 12 additions & 4 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,19 @@ def _save_checkpoint(self):

actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path, self.global_steps)
self.actor_rollout_wg.save_checkpoint(actor_local_path,
actor_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

if self.use_critic:
critic_local_path = os.path.join(local_global_step_folder, 'critic')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps)
self.critic_wg.save_checkpoint(critic_local_path,
critic_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

# save dataloader
dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt')
Expand Down Expand Up @@ -770,10 +776,12 @@ def _load_checkpoint(self):
actor_path = os.path.join(global_step_folder, 'actor')
critic_path = os.path.join(global_step_folder, 'critic')
# load actor
self.actor_rollout_wg.load_checkpoint(actor_path)
self.actor_rollout_wg.load_checkpoint(actor_path,
del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
# load critic
if self.use_critic:
self.critic_wg.load_checkpoint(critic_path)
self.critic_wg.load_checkpoint(critic_path,
del_local_after_load=self.config.trainer.del_local_ckpt_after_load)

# load dataloader,
# TODO: from remote not implemented yet
Expand Down
4 changes: 2 additions & 2 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, model: FSDP, optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, tokenizer: PreTrainedTokenizer, *args, **kwargs):
super().__init__(model, optimizer, lr_scheduler, tokenizer)

def load_checkpoint(self, path=None, del_local_after_load=True, *args, **kwargs):
def load_checkpoint(self, path=None, del_local_after_load=False, *args, **kwargs):
if path is None:
return

Expand Down Expand Up @@ -93,7 +93,7 @@ def load_checkpoint(self, path=None, del_local_after_load=True, *args, **kwargs)
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)

def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=True, *args, **kwargs):
def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=False, *args, **kwargs):
# record the previous global step
self.previous_global_step = global_step

Expand Down
16 changes: 11 additions & 5 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def compute_ref_log_prob(self, data: DataProto):
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False):
# only support save and load ckpt for actor
assert self._is_actor
import torch
Expand All @@ -554,14 +554,17 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)

self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step)
self.checkpoint_manager.save_checkpoint(local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
remove_previous_ckpt=remove_previous_ckpt)

torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, path, del_local_after_load=True):
def load_checkpoint(self, path, del_local_after_load=False):
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
Expand Down Expand Up @@ -821,14 +824,17 @@ def update_critic(self, data: DataProto):
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False):
import torch
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.critic_module,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)

self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step)
self.checkpoint_manager.save_checkpoint(local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
remove_previous_ckpt=remove_previous_ckpt)

torch.distributed.barrier()
if self._is_offload_param:
Expand Down
Loading