Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ class GRPOConfig(TrainingArguments):
timeout, a `ConnectionError` is raised.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
vllm_colocation (`Optional[int]`, *optional*, defaults to `None`):
Controls colocated vLLM execution and tensor parallelism via the `external_launcher` backend.
- Set to `None` to disable colocated vLLM.
- Set to `1` to enable colocated vLLM on each device independently (no tensor parallelism).
- Set to a value greater than `1` to enable colocated vLLM with tensor parallelism across multiple devices.
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`):
This parameter is used control the GPU memory utilization for vLLM in colocation mode.

> Parameters that control the training

Expand Down Expand Up @@ -295,6 +302,39 @@ class GRPOConfig(TrainingArguments):
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_colocation: Optional[int] = field(
default=None,
metadata={
"help": (
"Controls colocated vLLM execution and tensor parallelism using the `external_launcher` backend. "
"Set to `None` to disable colocated vLLM. "
"Set to `1` to enable colocated vLLM on each device (no tensor parallelism). "
"Set to a value >1 to enable colocated vLLM with tensor parallelism across multiple devices."
)
},
)
vllm_gpu_memory_utilization: Optional[float] = field(
default=0.3,
metadata={
"help": "This parameter is used control the GPU memory utilization for vLLM in colocation mode."
},
)
vllm_max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "This parameter is used to control model length for the vLLM in colocation mode"
},
)
vllm_sleep_enabled: Optional[bool] = field(
default=False,
metadata={
"help": (
"Enables sleep mode for colocated vLLM during training. "
"Set to `True` to keep vLLM in sleep state during training steps, helping reduce memory usage. "
"Set to `False` to disable this behavior."
)
},
)

# Parameters that control the training
learning_rate: float = field(
Expand Down
155 changes: 122 additions & 33 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import torch.utils.data
import transformers
from accelerate import PartialState
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from datasets import Dataset, IterableDataset
from packaging import version
Expand All @@ -47,6 +48,8 @@
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from vllm import SamplingParams, LLM
from vllm.sampling_params import GuidedDecodingParams
from ..import_utils import is_liger_kernel_available, is_rich_available, is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
Expand Down Expand Up @@ -617,7 +620,39 @@ def data_collator(features): # No data collation is needed in GRPO
"`pip install vllm` to use it."
)

if self.accelerator.is_main_process:
if self.args.vllm_colocation:

# Ensure vllm_colocation TP value is valid (at least 1)
assert self.args.vllm_colocation >= 1, "vllm_colocation must be greater than 0"
# Make sure vllm_colocation TP group size evenly divides the world size - each group should have the same number of ranks
assert self.accelerator.num_processes % self.args.vllm_colocation == 0, (
f"TP size of vllm_colocation ({self.args.vllm_colocation}) must divide world size "
f"({self.accelerator.num_processes}) evenly."
)

if self.args.vllm_colocation > 1:
# Create subgroups of ranks for TP, each group with `vllm_colocation` ranks.
# For example, if world_size=8 and vllm_colocation=2 → groups: [0,1], [2,3], [4,5], [6,7]
self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
[
list(range(i*self.args.vllm_colocation, (i+1) * self.args.vllm_colocation))
for i in range(self.accelerator.num_processes // self.args.vllm_colocation)
]
)

self.llm = LLM(
model=model.name_or_path,
# device=f"{device_type}:{self.accelerator.process_index}", # ToDo: we do not need to set the device
tensor_parallel_size=args.vllm_colocation,
enable_sleep_mode=self.args.vllm_sleep_enabled,
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
# max_num_seqs=self.args.per_device_train_batch_size * self.args.vllm_colocation, # ToDo: this should be multiplied by gradient_accumulation_steps
max_model_len=self.args.vllm_max_model_len,
distributed_executor_backend="external_launcher",
seed=int(os.getenv("RANK", "0")) // self.args.vllm_colocation, # feed identical seed for tp groups
)

elif self.accelerator.is_main_process:
self.vllm_client = VLLMClient(
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout
)
Expand Down Expand Up @@ -903,7 +938,13 @@ def _move_model_to_vllm(self):
continue
name = name.replace("modules_to_save.default.", "")

if self.accelerator.is_main_process:
if self.args.vllm_colocation:
if self.args.vllm_sleep_enabled:
torch.cuda.empty_cache()
self.llm.wake_up()
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name,param.data)])
elif self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)

# Unmerge adapters while parameters are still gathered
Expand All @@ -913,11 +954,19 @@ def _move_model_to_vllm(self):
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
if self.args.vllm_colocation:
if self.args.vllm_sleep_enabled:
torch.cuda.empty_cache()
self.llm.wake_up()
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name,param.data)])
elif self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)

# Reset cache on main process
if self.accelerator.is_main_process:
# Reset cache on vLLM
if self.args.vllm_colocation:
self.llm.reset_prefix_cache()
elif self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()

@profiling_decorator
Expand Down Expand Up @@ -977,35 +1026,75 @@ def _generate_and_score_completions(
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step

# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
)
# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
if self.args.vllm_colocation:
# ToDo: we may not need to wake up here anymore as model update is always called before each generation
if self.args.vllm_sleep_enabled:
torch.cuda.empty_cache()
self.llm.wake_up()
sampling_params = SamplingParams(
n=1, # vLLM on each GPU generates only 1 in vllm_colocation mode
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding=GuidedDecodingParams(backend="outlines", regex=self.guided_decoding_regex) if self.guided_decoding_regex else None,
)
if self.args.vllm_colocation > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts_text)
gathered_prompts = [None for _ in range(self.args.vllm_colocation)]
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
prompts_text = [p for sublist in gathered_prompts for p in sublist]

all_outputs = self.llm.generate(
prompts_text, sampling_params=sampling_params, use_tqdm=False
)
if self.args.vllm_sleep_enabled:
self.llm.sleep(level=2) # going back to sleep to free memory for training
torch.cuda.empty_cache()
completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]

if self.args.vllm_colocation > 1:
# Slice completions for this rank within its TP group.
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
completion_ids = completion_ids[tp_slice]

# Generate completions using vLLM server: gather all prompts and use them in a single call in the main process
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
)
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]

# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
Expand Down