Skip to content
3 changes: 2 additions & 1 deletion trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,12 @@ def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):

def prepare_fsdp(model, accelerator):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1421
from torch.distributed.fsdp import FSDPModule
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
if not isinstance(model, FSDP):
if not (isinstance(model, FSDP) or isinstance(model, FSDPModule)):
accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = accelerator.state.fsdp_plugin
kwargs = {
Expand Down
37 changes: 31 additions & 6 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,14 +1138,14 @@ def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = No
name = name.replace(prefix, "")
return name

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this get duplicated and sent twice with the now L934-938?

Copy link
Contributor Author

@SalmanMohammadi SalmanMohammadi Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right - this should be the only logic for syncing params, and I believe we'll be dropping support for FSDP1, right? @kashif

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we are dropping support for FSDP1

def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
# For FSDP1, we need to recurse into children and also use summon_full_params
if visited is None:
visited = set()

for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
self._sync_fsdp_params_to_vllm(
self._sync_fsdp1_params_to_vllm(
child_module, prefix=child_prefix, visited=visited
) # recurse into the child

Expand All @@ -1165,6 +1165,19 @@ def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(full_name, param.data)])

def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
# For FSDP2, module.state_dict() already covers all parameters, so no need for recursion
for name, param in module.state_dict().items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
param = param.full_tensor()

if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param)
elif self.vllm_mode == "colocate":
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name, param)])

@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
Expand All @@ -1188,7 +1201,14 @@ def _move_model_to_vllm(self):
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
# Update vLLM weights while parameters are gathered
# For PEFT with FSDP we need to use the memory efficient post-order traversal
self._sync_fsdp_params_to_vllm(self.model)
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
if fsdp_version == 1:
self._sync_fsdp1_params_to_vllm(
self.model
) # use memory-efficient post-order traversal for FSDP
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(self.model)
else:
# DeepSpeed ZeRO-3 with PEFT
for name, param in self.model.named_parameters():
Expand All @@ -1212,7 +1232,12 @@ def _move_model_to_vllm(self):
else:
# For non-PEFT models, simply gather (if needed) and update each parameter individually.
if self.is_fsdp_enabled:
self._sync_fsdp_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
if fsdp_version == 1:
self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(self.model)
else:
for name, param in self.model.named_parameters():
name = self._fix_param_name_to_vllm(name)
Expand Down Expand Up @@ -1360,7 +1385,7 @@ def _generate_and_score_completions(
padding_side="left",
add_special_tokens=False,
**kwargs,
).to(self.model.dtype)
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

Expand Down
Loading