Skip to content

Commit edbe823

Browse files
edbeechingqgallouedecsergiopaniego
authored
[GRPO] Adds an option to sleep vllm when running in colocated mode (#3968)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
1 parent 4c47b32 commit edbe823

File tree

4 files changed

+49
-1
lines changed

4 files changed

+49
-1
lines changed

docs/source/grpo_trainer.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-mem
253253
254254
If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
255255

256+
If you still find you are getting out-of-memory errors set `vllm_sleep_enabled` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).
257+
256258
</Tip>
257259

258260
<Tip>

docs/source/reducing_memory_usage.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,28 @@ training_args = RLOOConfig(..., ds3_gather_for_generation=False)
267267
</hfoptions>
268268

269269
This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds.
270+
271+
## vLLM sleep mode
272+
273+
When using vLLM as the generation backend, you can enable _sleep mode_ to offload vLLM parameters and cache to CPU RAM during the optimization step and reload them back to GPU VRAM when needed for weight synchronization and generation.
274+
275+
<hfoptions id="vllm_sleep">
276+
<hfoption id="GRPO">
277+
278+
```python
279+
from trl import GRPOConfig
280+
281+
training_args = GRPOConfig(..., vllm_sleep_enabled=True)
282+
```
283+
284+
</hfoption>
285+
<hfoption id="RLOO">
286+
287+
```python
288+
from trl import RLOOConfig
289+
290+
training_args = RLOOConfig(..., vllm_sleep_enabled=True)
291+
```
292+
293+
</hfoption>
294+
</hfoptions>

trl/trainer/grpo_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ class GRPOConfig(TrainingArguments):
140140
Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
141141
the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
142142
implementation.
143+
vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`):
144+
Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken
145+
for weight sync and generation.
143146
144147
> Parameters that control the training
145148
@@ -416,6 +419,13 @@ class GRPOConfig(TrainingArguments):
416419
"model implementation."
417420
},
418421
)
422+
vllm_enable_sleep_mode: bool = field(
423+
default=False,
424+
metadata={
425+
"help": "Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step "
426+
"and woken for weight sync and generation."
427+
},
428+
)
419429
vllm_guided_decoding_regex: Optional[str] = field(
420430
default=None,
421431
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},

trl/trainer/grpo_trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,10 +533,13 @@ def __init__(
533533
distributed_executor_backend="external_launcher",
534534
# Feed identical seed for tp groups to ensure sampling results are the same across workers
535535
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
536-
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory
536+
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory
537537
max_num_batched_tokens=4096,
538538
model_impl=self.args.vllm_model_impl,
539+
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
539540
)
541+
if self.args.vllm_enable_sleep_mode:
542+
self.llm.sleep(level=1)
540543
else:
541544
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
542545

@@ -1127,6 +1130,11 @@ def _generate_and_score_completions(
11271130

11281131
# Generate completions using either vLLM or regular generation
11291132
if self.use_vllm:
1133+
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
1134+
# wake up colocated vLLM instances if needed
1135+
torch.cuda.empty_cache() # required to avoid OOM in some cases
1136+
self.llm.wake_up()
1137+
11301138
# First, update the vLLM weights if needed
11311139
if self.state.global_step != self._last_loaded_step:
11321140
self._move_model_to_vllm()
@@ -1235,6 +1243,9 @@ def _generate_and_score_completions(
12351243
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
12361244
completion_ids = completion_ids[tp_slice]
12371245

1246+
if self.args.vllm_enable_sleep_mode:
1247+
self.llm.sleep(level=1)
1248+
12381249
# Pad the completions, and concatenate them with the prompts
12391250
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
12401251
completion_ids = pad(completion_ids, padding_value=self.pad_token_id)

0 commit comments

Comments
 (0)