Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,12 @@ def tearDown(self):


@parameterized_class(
("fsdp_strategy",),
("fsdp_strategy", "offloading"),
[
("fsdp",),
("fsdp2",),
("fsdp", False),
("fsdp2", False),
("fsdp", True),
("fsdp2", True),
],
)
class TestTrainerGSM8K(BaseTrainerCase):
Expand All @@ -194,9 +196,18 @@ def test_trainer(self):
self.config.buffer.total_epochs = 1
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.check_and_update()
self.config.trainer.trainer_config.actor_rollout_ref.actor.strategy = self.fsdp_strategy
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
actor_rollout_ref = self.config.trainer.trainer_config.actor_rollout_ref
actor_rollout_ref.actor.strategy = self.fsdp_strategy
actor_rollout_ref.actor.optim.lr = 1e-5
if self.fsdp_strategy == "fsdp":
actor_rollout_ref.actor.fsdp_config.param_offload = self.offloading
actor_rollout_ref.actor.fsdp_config.optimizer_offload = self.offloading
actor_rollout_ref.ref.fsdp_config.param_offload = self.offloading
actor_rollout_ref.ref.fsdp_config.optimizer_offload = self.offloading
else: # fsdp2
actor_rollout_ref.actor.fsdp_config.offload_policy = self.offloading
actor_rollout_ref.ref.fsdp_config.offload_policy = self.offloading
both(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
Expand Down
102 changes: 60 additions & 42 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import os
import warnings
from contextlib import contextmanager
from dataclasses import asdict
from datetime import timedelta

Expand Down Expand Up @@ -187,6 +188,19 @@ def __init__(self, config: DictConfig, role: str):
self.config.ref.log_prob_micro_batch_size
)

@contextmanager
def _fsdp_offload_context(self):
"""A context manager to handle FSDP model GPU loading and CPU offloading."""
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
try:
yield
finally:
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
torch.distributed.barrier()
torch.cuda.empty_cache()

def _build_model_optimizer( # noqa: C901
self,
model_path,
Expand Down Expand Up @@ -570,28 +584,29 @@ def setup_weight_sync_group(self):
model = self.actor_module_fsdp
self.named_modules = []
self.state_dict_meta = []
if self.config.actor.strategy == "fsdp":
for name, module in model.named_modules():
if isinstance(module, FSDP):
self.named_modules.append((name, module))
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
realname = (
name_prefix[len(FSDP_PREFIX) :] + "." + name
if name_prefix
else name
)
self.state_dict_meta.append(
(realname, str(param.dtype), tuple(param.shape))
)
param = None
torch.cuda.empty_cache()
else: # fsdp2
for name, param in model.named_parameters():
self.state_dict_meta.append((name, str(param.dtype), tuple(param.shape)))
with self._fsdp_offload_context():
if self.config.actor.strategy == "fsdp":
for name, module in model.named_modules():
if isinstance(module, FSDP):
self.named_modules.append((name, module))
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
realname = (
name_prefix[len(FSDP_PREFIX) :] + "." + name
if name_prefix
else name
)
self.state_dict_meta.append(
(realname, str(param.dtype), tuple(param.shape))
)
param = None
torch.cuda.empty_cache()
else: # fsdp2
for name, param in model.named_parameters():
self.state_dict_meta.append((name, str(param.dtype), tuple(param.shape)))

if torch.distributed.get_rank() == 0:
import ray
Expand All @@ -606,6 +621,7 @@ def setup_weight_sync_group(self):
master_address, master_port, self.state_dict_meta
)
timeout = self.config.synchronizer.sync_timeout

self._model_update_group = init_process_group(
host=master_address,
port=master_port,
Expand All @@ -621,30 +637,32 @@ def setup_weight_sync_group(self):

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def sync_weight(self):
if self.config.actor.strategy == "fsdp":
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
with self._fsdp_offload_context():
if self.config.actor.strategy == "fsdp":
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
if torch.distributed.get_rank() == 0:
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
torch.distributed.broadcast(
param, 0, group=self._model_update_group
)
param = None
else: # fsdp2
for name, param in self.actor_module_fsdp.named_parameters():
full_param = param.full_tensor().detach().to(device=get_device_id())
if torch.distributed.get_rank() == 0:
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
torch.distributed.broadcast(param, 0, group=self._model_update_group)
param = None
else: # fsdp2
for name, param in self.actor_module_fsdp.named_parameters():
full_param = param.full_tensor()
if torch.distributed.get_rank() == 0:
torch.distributed.broadcast(full_param, 0, group=self._model_update_group)
del full_param
if torch.distributed.get_rank() == 0:
torch.distributed.barrier(group=self._model_update_group)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
torch.distributed.broadcast(full_param, 0, group=self._model_update_group)
del full_param
if torch.distributed.get_rank() == 0:
torch.distributed.barrier(group=self._model_update_group)
torch.cuda.synchronize()

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def upload_state_dict(self, trainer_step: int):
self.checkpoint_manager.upload_state_dict(trainer_step)
with self._fsdp_offload_context():
self.checkpoint_manager.upload_state_dict(trainer_step)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_algorithm(self, algo_config: AlgorithmConfig):
Expand Down