diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7826b184488..678ab34449e 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -64,6 +64,12 @@ selective_log_softmax, ) +try: + from torch.distributed.checkpoint.state_dict import get_model_state_dict + + fsdp2_available = True +except ImportError: + fsdp2_available = False if is_peft_available(): from peft import PeftConfig, get_peft_model @@ -925,6 +931,25 @@ def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited elif self.vllm_mode == "colocate": llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model llm_model.load_weights([(full_name, param.data)]) + elif isinstance(module, torch.distributed.fsdp.FSDPModule): + assert fsdp2_available, "FSDP2 is not available" + + # Only run this logic at the root call (prefix is empty) + if prefix == "": + # Get the canonical state dict using the high-level torch.distributed.checkpoint API + model_state_dict = get_model_state_dict(module) + + # Sync the state dict to vLLM + if self.vllm_mode == "server" and self.accelerator.is_main_process: + for name, param in model_state_dict.items(): + self.vllm_client.update_named_param(name, param) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + try: + for k, v in model_state_dict.items(): + llm_model.load_weights([(k, v)]) + except ValueError: + print(f"Error loading weights for {k} with shape {v.shape}") @profiling_decorator def _move_model_to_vllm(self):