From 9ce54784ca81f295aff41bd17204aa482334a9b7 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 09:40:08 -0400 Subject: [PATCH 01/27] Add vllm colocation --- trl/extras/vllm_client.py | 4 +- trl/extras/vllm_coloc_client.py | 129 ++++++++++++++++++++++++++++++++ trl/extras/vllm_proxy.py | 52 +++++++++++++ trl/trainer/grpo_config.py | 12 +++ trl/trainer/grpo_trainer.py | 44 +++++------ 5 files changed, 217 insertions(+), 24 deletions(-) create mode 100644 trl/extras/vllm_coloc_client.py create mode 100644 trl/extras/vllm_proxy.py diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 940611f5d6b..6de19c6a394 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -21,7 +21,7 @@ from torch import nn from ..import_utils import is_requests_available, is_vllm_available - +from .vllm_proxy import BaseVLLMClient if is_requests_available(): import requests @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) -class VLLMClient: +class VLLMClient(BaseVLLMClient): """ A client class to interact with a vLLM server. diff --git a/trl/extras/vllm_coloc_client.py b/trl/extras/vllm_coloc_client.py new file mode 100644 index 00000000000..6d195c137e8 --- /dev/null +++ b/trl/extras/vllm_coloc_client.py @@ -0,0 +1,129 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import torch +import warnings +from .vllm_proxy import BaseVLLMClient +from vllm import SamplingParams, LLM + +from accelerate import PartialState +from profiling import profiling_context + +class VLLMColocationClient(BaseVLLMClient): + def __init__(self, accelerator, args, model): + ## ToDo: get accelerator and other things - to offload if/else from trainer to the clients + self.args = args + self.accelerator = accelerator + self.model = model + self.vllm_device = None + self._initialize_colocated_vllm() + + def _initialize_colocated_vllm(self): + device_type = PartialState().default_device.type + self.vllm_device = f"{device_type}:{self.accelerator.process_index}" + + warnings.warn( + f"The requested device {self.vllm_device} is also being used for training. For higher throughput " + "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. " + "If this is intentional, you may ignore this warning but should adjust " + "`vllm_gpu_memory_utilization` accordingly." + ) + + self.llm = LLM( + model=self.model.name_or_path, + device=self.vllm_device, + gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + dtype=self.args.vllm_dtype, + enable_prefix_caching=self.args.vllm_enable_prefix_caching, + max_model_len=self.args.vllm_max_model_len, + distributed_executor_backend="external_launcher", + ) + + def update_named_param(self, name: str, weights: torch.Tensor): + """ + Updates a specific named parameter in the model. + + Args: + name (`str`): + Name of the layer whose weights are being updated. + weights (`torch.Tensor`): + Tensor containing the updated weights. + """ + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name,weights)]) + + def generate( + self, + prompts: list[str], + n: int = 1, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + max_tokens: int = 16, + guided_decoding_regex: Optional[str] = None, + ) -> list[list[str]]: + """ + Generates model completions for the provided prompts. + + Args: + prompts (`list[str]`): + List of text prompts for which the model will generate completions. + n (`int`, *optional*, defaults to `1`): + Number of completions to generate for each prompt. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Parameter for repetition penalty. 1.0 means no penalty. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature parameter for sampling. Higher values increase diversity. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter.`1.0` means no truncation. + top_k (`int`, *optional*, defaults to `-1`): + Top-k sampling parameter. `-1` means no truncation. + min_p (`float`, *optional*, defaults to `0.0`): + Minimum probability for sampling. + max_tokens (`int`, *optional*, defaults to `16`): + Maximum number of tokens to generate for each prompt. + guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): + Regular expression to guide the decoding process. + + Returns: + `list[list[int]]`: + List of lists of token IDs representing the model-generated completions for each prompt. + """ + sampling_params = SamplingParams( + n=1, # vLLM on each GPU generates only 1 in vllm_colocation mode + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_tokens=max_tokens, + guided_decoding_regex=guided_decoding_regex, + + ) + with profiling_context(self, "vLLM.generate"): + all_outputs = self.llm.generate( + prompts, sampling_params=sampling_params, use_tqdm=False + ) + completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + return completion_ids + + def reset_prefix_cache(self): + """ + Resets the prefix cache for the model. + """ + self.llm.engine.reset_prefix_cache() + diff --git a/trl/extras/vllm_proxy.py b/trl/extras/vllm_proxy.py new file mode 100644 index 00000000000..f738c3f0214 --- /dev/null +++ b/trl/extras/vllm_proxy.py @@ -0,0 +1,52 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional +import torch + +class BaseVLLMClient(ABC): + + @abstractmethod + def generate( + self, + prompts: list[str], + n: int = 1, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + max_tokens: int = 16, + guided_decoding_regex: Optional[str] = None, + ) -> list[list[str]]: + pass + + @abstractmethod + def update_named_param(self, name: str, weights: torch.Tensor): + pass + + @abstractmethod + def reset_prefix_cache(self): + pass + +def get_vllm_client(args, accelerator, model) -> BaseVLLMClient: + from .vllm_colocation_client import VLLMColocationClient + from .vllm_client import VLLMClient + + if args.vllm_colocation: + return VLLMColocationClient(accelerator, args, model) + else: + return VLLMClient(args.vllm_server_host, args.vllm_server_port, connection_timeout=120.0) + diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 776bb1ef3cb..8ec24d40108 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -90,6 +90,10 @@ class GRPOConfig(TrainingArguments): timeout, a `ConnectionError` is raised. vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. + vllm_colocation (`bool`, *optional*, defaults to `False`): + Whether to use colocated vLLM execution via external launcher. If set to `True`, vLLM will be + initialized in **all processes**, each assigned to its respective device. This allows multi-GPU + or multi-node execution with vLLM's external launcher, enabling improved large-scale inference. > Parameters that control the training @@ -248,6 +252,14 @@ class GRPOConfig(TrainingArguments): default=None, metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, ) + vllm_colocation: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to use colocated vLLM execution via external launcher. If set to `True`, vLLM will be " + "initialized in all processes, each assigned to its respective device. This enables optimized " + "multi-GPU inference." + }, + ) # Parameters that control the training learning_rate: float = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 121263c1861..14ace026640 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -43,7 +43,7 @@ from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..extras.profiling import profiling_context, profiling_decorator -from ..extras.vllm_client import VLLMClient +from ..extras.vllm_proxy import get_vllm_client from ..import_utils import is_deepspeed_available, is_rich_available, is_vllm_available from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from .callbacks import SyncRefModelCallback @@ -472,10 +472,8 @@ def data_collator(features): # No data collation is needed in GRPO "`pip install vllm` to use it." ) - if self.accelerator.is_main_process: - self.vllm_client = VLLMClient( - args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout - ) + if self.accelerator.is_main_process or self.args.vllm_colocation: + self.vllm_client = get_vllm_client(self.args, self.accelerator, model) # vLLM specific sampling arguments self.guided_decoding_regex = args.vllm_guided_decoding_regex @@ -484,8 +482,9 @@ def data_collator(features): # No data collation is needed in GRPO # When using vLLM, the main process is responsible for loading the model weights. This can cause process # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we - # synchronize all processes after vLLM has been fully initialized. - self.accelerator.wait_for_everyone() + # synchronize all processes after vLLM has been fully initialized (if colocated, no need to wait). + if not self.args.vllm_colocation: + self.accelerator.wait_for_everyone() else: self.generation_config = GenerationConfig( max_new_tokens=self.max_completion_length, @@ -641,7 +640,7 @@ def _move_model_to_vllm(self): continue name = name.replace("modules_to_save.default.", "") - if self.accelerator.is_main_process: + if self.accelerator.is_main_process or self.args.vllm_colocation: self.vllm_client.update_named_param(name, param.data) # Unmerge adapters while parameters are still gathered @@ -651,11 +650,11 @@ def _move_model_to_vllm(self): # For non-PEFT models, simply gather and update each parameter individually. for name, param in self.model.named_parameters(): with gather_if_zero3([param]): - if self.accelerator.is_main_process: + if self.accelerator.is_main_process or self.args.vllm_colocation: self.vllm_client.update_named_param(name, param.data) - # Reset cache on main process - if self.accelerator.is_main_process: + # Reset cache on main process (if colocated, reset cache on all vllms) + if self.accelerator.is_main_process or self.args.vllm_colocation: self.vllm_client.reset_prefix_cache() @profiling_decorator @@ -696,13 +695,13 @@ def _generate_and_score_completions( self._move_model_to_vllm() self._last_loaded_step = self.state.global_step - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - all_prompts_text = gather_object(prompts_text) + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process (if colocated, work on your own batch) + all_prompts_text = prompts_text if self.args.vllm_colocation else gather_object(prompts_text) if self.accelerator.is_main_process: # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + # prompt individually (if colocated, work on your own batch). + ordered_set_of_prompts = all_prompts_text if self.args.vllm_colocation else all_prompts_text[:: self.num_generations] with profiling_context(self, "vLLM.generate"): completion_ids = self.vllm_client.generate( prompts=ordered_set_of_prompts, @@ -718,13 +717,14 @@ def _generate_and_score_completions( else: completion_ids = [None] * len(all_prompts_text) # Broadcast the completions from the main process to all processes, ensuring each process receives its - # corresponding slice. - completion_ids = broadcast_object_list(completion_ids, from_process=0) - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - completion_ids = completion_ids[process_slice] + # corresponding slice (if colocated, no need for broadcasting). + if not self.args.vllm_colocation: + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + completion_ids = completion_ids[process_slice] # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] From e3a0734444192c6541ab6f71762f9df36137710d Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 09:41:52 -0400 Subject: [PATCH 02/27] Fix typo --- trl/extras/vllm_proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/extras/vllm_proxy.py b/trl/extras/vllm_proxy.py index f738c3f0214..2946223cd82 100644 --- a/trl/extras/vllm_proxy.py +++ b/trl/extras/vllm_proxy.py @@ -42,7 +42,7 @@ def reset_prefix_cache(self): pass def get_vllm_client(args, accelerator, model) -> BaseVLLMClient: - from .vllm_colocation_client import VLLMColocationClient + from .vllm_coloc_client import VLLMColocationClient from .vllm_client import VLLMClient if args.vllm_colocation: From 72697c20b185138947ed9048b560315f8619748a Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 10:05:45 -0400 Subject: [PATCH 03/27] Remove profiling --- trl/extras/vllm_coloc_client.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/trl/extras/vllm_coloc_client.py b/trl/extras/vllm_coloc_client.py index 6d195c137e8..4567efe5ada 100644 --- a/trl/extras/vllm_coloc_client.py +++ b/trl/extras/vllm_coloc_client.py @@ -19,7 +19,6 @@ from vllm import SamplingParams, LLM from accelerate import PartialState -from profiling import profiling_context class VLLMColocationClient(BaseVLLMClient): def __init__(self, accelerator, args, model): @@ -114,10 +113,10 @@ def generate( guided_decoding_regex=guided_decoding_regex, ) - with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate( - prompts, sampling_params=sampling_params, use_tqdm=False - ) + + all_outputs = self.llm.generate( + prompts, sampling_params=sampling_params, use_tqdm=False + ) completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] return completion_ids From 6762037427eea783c637a5b6f5bb3294daf8ec7c Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 10:18:08 -0400 Subject: [PATCH 04/27] Fix default dtype --- trl/trainer/grpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 8ec24d40108..5d0e2295ab7 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -356,7 +356,7 @@ class GRPOConfig(TrainingArguments): }, ) vllm_dtype: Optional[str] = field( - default=None, + default="auto", metadata={ "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the data type for " "vLLM generation, you should now use the `dtype` parameter in the vLLM server configuration." From 95ef38f7d1f4e07ef6c42a879bd874a54784a58c Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 10:46:39 -0400 Subject: [PATCH 05/27] Remove profiling --- trl/trainer/grpo_trainer.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 14ace026640..797ea3951d6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -702,18 +702,18 @@ def _generate_and_score_completions( # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually (if colocated, work on your own batch). ordered_set_of_prompts = all_prompts_text if self.args.vllm_colocation else all_prompts_text[:: self.num_generations] - with profiling_context(self, "vLLM.generate"): - completion_ids = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - guided_decoding_regex=self.guided_decoding_regex, - ) + # with profiling_context(self, "vLLM.generate"): + completion_ids = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + ) else: completion_ids = [None] * len(all_prompts_text) # Broadcast the completions from the main process to all processes, ensuring each process receives its From 65b3bc6c3eee77c7774803d36b317e11cf1753f8 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 11:15:02 -0400 Subject: [PATCH 06/27] Print for debugging --- trl/extras/vllm_coloc_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/extras/vllm_coloc_client.py b/trl/extras/vllm_coloc_client.py index 4567efe5ada..006077ee915 100644 --- a/trl/extras/vllm_coloc_client.py +++ b/trl/extras/vllm_coloc_client.py @@ -117,7 +117,9 @@ def generate( all_outputs = self.llm.generate( prompts, sampling_params=sampling_params, use_tqdm=False ) + print("\n\n\n---- all out", all_outputs) completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + print("\n\n\n---- completion out",completion_ids) return completion_ids def reset_prefix_cache(self): From 4f6aa276f309005bb9656b94c70cf0de7431bbd5 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 11:17:54 -0400 Subject: [PATCH 07/27] Fix bug - generate in all procs --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 797ea3951d6..061104be6f6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -697,7 +697,7 @@ def _generate_and_score_completions( # Generate completions using vLLM: gather all prompts and use them in a single call in the main process (if colocated, work on your own batch) all_prompts_text = prompts_text if self.args.vllm_colocation else gather_object(prompts_text) - if self.accelerator.is_main_process: + if self.accelerator.is_main_process or self.args.vllm_colocation: # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually (if colocated, work on your own batch). From 85f8f40d74d5240064c0f37817f318e7a6c10c22 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 11:23:12 -0400 Subject: [PATCH 08/27] Fix guided decoding param --- trl/extras/vllm_coloc_client.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/trl/extras/vllm_coloc_client.py b/trl/extras/vllm_coloc_client.py index 006077ee915..90d66f38b73 100644 --- a/trl/extras/vllm_coloc_client.py +++ b/trl/extras/vllm_coloc_client.py @@ -17,6 +17,7 @@ import warnings from .vllm_proxy import BaseVLLMClient from vllm import SamplingParams, LLM +from vllm.sampling_params import GuidedDecodingParams from accelerate import PartialState @@ -102,6 +103,11 @@ def generate( `list[list[int]]`: List of lists of token IDs representing the model-generated completions for each prompt. """ + if guided_decoding_regex is not None: + guided_decoding = GuidedDecodingParams(backend="outlines", regex=guided_decoding_regex) + else: + guided_decoding = None + sampling_params = SamplingParams( n=1, # vLLM on each GPU generates only 1 in vllm_colocation mode repetition_penalty=repetition_penalty, @@ -110,8 +116,7 @@ def generate( top_k=top_k, min_p=min_p, max_tokens=max_tokens, - guided_decoding_regex=guided_decoding_regex, - + guided_decoding=guided_decoding, ) all_outputs = self.llm.generate( From 845328e61f532df021eb42c9ae36412abf21ba75 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 11:26:38 -0400 Subject: [PATCH 09/27] Fix reset prefix caching --- trl/extras/vllm_coloc_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/extras/vllm_coloc_client.py b/trl/extras/vllm_coloc_client.py index 90d66f38b73..abcd8c98ddd 100644 --- a/trl/extras/vllm_coloc_client.py +++ b/trl/extras/vllm_coloc_client.py @@ -122,14 +122,12 @@ def generate( all_outputs = self.llm.generate( prompts, sampling_params=sampling_params, use_tqdm=False ) - print("\n\n\n---- all out", all_outputs) completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - print("\n\n\n---- completion out",completion_ids) return completion_ids def reset_prefix_cache(self): """ Resets the prefix cache for the model. """ - self.llm.engine.reset_prefix_cache() + self.llm.reset() From 20d6fefef96b169f4250ff5748691c1d97b490f9 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 24 Mar 2025 11:32:53 -0400 Subject: [PATCH 10/27] Fix reset prefix caching --- trl/extras/vllm_coloc_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/extras/vllm_coloc_client.py b/trl/extras/vllm_coloc_client.py index abcd8c98ddd..d02b56c787f 100644 --- a/trl/extras/vllm_coloc_client.py +++ b/trl/extras/vllm_coloc_client.py @@ -129,5 +129,5 @@ def reset_prefix_cache(self): """ Resets the prefix cache for the model. """ - self.llm.reset() + self.llm.reset_prefix_cache() From 47aca0475e7cbcada54831b62764502c689f5bc3 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Tue, 25 Mar 2025 11:38:58 -0400 Subject: [PATCH 11/27] Revert dtype --- trl/trainer/grpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index ea5b447e3db..d812aea6b12 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -362,7 +362,7 @@ class GRPOConfig(TrainingArguments): }, ) vllm_dtype: Optional[str] = field( - default="auto", + default=None, metadata={ "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the data type for " "vLLM generation, you should now use the `dtype` parameter in the vLLM server configuration." From fa14b211a4345f0f03ee7ba32e135fc672082e71 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Tue, 25 Mar 2025 12:04:33 -0400 Subject: [PATCH 12/27] Add timeout arg --- trl/extras/vllm_proxy.py | 2 +- trl/trainer/grpo_trainer.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/trl/extras/vllm_proxy.py b/trl/extras/vllm_proxy.py index 2946223cd82..4f263473ad6 100644 --- a/trl/extras/vllm_proxy.py +++ b/trl/extras/vllm_proxy.py @@ -48,5 +48,5 @@ def get_vllm_client(args, accelerator, model) -> BaseVLLMClient: if args.vllm_colocation: return VLLMColocationClient(accelerator, args, model) else: - return VLLMClient(args.vllm_server_host, args.vllm_server_port, connection_timeout=120.0) + return VLLMClient(args.vllm_server_host, args.vllm_server_port, args.vllm_server_timeout) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e4271c7f953..3f13f8ed4a3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -473,9 +473,12 @@ def data_collator(features): # No data collation is needed in GRPO "`pip install vllm` to use it." ) + self.vllm_client = None if self.accelerator.is_main_process or self.args.vllm_colocation: self.vllm_client = get_vllm_client(self.args, self.accelerator, model) + print("-----\n\n\n\nVLLM client initialized!") + # vLLM specific sampling arguments self.guided_decoding_regex = args.vllm_guided_decoding_regex From e0dfba85bf83aa75a97b6f65bb0b7548d0fb2a66 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Tue, 25 Mar 2025 12:31:53 -0400 Subject: [PATCH 13/27] Fix vllm client init --- trl/extras/vllm_proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/extras/vllm_proxy.py b/trl/extras/vllm_proxy.py index 4f263473ad6..d2fb9ba604e 100644 --- a/trl/extras/vllm_proxy.py +++ b/trl/extras/vllm_proxy.py @@ -48,5 +48,5 @@ def get_vllm_client(args, accelerator, model) -> BaseVLLMClient: if args.vllm_colocation: return VLLMColocationClient(accelerator, args, model) else: - return VLLMClient(args.vllm_server_host, args.vllm_server_port, args.vllm_server_timeout) + return VLLMClient(args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout) From 0f98e5cbd70bc87efd6d6a31c5305ddb227ef515 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Tue, 25 Mar 2025 12:43:57 -0400 Subject: [PATCH 14/27] Remove lazy import --- trl/extras/vllm_base.py | 42 +++++++++++++++++++++++++++++++++ trl/extras/vllm_client.py | 2 +- trl/extras/vllm_coloc_client.py | 3 ++- trl/extras/vllm_proxy.py | 34 +++----------------------- trl/trainer/grpo_trainer.py | 3 +++ 5 files changed, 51 insertions(+), 33 deletions(-) create mode 100644 trl/extras/vllm_base.py diff --git a/trl/extras/vllm_base.py b/trl/extras/vllm_base.py new file mode 100644 index 00000000000..038361102a4 --- /dev/null +++ b/trl/extras/vllm_base.py @@ -0,0 +1,42 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional +import torch + +class BaseVLLMClient(ABC): + + @abstractmethod + def generate( + self, + prompts: list[str], + n: int = 1, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + max_tokens: int = 16, + guided_decoding_regex: Optional[str] = None, + ) -> list[list[str]]: + pass + + @abstractmethod + def update_named_param(self, name: str, weights: torch.Tensor): + pass + + @abstractmethod + def reset_prefix_cache(self): + pass \ No newline at end of file diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 6de19c6a394..b21d77e740b 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -21,7 +21,7 @@ from torch import nn from ..import_utils import is_requests_available, is_vllm_available -from .vllm_proxy import BaseVLLMClient +from .vllm_base import BaseVLLMClient if is_requests_available(): import requests diff --git a/trl/extras/vllm_coloc_client.py b/trl/extras/vllm_coloc_client.py index d02b56c787f..10fca000b1f 100644 --- a/trl/extras/vllm_coloc_client.py +++ b/trl/extras/vllm_coloc_client.py @@ -15,7 +15,8 @@ from typing import Optional import torch import warnings -from .vllm_proxy import BaseVLLMClient +from .vllm_base import BaseVLLMClient + from vllm import SamplingParams, LLM from vllm.sampling_params import GuidedDecodingParams diff --git a/trl/extras/vllm_proxy.py b/trl/extras/vllm_proxy.py index d2fb9ba604e..707b1b6fa80 100644 --- a/trl/extras/vllm_proxy.py +++ b/trl/extras/vllm_proxy.py @@ -12,39 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod -from typing import Optional -import torch - -class BaseVLLMClient(ABC): - - @abstractmethod - def generate( - self, - prompts: list[str], - n: int = 1, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - min_p: float = 0.0, - max_tokens: int = 16, - guided_decoding_regex: Optional[str] = None, - ) -> list[list[str]]: - pass - - @abstractmethod - def update_named_param(self, name: str, weights: torch.Tensor): - pass - - @abstractmethod - def reset_prefix_cache(self): - pass +from .vllm_base import BaseVLLMClient +from .vllm_client import VLLMClient +from .vllm_coloc_client import VLLMColocationClient def get_vllm_client(args, accelerator, model) -> BaseVLLMClient: - from .vllm_coloc_client import VLLMColocationClient - from .vllm_client import VLLMClient - if args.vllm_colocation: return VLLMColocationClient(accelerator, args, model) else: diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3f13f8ed4a3..2721e86c4e0 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -474,7 +474,10 @@ def data_collator(features): # No data collation is needed in GRPO ) self.vllm_client = None + print("args here", self.args) + print("coloc arg", self.args.vllm_colocation) if self.accelerator.is_main_process or self.args.vllm_colocation: + print("-----\n\n\n\nInitializing now!!!!") self.vllm_client = get_vllm_client(self.args, self.accelerator, model) print("-----\n\n\n\nVLLM client initialized!") From 837263c9c37a1f8913a361f277607cb368475482 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Tue, 25 Mar 2025 12:46:18 -0400 Subject: [PATCH 15/27] Debugging client --- trl/extras/vllm_proxy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/extras/vllm_proxy.py b/trl/extras/vllm_proxy.py index 707b1b6fa80..142b50f6630 100644 --- a/trl/extras/vllm_proxy.py +++ b/trl/extras/vllm_proxy.py @@ -18,7 +18,9 @@ def get_vllm_client(args, accelerator, model) -> BaseVLLMClient: if args.vllm_colocation: + print("\n\n\n\nColoc client !") return VLLMColocationClient(accelerator, args, model) else: + print("\n\n\n\nOld client !") return VLLMClient(args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout) From 76ca767b5f45e701673aa7baa1a324ae6852e828 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Tue, 25 Mar 2025 12:58:56 -0400 Subject: [PATCH 16/27] Add dtype auto as default --- trl/trainer/grpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index d812aea6b12..ea5b447e3db 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -362,7 +362,7 @@ class GRPOConfig(TrainingArguments): }, ) vllm_dtype: Optional[str] = field( - default=None, + default="auto", metadata={ "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the data type for " "vLLM generation, you should now use the `dtype` parameter in the vLLM server configuration." From d96655f1ff9bd6ec39695214590834f2d1a09cd1 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Tue, 25 Mar 2025 13:41:57 -0400 Subject: [PATCH 17/27] Remove prints and set default for vllm params --- trl/trainer/grpo_config.py | 4 ++-- trl/trainer/grpo_trainer.py | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index ea5b447e3db..c389bc0ada4 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -354,7 +354,7 @@ class GRPOConfig(TrainingArguments): }, ) vllm_gpu_memory_utilization: Optional[float] = field( - default=None, + default=0.3, metadata={ "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the GPU memory " "utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server " @@ -377,7 +377,7 @@ class GRPOConfig(TrainingArguments): }, ) vllm_enable_prefix_caching: Optional[bool] = field( - default=None, + default=False, metadata={ "help": "This parameter is deprecated and will be removed in version 0.18.0. To control prefix caching in " "vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server configuration." diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 2721e86c4e0..e4271c7f953 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -473,15 +473,9 @@ def data_collator(features): # No data collation is needed in GRPO "`pip install vllm` to use it." ) - self.vllm_client = None - print("args here", self.args) - print("coloc arg", self.args.vllm_colocation) if self.accelerator.is_main_process or self.args.vllm_colocation: - print("-----\n\n\n\nInitializing now!!!!") self.vllm_client = get_vllm_client(self.args, self.accelerator, model) - print("-----\n\n\n\nVLLM client initialized!") - # vLLM specific sampling arguments self.guided_decoding_regex = args.vllm_guided_decoding_regex From c3e58a30348dd1db2a40a4c9b239dc9e583d66a0 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Tue, 25 Mar 2025 16:01:25 -0400 Subject: [PATCH 18/27] Remove debug statements --- trl/extras/vllm_coloc_client.py | 1 - trl/extras/vllm_proxy.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/trl/extras/vllm_coloc_client.py b/trl/extras/vllm_coloc_client.py index 10fca000b1f..b8558227bfa 100644 --- a/trl/extras/vllm_coloc_client.py +++ b/trl/extras/vllm_coloc_client.py @@ -24,7 +24,6 @@ class VLLMColocationClient(BaseVLLMClient): def __init__(self, accelerator, args, model): - ## ToDo: get accelerator and other things - to offload if/else from trainer to the clients self.args = args self.accelerator = accelerator self.model = model diff --git a/trl/extras/vllm_proxy.py b/trl/extras/vllm_proxy.py index 142b50f6630..707b1b6fa80 100644 --- a/trl/extras/vllm_proxy.py +++ b/trl/extras/vllm_proxy.py @@ -18,9 +18,7 @@ def get_vllm_client(args, accelerator, model) -> BaseVLLMClient: if args.vllm_colocation: - print("\n\n\n\nColoc client !") return VLLMColocationClient(accelerator, args, model) else: - print("\n\n\n\nOld client !") return VLLMClient(args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout) From 5b43f0fbdbc912ee2f56490cc744200bb7a575dc Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Wed, 26 Mar 2025 01:26:55 +0000 Subject: [PATCH 19/27] have just 1 vllm_client.py Signed-off-by: Yu Chin Fabian Lim --- trl/extras/vllm_base.py | 42 ---------- trl/extras/vllm_client.py | 137 +++++++++++++++++++++++++++++++- trl/extras/vllm_coloc_client.py | 133 ------------------------------- trl/extras/vllm_proxy.py | 24 ------ trl/trainer/grpo_trainer.py | 6 +- 5 files changed, 139 insertions(+), 203 deletions(-) delete mode 100644 trl/extras/vllm_base.py delete mode 100644 trl/extras/vllm_coloc_client.py delete mode 100644 trl/extras/vllm_proxy.py diff --git a/trl/extras/vllm_base.py b/trl/extras/vllm_base.py deleted file mode 100644 index 038361102a4..00000000000 --- a/trl/extras/vllm_base.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Optional -import torch - -class BaseVLLMClient(ABC): - - @abstractmethod - def generate( - self, - prompts: list[str], - n: int = 1, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - min_p: float = 0.0, - max_tokens: int = 16, - guided_decoding_regex: Optional[str] = None, - ) -> list[list[str]]: - pass - - @abstractmethod - def update_named_param(self, name: str, weights: torch.Tensor): - pass - - @abstractmethod - def reset_prefix_cache(self): - pass \ No newline at end of file diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index b21d77e740b..70f4545002f 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod import atexit import logging import time @@ -21,20 +22,48 @@ from torch import nn from ..import_utils import is_requests_available, is_vllm_available -from .vllm_base import BaseVLLMClient + +from ..trainer.grpo_config import GRPOConfig if is_requests_available(): import requests from requests import ConnectionError - if is_vllm_available(): from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup + from vllm import SamplingParams, LLM +from accelerate import Accelerator logger = logging.getLogger(__name__) +# abstract base class. All vllm clients must +# implement these methods +class BaseVLLMClient(ABC): + + @abstractmethod + def generate( + self, + prompts: list[str], + n: int = 1, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + max_tokens: int = 16, + guided_decoding_regex: Optional[str] = None, + ) -> list[list[str]]: + pass + + @abstractmethod + def update_named_param(self, name: str, weights: torch.Tensor): + pass + + @abstractmethod + def reset_prefix_cache(self): + pass class VLLMClient(BaseVLLMClient): """ @@ -280,3 +309,107 @@ def close_communicator(self): model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda") client.update_model_params(model) + + + +class VLLMColocationClient(BaseVLLMClient): + def __init__(self, args: GRPOConfig, model, vllm_device): + self.args: GRPOConfig = args + self.model = model + self.vllm_device = vllm_device + + self.llm = LLM( + model=self.model.name_or_path, + device=self.vllm_device, + gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + dtype=self.args.vllm_dtype, + enable_prefix_caching=self.args.vllm_enable_prefix_caching, + max_model_len=self.args.vllm_max_model_len, + distributed_executor_backend="external_launcher", + hf_overrides = { + 'max_position_embeddings': self.args.vllm_max_model_len + }, + ) + + def update_named_param(self, name: str, weights: torch.Tensor): + """ + Updates a specific named parameter in the model. + + Args: + name (`str`): + Name of the layer whose weights are being updated. + weights (`torch.Tensor`): + Tensor containing the updated weights. + """ + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name,weights)]) + + def generate( + self, + prompts: list[str], + n: int = 1, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + max_tokens: int = 16, + guided_decoding_regex: Optional[str] = None, + ) -> list[list[str]]: + """ + Generates model completions for the provided prompts. + + Args: + prompts (`list[str]`): + List of text prompts for which the model will generate completions. + n (`int`, *optional*, defaults to `1`): + Number of completions to generate for each prompt. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Parameter for repetition penalty. 1.0 means no penalty. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature parameter for sampling. Higher values increase diversity. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter.`1.0` means no truncation. + top_k (`int`, *optional*, defaults to `-1`): + Top-k sampling parameter. `-1` means no truncation. + min_p (`float`, *optional*, defaults to `0.0`): + Minimum probability for sampling. + max_tokens (`int`, *optional*, defaults to `16`): + Maximum number of tokens to generate for each prompt. + guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): + Regular expression to guide the decoding process. + + Returns: + `list[list[int]]`: + List of lists of token IDs representing the model-generated completions for each prompt. + """ + sampling_params = SamplingParams( + n=1, # vLLM on each GPU generates only 1 in vllm_colocation mode + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_tokens=max_tokens, + guided_decoding=guided_decoding_regex, + ) + + all_outputs = self.llm.generate( + prompts, sampling_params=sampling_params, use_tqdm=False + ) + completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + return completion_ids + + def reset_prefix_cache(self): + """ + Resets the prefix cache for the model. + """ + self.llm.reset_prefix_cache() + +# build appropriate client according to config +def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> BaseVLLMClient: + if args.vllm_colocation: + return VLLMColocationClient(args, model, accelerator.device) + else: + return VLLMClient(args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout) + diff --git a/trl/extras/vllm_coloc_client.py b/trl/extras/vllm_coloc_client.py deleted file mode 100644 index b8558227bfa..00000000000 --- a/trl/extras/vllm_coloc_client.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional -import torch -import warnings -from .vllm_base import BaseVLLMClient - -from vllm import SamplingParams, LLM -from vllm.sampling_params import GuidedDecodingParams - -from accelerate import PartialState - -class VLLMColocationClient(BaseVLLMClient): - def __init__(self, accelerator, args, model): - self.args = args - self.accelerator = accelerator - self.model = model - self.vllm_device = None - self._initialize_colocated_vllm() - - def _initialize_colocated_vllm(self): - device_type = PartialState().default_device.type - self.vllm_device = f"{device_type}:{self.accelerator.process_index}" - - warnings.warn( - f"The requested device {self.vllm_device} is also being used for training. For higher throughput " - "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. " - "If this is intentional, you may ignore this warning but should adjust " - "`vllm_gpu_memory_utilization` accordingly." - ) - - self.llm = LLM( - model=self.model.name_or_path, - device=self.vllm_device, - gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, - dtype=self.args.vllm_dtype, - enable_prefix_caching=self.args.vllm_enable_prefix_caching, - max_model_len=self.args.vllm_max_model_len, - distributed_executor_backend="external_launcher", - ) - - def update_named_param(self, name: str, weights: torch.Tensor): - """ - Updates a specific named parameter in the model. - - Args: - name (`str`): - Name of the layer whose weights are being updated. - weights (`torch.Tensor`): - Tensor containing the updated weights. - """ - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name,weights)]) - - def generate( - self, - prompts: list[str], - n: int = 1, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - min_p: float = 0.0, - max_tokens: int = 16, - guided_decoding_regex: Optional[str] = None, - ) -> list[list[str]]: - """ - Generates model completions for the provided prompts. - - Args: - prompts (`list[str]`): - List of text prompts for which the model will generate completions. - n (`int`, *optional*, defaults to `1`): - Number of completions to generate for each prompt. - repetition_penalty (`float`, *optional*, defaults to `1.0`): - Parameter for repetition penalty. 1.0 means no penalty. - temperature (`float`, *optional*, defaults to `1.0`): - Temperature parameter for sampling. Higher values increase diversity. - top_p (`float`, *optional*, defaults to `1.0`): - Top-p sampling parameter.`1.0` means no truncation. - top_k (`int`, *optional*, defaults to `-1`): - Top-k sampling parameter. `-1` means no truncation. - min_p (`float`, *optional*, defaults to `0.0`): - Minimum probability for sampling. - max_tokens (`int`, *optional*, defaults to `16`): - Maximum number of tokens to generate for each prompt. - guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): - Regular expression to guide the decoding process. - - Returns: - `list[list[int]]`: - List of lists of token IDs representing the model-generated completions for each prompt. - """ - if guided_decoding_regex is not None: - guided_decoding = GuidedDecodingParams(backend="outlines", regex=guided_decoding_regex) - else: - guided_decoding = None - - sampling_params = SamplingParams( - n=1, # vLLM on each GPU generates only 1 in vllm_colocation mode - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_tokens, - guided_decoding=guided_decoding, - ) - - all_outputs = self.llm.generate( - prompts, sampling_params=sampling_params, use_tqdm=False - ) - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - return completion_ids - - def reset_prefix_cache(self): - """ - Resets the prefix cache for the model. - """ - self.llm.reset_prefix_cache() - diff --git a/trl/extras/vllm_proxy.py b/trl/extras/vllm_proxy.py deleted file mode 100644 index 707b1b6fa80..00000000000 --- a/trl/extras/vllm_proxy.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .vllm_base import BaseVLLMClient -from .vllm_client import VLLMClient -from .vllm_coloc_client import VLLMColocationClient - -def get_vllm_client(args, accelerator, model) -> BaseVLLMClient: - if args.vllm_colocation: - return VLLMColocationClient(accelerator, args, model) - else: - return VLLMClient(args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout) - diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e4271c7f953..3c9e750a952 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -43,7 +43,7 @@ from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..extras.profiling import profiling_context, profiling_decorator -from ..extras.vllm_proxy import get_vllm_client +from ..extras.vllm_client import get_vllm_client from ..import_utils import is_deepspeed_available, is_rich_available, is_vllm_available from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from .callbacks import SyncRefModelCallback @@ -474,7 +474,9 @@ def data_collator(features): # No data collation is needed in GRPO ) if self.accelerator.is_main_process or self.args.vllm_colocation: - self.vllm_client = get_vllm_client(self.args, self.accelerator, model) + self.vllm_client = get_vllm_client( + self.args, model, self.accelerator, + ) # vLLM specific sampling arguments self.guided_decoding_regex = args.vllm_guided_decoding_regex From 5aa7882de7bb2ba1cd9f1be57035ed58795945ab Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Wed, 26 Mar 2025 06:18:05 +0000 Subject: [PATCH 20/27] move controls into vllm_client objects Signed-off-by: Yu Chin Fabian Lim --- trl/extras/vllm_client.py | 100 +++++++++++++++++++++++++----------- trl/trainer/grpo_trainer.py | 59 +++++++-------------- 2 files changed, 88 insertions(+), 71 deletions(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 70f4545002f..80f573312ed 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod import atexit import logging import time @@ -35,14 +34,17 @@ from vllm import SamplingParams, LLM from accelerate import Accelerator +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed logger = logging.getLogger(__name__) # abstract base class. All vllm clients must # implement these methods -class BaseVLLMClient(ABC): +class VLLMNoopClient: + + def __init__(self, process_index: int): + self.process_index = process_index - @abstractmethod def generate( self, prompts: list[str], @@ -55,17 +57,32 @@ def generate( max_tokens: int = 16, guided_decoding_regex: Optional[str] = None, ) -> list[list[str]]: - pass + orig_size = len(prompts) + prompts = gather_object(prompts) + completion_ids = [None] * len(prompts) + return self._broadcast_and_slice(completion_ids, orig_size) - @abstractmethod def update_named_param(self, name: str, weights: torch.Tensor): pass - @abstractmethod def reset_prefix_cache(self): pass -class VLLMClient(BaseVLLMClient): + def _gather(self, prompts): + return gather_object(prompts) + + def _broadcast_and_slice(self, completion_ids: list, slice_size: int): + # Broadcast the completions from the main process to all processes, ensuring each process receives its + # corresponding slice + + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.process_index * slice_size, + (self.process_index + 1) * slice_size, + ) + return completion_ids[process_slice] + +class VLLMClient(VLLMNoopClient): """ A client class to interact with a vLLM server. @@ -109,13 +126,17 @@ class VLLMClient(BaseVLLMClient): """ def __init__( - self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0 + self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0, + distributed: bool = False ): + super().__init__(process_index=0) + if not is_requests_available(): raise ImportError("requests is not installed. Please install it with `pip install requests`.") if not is_vllm_available(): raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") + self.distributed = distributed self.session = requests.Session() self.host = host self.server_port = server_port @@ -197,6 +218,16 @@ def generate( `list[list[int]]`: List of lists of token IDs representing the model-generated completions for each prompt. """ + + if self.distributed: + orig_size = len(prompts) + prompts = self._gather(prompts) + + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually + prompts = prompts[::n] + url = f"http://{self.host}:{self.server_port}/generate/" response = self.session.post( url, @@ -213,10 +244,15 @@ def generate( }, ) if response.status_code == 200: - return response.json()["completion_ids"] + completion_ids = response.json()["completion_ids"] else: raise Exception(f"Request failed: {response.status_code}, {response.text}") + if self.distributed: + completion_ids = self._broadcast_and_slice(completion_ids, orig_size) + + return completion_ids + def init_communicator(self): """ Initializes the weight update group in a distributed setup for model synchronization. @@ -294,25 +330,8 @@ def close_communicator(self): raise Exception(f"Request failed: {response.status_code}, {response.text}") -# Example usage -if __name__ == "__main__": - from vllm import SamplingParams - - client = VLLMClient() - # Generate completions - responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams()) - print("Responses:", responses) # noqa - - # Update model weights - from transformers import AutoModelForCausalLM - - model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda") - client.update_model_params(model) - - - -class VLLMColocationClient(BaseVLLMClient): +class VLLMColocationClient: def __init__(self, args: GRPOConfig, model, vllm_device): self.args: GRPOConfig = args self.model = model @@ -407,9 +426,30 @@ def reset_prefix_cache(self): self.llm.reset_prefix_cache() # build appropriate client according to config -def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> BaseVLLMClient: +def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> VLLMNoopClient: if args.vllm_colocation: return VLLMColocationClient(args, model, accelerator.device) - else: - return VLLMClient(args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout) + elif accelerator.is_main_process: + return VLLMClient( + args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout, + distributed=accelerator.num_processes > 1, + ) + return VLLMNoopClient(accelerator.process_index) + + +# Example usage for VLLMCLient +if __name__ == "__main__": + from vllm import SamplingParams + + client = VLLMClient() + + # Generate completions + responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams()) + print("Responses:", responses) # noqa + + # Update model weights + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda") + client.update_model_params(model) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3c9e750a952..3cef142ce4a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -473,10 +473,9 @@ def data_collator(features): # No data collation is needed in GRPO "`pip install vllm` to use it." ) - if self.accelerator.is_main_process or self.args.vllm_colocation: - self.vllm_client = get_vllm_client( - self.args, model, self.accelerator, - ) + self.vllm_client = get_vllm_client( + self.args, model, self.accelerator, + ) # vLLM specific sampling arguments self.guided_decoding_regex = args.vllm_guided_decoding_regex @@ -486,8 +485,7 @@ def data_collator(features): # No data collation is needed in GRPO # When using vLLM, the main process is responsible for loading the model weights. This can cause process # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we # synchronize all processes after vLLM has been fully initialized (if colocated, no need to wait). - if not self.args.vllm_colocation: - self.accelerator.wait_for_everyone() + self.accelerator.wait_for_everyone() else: self.generation_config = GenerationConfig( max_new_tokens=self.max_completion_length, @@ -643,8 +641,7 @@ def _move_model_to_vllm(self): continue name = name.replace("modules_to_save.default.", "") - if self.accelerator.is_main_process or self.args.vllm_colocation: - self.vllm_client.update_named_param(name, param.data) + self.vllm_client.update_named_param(name, param.data) # Unmerge adapters while parameters are still gathered self.model.unmerge_adapter() @@ -653,12 +650,10 @@ def _move_model_to_vllm(self): # For non-PEFT models, simply gather and update each parameter individually. for name, param in self.model.named_parameters(): with gather_if_zero3([param]): - if self.accelerator.is_main_process or self.args.vllm_colocation: - self.vllm_client.update_named_param(name, param.data) + self.vllm_client.update_named_param(name, param.data) # Reset cache on main process (if colocated, reset cache on all vllms) - if self.accelerator.is_main_process or self.args.vllm_colocation: - self.vllm_client.reset_prefix_cache() + self.vllm_client.reset_prefix_cache() @profiling_decorator def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: @@ -699,35 +694,17 @@ def _generate_and_score_completions( self._last_loaded_step = self.state.global_step # Generate completions using vLLM: gather all prompts and use them in a single call in the main process (if colocated, work on your own batch) - all_prompts_text = prompts_text if self.args.vllm_colocation else gather_object(prompts_text) - if self.accelerator.is_main_process or self.args.vllm_colocation: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually (if colocated, work on your own batch). - ordered_set_of_prompts = all_prompts_text if self.args.vllm_colocation else all_prompts_text[:: self.num_generations] - # with profiling_context(self, "vLLM.generate"): - completion_ids = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - guided_decoding_regex=self.guided_decoding_regex, - ) - else: - completion_ids = [None] * len(all_prompts_text) - # Broadcast the completions from the main process to all processes, ensuring each process receives its - # corresponding slice (if colocated, no need for broadcasting). - if not self.args.vllm_colocation: - completion_ids = broadcast_object_list(completion_ids, from_process=0) - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - completion_ids = completion_ids[process_slice] + completion_ids = self.vllm_client.generate( + prompts=prompts_text, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + ) # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] From 07c2b0f7dcee9fa1437b2439d6a8d2cef2a02ade Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Wed, 26 Mar 2025 10:04:22 -0400 Subject: [PATCH 21/27] Add comments and docstring --- trl/extras/vllm_client.py | 66 ++++++++++++++++++++++++++++++------- trl/trainer/grpo_trainer.py | 2 +- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 80f573312ed..772e0f74184 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -32,15 +32,26 @@ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup from vllm import SamplingParams, LLM + from vllm.sampling_params import GuidedDecodingParams from accelerate import Accelerator from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed logger = logging.getLogger(__name__) -# abstract base class. All vllm clients must -# implement these methods -class VLLMNoopClient: +class VLLMNoOpClient: + """ + A no-op vLLM client used in distributed training when the process is neither the main process + nor running in vLLM colocation mode. + + This stub client ensures compatibility in distributed setups without performing actual + inference or model updates. + + Methods like `generate` and `update_named_param` are implemented as no-ops or return default + values to maintain consistent interfaces across processes. + + This class should only be used internally by `get_vllm_client`. + """ def __init__(self, process_index: int): self.process_index = process_index @@ -82,7 +93,7 @@ def _broadcast_and_slice(self, completion_ids: list, slice_size: int): ) return completion_ids[process_slice] -class VLLMClient(VLLMNoopClient): +class VLLMClient(VLLMNoOpClient): """ A client class to interact with a vLLM server. @@ -329,9 +340,20 @@ def close_communicator(self): if response.status_code != 200: raise Exception(f"Request failed: {response.status_code}, {response.text}") +class VLLMColocationClient: + """ + A client class to interact with vLLM processes colocated with the training process. + This client bypasses remote communication and directly interacts with the in-process vLLM engine. + It supports weight updates and text generation functionalities similar to `VLLMClient`, but is optimized + for scenarios where vLLM is running in the same process or node as training. + + Args: + args (`GRPOConfig`): Configuration object containing vLLM parameters. + model (`transformers.PreTrainedModel`): The model being used. + vllm_device (`torch.device` or `str`): Device on which the model is loaded (e.g., "cuda:0"). + """ -class VLLMColocationClient: def __init__(self, args: GRPOConfig, model, vllm_device): self.args: GRPOConfig = args self.model = model @@ -345,8 +367,8 @@ def __init__(self, args: GRPOConfig, model, vllm_device): enable_prefix_caching=self.args.vllm_enable_prefix_caching, max_model_len=self.args.vllm_max_model_len, distributed_executor_backend="external_launcher", - hf_overrides = { - 'max_position_embeddings': self.args.vllm_max_model_len + hf_overrides = { + 'max_position_embeddings': self.args.vllm_max_model_len # model config reflects model length just for consistency (vllm == 0.8.2) }, ) @@ -402,6 +424,12 @@ def generate( `list[list[int]]`: List of lists of token IDs representing the model-generated completions for each prompt. """ + # Guided decoding, if enabled + if guided_decoding_regex is not None: + guided_decoding = GuidedDecodingParams(backend="outlines", regex=guided_decoding_regex) + else: + guided_decoding = None + sampling_params = SamplingParams( n=1, # vLLM on each GPU generates only 1 in vllm_colocation mode repetition_penalty=repetition_penalty, @@ -410,7 +438,7 @@ def generate( top_k=top_k, min_p=min_p, max_tokens=max_tokens, - guided_decoding=guided_decoding_regex, + guided_decoding=guided_decoding, ) all_outputs = self.llm.generate( @@ -425,8 +453,23 @@ def reset_prefix_cache(self): """ self.llm.reset_prefix_cache() -# build appropriate client according to config -def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> VLLMNoopClient: +def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> VLLMNoOpClient: + """ + Returns the appropriate vLLM client based on the current configuration. + + This function acts as a proxy to initialize and return the correct vLLM client type: + - If colocation is enabled, it returns `VLLMColocationClient`, which interacts directly with + the colocated vLLM process for faster integration. + - If running in the main process (non-colocated mode), it returns `VLLMClient`, which communicates + with an external vLLM server. + - If not the main process and colocation is disabled, it returns a base client (`VLLMNoOpClient`) + for compatibility in distributed settings. + + Args: + args (`GRPOConfig`): Configuration object containing flags for colocation, server host, port, etc. + model (`transformers.PreTrainedModel`): The model to use, passed only for the colocated client. + accelerator (`Accelerator`): Hugging Face `Accelerator` object that helps with multi-GPU training. + """ if args.vllm_colocation: return VLLMColocationClient(args, model, accelerator.device) elif accelerator.is_main_process: @@ -434,8 +477,7 @@ def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> VLLMNo args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout, distributed=accelerator.num_processes > 1, ) - return VLLMNoopClient(accelerator.process_index) - + return VLLMNoOpClient(accelerator.process_index) # Example usage for VLLMCLient if __name__ == "__main__": diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3cef142ce4a..1682e3ac081 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -484,7 +484,7 @@ def data_collator(features): # No data collation is needed in GRPO # When using vLLM, the main process is responsible for loading the model weights. This can cause process # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we - # synchronize all processes after vLLM has been fully initialized (if colocated, no need to wait). + # synchronize all processes after vLLM has been fully initialized. self.accelerator.wait_for_everyone() else: self.generation_config = GenerationConfig( From 0551491b4a1f25574e687b59565022ef46b6e650 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Thu, 27 Mar 2025 11:59:20 -0400 Subject: [PATCH 22/27] Remove hf overrides --- trl/extras/vllm_client.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 772e0f74184..0cbb9d2bdf6 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -367,9 +367,6 @@ def __init__(self, args: GRPOConfig, model, vllm_device): enable_prefix_caching=self.args.vllm_enable_prefix_caching, max_model_len=self.args.vllm_max_model_len, distributed_executor_backend="external_launcher", - hf_overrides = { - 'max_position_embeddings': self.args.vllm_max_model_len # model config reflects model length just for consistency (vllm == 0.8.2) - }, ) def update_named_param(self, name: str, weights: torch.Tensor): From 28b92b0994d18db770d25065b82e336711f4fc1e Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Fri, 18 Apr 2025 10:03:30 -0400 Subject: [PATCH 23/27] Fix imports --- trl/extras/vllm_client.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index f7385db4dd1..90a33ae87b9 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -33,12 +33,11 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm import SamplingParams, LLM from vllm.sampling_params import GuidedDecodingParams + if is_vllm_ascend_available(): + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator from accelerate import Accelerator -from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed - -if is_vllm_ascend_available(): - from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator +from accelerate.utils import broadcast_object_list, gather_object logger = logging.getLogger(__name__) From 551aa7d827d409598b93f4093c1814b54dd1bafe Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 28 Apr 2025 14:38:04 -0400 Subject: [PATCH 24/27] Init communicator for vllm server client --- trl/extras/vllm_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 552ea715715..3b990f37b09 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -486,10 +486,12 @@ def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> VLLMNo if args.vllm_colocation: return VLLMColocationClient(args, model, accelerator.device) elif accelerator.is_main_process: - return VLLMClient( + vllm_client = VLLMClient( args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout, distributed=accelerator.num_processes > 1, ) + vllm_client.init_communicator() + return vllm_client return VLLMNoOpClient(accelerator.process_index) # Example usage for VLLMCLient From dff5fd512d5fc5ffcc229ca0b151e3c2b9b73353 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 28 Apr 2025 15:15:39 -0400 Subject: [PATCH 25/27] Fix config --- trl/trainer/grpo_config.py | 35 +---------------------------------- 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 532d9e0c2b3..21912daee47 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -446,37 +446,4 @@ class GRPOConfig(TrainingArguments): "`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server " "configuration." }, - ) - - def __post_init__(self): - super().__post_init__() - - if self.vllm_device is not None: - warnings.warn( - "`vllm_device` is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM server " - "with the `trl vllm-serve` command.", - DeprecationWarning, - ) - - if self.vllm_gpu_memory_utilization is not None: - warnings.warn( - "`vllm_gpu_memory_utilization` is deprecated and will be removed in v0.18. To control the GPU memory " - "utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server " - "configuration.", - DeprecationWarning, - ) - - if self.vllm_dtype is not None: - warnings.warn( - "`vllm_dtype` is deprecated and will be removed in version 0.18.0. To control the data type for vLLM " - "generation, you should now use the `dtype` parameter in the vLLM server configuration.", - DeprecationWarning, - ) - - if self.vllm_max_model_len is not None: - warnings.warn( - "`vllm_max_model_len` is deprecated and will be removed in version 0.18.0. To control the " - "`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server " - "configuration.", - DeprecationWarning, - ) \ No newline at end of file + ) \ No newline at end of file From e72262567461a25a79768f739245728334ce93c5 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 28 Apr 2025 15:31:41 -0400 Subject: [PATCH 26/27] Fix config --- trl/extras/vllm_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 3b990f37b09..a980e915d8e 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -380,7 +380,6 @@ def __init__(self, args: GRPOConfig, model, vllm_device): device=self.vllm_device, gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, dtype=self.args.vllm_dtype, - enable_prefix_caching=self.args.vllm_enable_prefix_caching, max_model_len=self.args.vllm_max_model_len, distributed_executor_backend="external_launcher", ) From 7edcef7591206439381d8f5b186e12a19292e556 Mon Sep 17 00:00:00 2001 From: Mert Toslali Date: Mon, 28 Apr 2025 16:41:34 -0400 Subject: [PATCH 27/27] Seed is needed for new vllm req --- trl/extras/vllm_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index a980e915d8e..c5483feec41 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -382,6 +382,7 @@ def __init__(self, args: GRPOConfig, model, vllm_device): dtype=self.args.vllm_dtype, max_model_len=self.args.vllm_max_model_len, distributed_executor_backend="external_launcher", + seed=0 ) def update_named_param(self, name: str, weights: torch.Tensor):