diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 190d3bf06a..80e8322d28 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -173,10 +173,12 @@ def tearDown(self): @parameterized_class( - ("fsdp_strategy",), + ("fsdp_strategy", "offloading"), [ - ("fsdp",), - ("fsdp2",), + ("fsdp", False), + ("fsdp2", False), + ("fsdp", True), + ("fsdp2", True), ], ) class TestTrainerGSM8K(BaseTrainerCase): @@ -194,9 +196,18 @@ def test_trainer(self): self.config.buffer.total_epochs = 1 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.check_and_update() - self.config.trainer.trainer_config.actor_rollout_ref.actor.strategy = self.fsdp_strategy self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 - self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5 + actor_rollout_ref = self.config.trainer.trainer_config.actor_rollout_ref + actor_rollout_ref.actor.strategy = self.fsdp_strategy + actor_rollout_ref.actor.optim.lr = 1e-5 + if self.fsdp_strategy == "fsdp": + actor_rollout_ref.actor.fsdp_config.param_offload = self.offloading + actor_rollout_ref.actor.fsdp_config.optimizer_offload = self.offloading + actor_rollout_ref.ref.fsdp_config.param_offload = self.offloading + actor_rollout_ref.ref.fsdp_config.optimizer_offload = self.offloading + else: # fsdp2 + actor_rollout_ref.actor.fsdp_config.offload_policy = self.offloading + actor_rollout_ref.ref.fsdp_config.offload_policy = self.offloading both(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 65e16f175e..9fce3e4cac 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -20,6 +20,7 @@ import logging import os import warnings +from contextlib import contextmanager from dataclasses import asdict from datetime import timedelta @@ -187,6 +188,19 @@ def __init__(self, config: DictConfig, role: str): self.config.ref.log_prob_micro_batch_size ) + @contextmanager + def _fsdp_offload_context(self): + """A context manager to handle FSDP model GPU loading and CPU offloading.""" + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + try: + yield + finally: + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + torch.distributed.barrier() + torch.cuda.empty_cache() + def _build_model_optimizer( # noqa: C901 self, model_path, @@ -570,28 +584,29 @@ def setup_weight_sync_group(self): model = self.actor_module_fsdp self.named_modules = [] self.state_dict_meta = [] - if self.config.actor.strategy == "fsdp": - for name, module in model.named_modules(): - if isinstance(module, FSDP): - self.named_modules.append((name, module)) - for name_prefix, module in self.named_modules: - with FSDP.summon_full_params(module, recurse=False): - for name, param in module.named_parameters(): - if isinstance(param, FlatParameter): - continue - realname = ( - name_prefix[len(FSDP_PREFIX) :] + "." + name - if name_prefix - else name - ) - self.state_dict_meta.append( - (realname, str(param.dtype), tuple(param.shape)) - ) - param = None - torch.cuda.empty_cache() - else: # fsdp2 - for name, param in model.named_parameters(): - self.state_dict_meta.append((name, str(param.dtype), tuple(param.shape))) + with self._fsdp_offload_context(): + if self.config.actor.strategy == "fsdp": + for name, module in model.named_modules(): + if isinstance(module, FSDP): + self.named_modules.append((name, module)) + for name_prefix, module in self.named_modules: + with FSDP.summon_full_params(module, recurse=False): + for name, param in module.named_parameters(): + if isinstance(param, FlatParameter): + continue + realname = ( + name_prefix[len(FSDP_PREFIX) :] + "." + name + if name_prefix + else name + ) + self.state_dict_meta.append( + (realname, str(param.dtype), tuple(param.shape)) + ) + param = None + torch.cuda.empty_cache() + else: # fsdp2 + for name, param in model.named_parameters(): + self.state_dict_meta.append((name, str(param.dtype), tuple(param.shape))) if torch.distributed.get_rank() == 0: import ray @@ -606,6 +621,7 @@ def setup_weight_sync_group(self): master_address, master_port, self.state_dict_meta ) timeout = self.config.synchronizer.sync_timeout + self._model_update_group = init_process_group( host=master_address, port=master_port, @@ -621,30 +637,32 @@ def setup_weight_sync_group(self): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def sync_weight(self): - if self.config.actor.strategy == "fsdp": - for name_prefix, module in self.named_modules: - with FSDP.summon_full_params(module, recurse=False): + with self._fsdp_offload_context(): + if self.config.actor.strategy == "fsdp": + for name_prefix, module in self.named_modules: + with FSDP.summon_full_params(module, recurse=False): + if torch.distributed.get_rank() == 0: + for name, param in module.named_parameters(): + if isinstance(param, FlatParameter): + continue + torch.distributed.broadcast( + param, 0, group=self._model_update_group + ) + param = None + else: # fsdp2 + for name, param in self.actor_module_fsdp.named_parameters(): + full_param = param.full_tensor().detach().to(device=get_device_id()) if torch.distributed.get_rank() == 0: - for name, param in module.named_parameters(): - if isinstance(param, FlatParameter): - continue - torch.distributed.broadcast(param, 0, group=self._model_update_group) - param = None - else: # fsdp2 - for name, param in self.actor_module_fsdp.named_parameters(): - full_param = param.full_tensor() - if torch.distributed.get_rank() == 0: - torch.distributed.broadcast(full_param, 0, group=self._model_update_group) - del full_param - if torch.distributed.get_rank() == 0: - torch.distributed.barrier(group=self._model_update_group) - torch.cuda.synchronize() - torch.distributed.barrier() - torch.cuda.empty_cache() + torch.distributed.broadcast(full_param, 0, group=self._model_update_group) + del full_param + if torch.distributed.get_rank() == 0: + torch.distributed.barrier(group=self._model_update_group) + torch.cuda.synchronize() @register(dispatch_mode=Dispatch.ONE_TO_ALL) def upload_state_dict(self, trainer_step: int): - self.checkpoint_manager.upload_state_dict(trainer_step) + with self._fsdp_offload_context(): + self.checkpoint_manager.upload_state_dict(trainer_step) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def set_algorithm(self, algo_config: AlgorithmConfig):