Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-mem

If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.

If you still find you are getting out-of-memory errors set `vllm_sleep_enabled` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).

</Tip>

<Tip>
Expand Down
25 changes: 25 additions & 0 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,28 @@ training_args = RLOOConfig(..., ds3_gather_for_generation=False)
</hfoptions>

This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds.

## vLLM sleep mode

When using vLLM as the generation backend, you can enable _sleep mode_ to offload vLLM parameters and cache to CPU RAM during the optimization step and reload them back to GPU VRAM when needed for weight synchronization and generation.

<hfoptions id="vllm_sleep">
<hfoption id="GRPO">

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., vllm_sleep_enabled=True)
```

</hfoption>
<hfoption id="RLOO">

```python
from trl import RLOOConfig

training_args = RLOOConfig(..., vllm_sleep_enabled=True)
```

</hfoption>
</hfoptions>
10 changes: 10 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ class GRPOConfig(TrainingArguments):
Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
implementation.
vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`):
Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken
for weight sync and generation.

> Parameters that control the training

Expand Down Expand Up @@ -416,6 +419,13 @@ class GRPOConfig(TrainingArguments):
"model implementation."
},
)
vllm_enable_sleep_mode: bool = field(
default=False,
metadata={
"help": "Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step "
"and woken for weight sync and generation."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
Expand Down
13 changes: 12 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,13 @@ def __init__(
distributed_executor_backend="external_launcher",
# Feed identical seed for tp groups to ensure sampling results are the same across workers
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory
max_num_batched_tokens=4096,
model_impl=self.args.vllm_model_impl,
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every time vllm wakes up, it wakes up to an updated model, right? So why not use level = 2 to further improve efficiency?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point, let me rerun the benchmark with level=2.

Copy link
Contributor

@toslali-ibm toslali-ibm Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For level = 2, you will also need to wake-up then sleep in _move_model_to_vllm, _sync_fsdp2_params_to_vllm and _sync_fsdp1_params_to_vllm.

BAsically anytime you touch vllm (in generation or loading/syncying the model), you wake up, do the work, and then go back to sleep.

Additional note: the reason we did not move sleep to upstream was because of a vllm bug, which is just recently fixed by this PR. So that means, you need to use vllm version that incorporates the fix to be able to use sleep level 2 without segmentation fault.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have already included a wake up call before the weight sync. I don't re-sleep as we have the gen step straight after.

It looks like the level 2 fix has not made it into their most recent release. I will leave level=1 for now for better backward compatability. Unless there are other reasons to go with level 2?

else:
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")

Expand Down Expand Up @@ -1127,6 +1130,11 @@ def _generate_and_score_completions(

# Generate completions using either vLLM or regular generation
if self.use_vllm:
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
# wake up colocated vLLM instances if needed
torch.cuda.empty_cache() # required to avoid OOM in some cases
self.llm.wake_up()

# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
Expand Down Expand Up @@ -1235,6 +1243,9 @@ def _generate_and_score_completions(
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
completion_ids = completion_ids[tp_slice]

if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)

# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.pad_token_id)
Expand Down
Loading