From 3e579522bbd7b16a8f091de3af9ef2347ed3756e Mon Sep 17 00:00:00 2001 From: zwhe99 Date: Fri, 14 Feb 2025 13:11:45 +0800 Subject: [PATCH 1/5] feat: Expose `remove_previous_ckpt` option to training entry point and user configuration - Add `remove_previous_ckpt` configuration option in `ppo_trainer.yaml` - Update `RayPPOTrainer` to support optional checkpoint deletion during loading - Modify `ActorRolloutRefWorker` and `CriticWorker` to pass checkpoint removal flag --- verl/trainer/config/ppo_trainer.yaml | 1 + verl/trainer/ppo/ray_trainer.py | 4 ++-- verl/workers/fsdp_workers.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 03fed370..7915e13a 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -170,4 +170,5 @@ trainer: test_freq: -1 critic_warmup: 0 default_hdfs_dir: null + remove_previous_ckpt: True default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index c2563767..4dad77f8 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -770,10 +770,10 @@ def _load_checkpoint(self): actor_path = os.path.join(global_step_folder, 'actor') critic_path = os.path.join(global_step_folder, 'critic') # load actor - self.actor_rollout_wg.load_checkpoint(actor_path) + self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.remove_previous_ckpt) # load critic if self.use_critic: - self.critic_wg.load_checkpoint(critic_path) + self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.remove_previous_ckpt) # load dataloader, # TODO: from remote not implemented yet diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index ed96127d..980e3e63 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -554,7 +554,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step) + self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt) torch.distributed.barrier() if self._is_offload_param: @@ -828,7 +828,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step) + self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt) torch.distributed.barrier() if self._is_offload_param: From 257d25f7f35da5be84e2ce017641bec317f7b0e6 Mon Sep 17 00:00:00 2001 From: zwhe99 Date: Fri, 14 Feb 2025 13:38:40 +0800 Subject: [PATCH 2/5] refactor: Rename checkpoint removal configuration for clarity - Rename `remove_previous_ckpt` in config file to more specific `remove_previous_ckpt_in_save` - Add new `del_local_ckpt_after_load` configuration option in config file - Update checkpoint loading and saving methods to use new configuration names --- verl/trainer/config/ppo_trainer.yaml | 3 ++- verl/trainer/ppo/ray_trainer.py | 4 ++-- verl/workers/fsdp_workers.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 7915e13a..6e2f7445 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -170,5 +170,6 @@ trainer: test_freq: -1 critic_warmup: 0 default_hdfs_dir: null - remove_previous_ckpt: True + remove_previous_ckpt_in_save: True + del_local_ckpt_after_load: True default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 4dad77f8..834a5340 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -770,10 +770,10 @@ def _load_checkpoint(self): actor_path = os.path.join(global_step_folder, 'actor') critic_path = os.path.join(global_step_folder, 'critic') # load actor - self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.remove_previous_ckpt) + self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load critic if self.use_critic: - self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.remove_previous_ckpt) + self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load dataloader, # TODO: from remote not implemented yet diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 980e3e63..6e250e43 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -554,7 +554,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt) + self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) torch.distributed.barrier() if self._is_offload_param: @@ -828,7 +828,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt) + self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) torch.distributed.barrier() if self._is_offload_param: From fe4a4fa51ab4cca77dd785eb823c3a1886d1f9d3 Mon Sep 17 00:00:00 2001 From: zwhe99 Date: Fri, 14 Feb 2025 14:06:34 +0800 Subject: [PATCH 3/5] fix: Modify checkpoint management default behavior - Change default value of `remove_previous_ckpt_in_save` to `False` in `ppo_trainer.yaml` - Change default value of `del_local_ckpt_after_load` to `False` in `ppo_trainer.yaml` - Update `FSDPCheckpointManager` and `ActorRolloutRefWorker` to use new default values --- verl/trainer/config/ppo_trainer.yaml | 4 ++-- verl/utils/checkpoint/fsdp_checkpoint_manager.py | 4 ++-- verl/workers/fsdp_workers.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 6e2f7445..d9ea39c7 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -170,6 +170,6 @@ trainer: test_freq: -1 critic_warmup: 0 default_hdfs_dir: null - remove_previous_ckpt_in_save: True - del_local_ckpt_after_load: True + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index aa4806de..e5ec9efd 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -48,7 +48,7 @@ def __init__(self, model: FSDP, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, tokenizer: PreTrainedTokenizer, *args, **kwargs): super().__init__(model, optimizer, lr_scheduler, tokenizer) - def load_checkpoint(self, path=None, del_local_after_load=True, *args, **kwargs): + def load_checkpoint(self, path=None, del_local_after_load=False, *args, **kwargs): if path is None: return @@ -93,7 +93,7 @@ def load_checkpoint(self, path=None, del_local_after_load=True, *args, **kwargs) if self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) - def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=True, *args, **kwargs): + def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=False, *args, **kwargs): # record the previous global step self.previous_global_step = global_step diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 6e250e43..50366869 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -561,7 +561,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, path, del_local_after_load=True): + def load_checkpoint(self, path, del_local_after_load=False): if self._is_offload_param: load_fsdp_param_and_grad(module=self.actor_module_fsdp, device_id=torch.cuda.current_device(), From 3acb04b75f67a317221de5161a89a312c0b89ccb Mon Sep 17 00:00:00 2001 From: zwhe99 Date: Fri, 14 Feb 2025 14:21:04 +0800 Subject: [PATCH 4/5] refactor: Improve checkpoint loading and saving method formatting --- verl/trainer/ppo/ray_trainer.py | 6 ++++-- verl/workers/fsdp_workers.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 834a5340..5150d2e1 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -770,10 +770,12 @@ def _load_checkpoint(self): actor_path = os.path.join(global_step_folder, 'actor') critic_path = os.path.join(global_step_folder, 'critic') # load actor - self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.actor_rollout_wg.load_checkpoint(actor_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load critic if self.use_critic: - self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.critic_wg.load_checkpoint(critic_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load dataloader, # TODO: from remote not implemented yet diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 50366869..323581e4 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -554,7 +554,10 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + self.checkpoint_manager.save_checkpoint(local_path=local_path, + hdfs_path=hdfs_path, + global_step=global_step, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) torch.distributed.barrier() if self._is_offload_param: @@ -828,7 +831,10 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + self.checkpoint_manager.save_checkpoint(local_path=local_path, + hdfs_path=hdfs_path, + global_step=global_step, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) torch.distributed.barrier() if self._is_offload_param: From c8633537d2dcf7fcfe03b581346a4964b17e3d06 Mon Sep 17 00:00:00 2001 From: zwhe99 Date: Fri, 14 Feb 2025 16:59:13 +0800 Subject: [PATCH 5/5] - Update `RayPPOTrainer` to pass `remove_previous_ckpt` flag to checkpoint saving methods - Modify `ActorRolloutRefWorker` and `CriticWorker` to accept optional `remove_previous_ckpt` parameter --- verl/trainer/ppo/ray_trainer.py | 10 ++++++++-- verl/workers/fsdp_workers.py | 8 ++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 5150d2e1..9ba116df 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -714,13 +714,19 @@ def _save_checkpoint(self): actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor') - self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path, self.global_steps) + self.actor_rollout_wg.save_checkpoint(actor_local_path, + actor_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) if self.use_critic: critic_local_path = os.path.join(local_global_step_folder, 'critic') critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic') - self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps) + self.critic_wg.save_checkpoint(critic_local_path, + critic_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) # save dataloader dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt') diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 323581e4..42397518 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -545,7 +545,7 @@ def compute_ref_log_prob(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False): # only support save and load ckpt for actor assert self._is_actor import torch @@ -557,7 +557,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, - remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + remove_previous_ckpt=remove_previous_ckpt) torch.distributed.barrier() if self._is_offload_param: @@ -824,7 +824,7 @@ def update_critic(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False): import torch if self._is_offload_param: load_fsdp_param_and_grad(module=self.critic_module, @@ -834,7 +834,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, - remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + remove_previous_ckpt=remove_previous_ckpt) torch.distributed.barrier() if self._is_offload_param: