Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,8 @@ def test_training_vllm_and_peft(self):

@require_vllm
@unittest.skip("We should add a mock for the vLLM server.")
def test_training_vllm_guided_decoding(self):
"""Test that training works with vLLM for generation with guided decoding."""
def test_training_vllm_structured_outputs(self):
"""Test that training works with vLLM for generation with structured outputs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
Expand All @@ -773,7 +773,7 @@ def test_training_vllm_guided_decoding(self):
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
vllm_structured_outputs_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
Expand All @@ -796,7 +796,7 @@ def test_training_vllm_guided_decoding(self):
@require_vllm
@unittest.skip("We should add a mock for the vLLM server.")
def test_training_vllm_importance_sampling_correction(self):
"""Test that training works with vLLM for generation with guided decoding."""
"""Test that training works with vLLM for generation with structured outputs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,8 @@ def test_training_vllm_and_peft(self):

@require_vllm
@unittest.skip("We should add a mock for the vLLM server.")
def test_training_vllm_guided_decoding(self):
"""Test that training works with vLLM for generation with guided decoding."""
def test_training_vllm_structured_outputs(self):
"""Test that training works with vLLM for generation with structured outputs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = RLOOConfig(
Expand All @@ -640,7 +640,7 @@ def test_training_vllm_guided_decoding(self):
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
vllm_structured_outputs_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
)
trainer = RLOOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
Expand Down
12 changes: 6 additions & 6 deletions trl/experimental/gfpo/gfpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

if is_vllm_available():
from vllm import SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from vllm.sampling_params import StructuredOutputsParams


class GFPOTrainer(_GRPOTrainer):
Expand Down Expand Up @@ -194,7 +194,7 @@ def _generate_and_score_completions(self, inputs):
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
structured_outputs_regex=self.structured_outputs_regex,
generation_kwargs=self.args.generation_kwargs,
)
payload = (output["completion_ids"], output["logprobs"])
Expand All @@ -215,10 +215,10 @@ def _generate_and_score_completions(self, inputs):

# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
elif self.vllm_mode == "colocate":
if self.guided_decoding_regex:
guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
if self.structured_outputs_regex:
structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex)
else:
guided_decoding = None
structured_outputs = None

generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
Expand All @@ -228,7 +228,7 @@ def _generate_and_score_completions(self, inputs):
"top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"guided_decoding": guided_decoding,
"structured_outputs": structured_outputs,
"logprobs": 0, # only return the logprob of the generated token
}
if self.args.generation_kwargs is not None:
Expand Down
6 changes: 3 additions & 3 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def generate(
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 16,
guided_decoding_regex: Optional[str] = None,
structured_outputs_regex: Optional[str] = None,
generation_kwargs: Optional[dict] = None,
) -> list[list[int]]:
"""
Expand All @@ -203,7 +203,7 @@ def generate(
Minimum probability for sampling.
max_tokens (`int`, *optional*, defaults to `16`):
Maximum number of tokens to generate for each prompt.
guided_decoding_regex (`str`, *optional*):
structured_outputs_regex (`str`, *optional*):
Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*):
Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like
Expand Down Expand Up @@ -240,7 +240,7 @@ def pil_to_base64(image):
"top_k": top_k,
"min_p": min_p,
"max_tokens": max_tokens,
"guided_decoding_regex": guided_decoding_regex,
"structured_outputs_regex": structured_outputs_regex,
"generation_kwargs": generation_kwargs or {},
},
)
Expand Down
16 changes: 8 additions & 8 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.parallel_state import get_world_group
from vllm.distributed.utils import StatelessProcessGroup
from vllm.sampling_params import GuidedDecodingParams
from vllm.sampling_params import StructuredOutputsParams
from vllm.utils import get_open_port

if is_vllm_ascend_available():
Expand Down Expand Up @@ -495,7 +495,7 @@ class GenerateRequest(BaseModel):
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
guided_decoding_regex: Optional[str] = None
structured_outputs_regex: Optional[str] = None
generation_kwargs: dict = field(default_factory=dict)

class GenerateResponse(BaseModel):
Expand Down Expand Up @@ -524,7 +524,7 @@ async def generate(request: GenerateRequest):
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
completion.
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
- `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided, the
model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
`SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains
Expand Down Expand Up @@ -555,11 +555,11 @@ async def generate(request: GenerateRequest):
row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
prompts.append(row)

# Guided decoding, if enabled
if request.guided_decoding_regex is not None:
guided_decoding = GuidedDecodingParams(backend="outlines", regex=request.guided_decoding_regex)
# structured outputs, if enabled
if request.structured_outputs_regex is not None:
structured_outputs = StructuredOutputsParams(backend="outlines", regex=request.structured_outputs_regex)
else:
guided_decoding = None
structured_outputs = None

generation_kwargs = {
"n": request.n,
Expand All @@ -569,7 +569,7 @@ async def generate(request: GenerateRequest):
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"guided_decoding": guided_decoding,
"structured_outputs": structured_outputs,
"logprobs": 0,
}
generation_kwargs.update(request.generation_kwargs)
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ class GRPOConfig(TrainingArguments):
Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
implementation.
vllm_guided_decoding_regex (`str`, *optional*):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
vllm_structured_outputs_regex (`str`, *optional*):
Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled.

> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)

Expand Down Expand Up @@ -428,9 +428,9 @@ class GRPOConfig(TrainingArguments):
"and woken for weight sync and generation."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
vllm_structured_outputs_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."},
)

# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
Expand Down
14 changes: 7 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@

if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from vllm.sampling_params import StructuredOutputsParams

if is_wandb_available():
import wandb
Expand Down Expand Up @@ -543,7 +543,7 @@ def __init__(
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")

# vLLM specific sampling arguments
self.guided_decoding_regex = args.vllm_guided_decoding_regex
self.structured_outputs_regex = args.vllm_structured_outputs_regex

self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation

Expand Down Expand Up @@ -1181,7 +1181,7 @@ def _generate_and_score_completions(
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
structured_outputs_regex=self.structured_outputs_regex,
generation_kwargs=self.args.generation_kwargs,
)
payload = (output["completion_ids"], output["logprobs"])
Expand All @@ -1202,10 +1202,10 @@ def _generate_and_score_completions(

# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
elif self.vllm_mode == "colocate":
if self.guided_decoding_regex:
guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
if self.structured_outputs_regex:
structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex)
else:
guided_decoding = None
structured_outputs = None

generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
Expand All @@ -1215,7 +1215,7 @@ def _generate_and_score_completions(
"top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"guided_decoding": guided_decoding,
"structured_outputs": structured_outputs,
"logprobs": 0, # only return the logprob of the generated token
}
if self.args.generation_kwargs is not None:
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ class may differ from those in [`~transformers.TrainingArguments`].
server is running (start with `trl vllm-serve`).
- `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a
separate server but may cause resource contention with training.
vllm_guided_decoding_regex (`str`, *optional*):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
vllm_structured_outputs_regex (`str`, *optional*):
Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled.

> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)

Expand Down Expand Up @@ -304,9 +304,9 @@ class may differ from those in [`~transformers.TrainingArguments`].
"model implementation."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
vllm_structured_outputs_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."},
)
vllm_gpu_memory_utilization: Optional[float] = field(
default=0.55,
Expand Down
10 changes: 5 additions & 5 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@

if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from vllm.sampling_params import StructuredOutputsParams

if is_wandb_available():
import wandb
Expand Down Expand Up @@ -528,7 +528,7 @@ def __init__(
else:
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
# vLLM specific sampling arguments
self.guided_decoding_regex = args.vllm_guided_decoding_regex
self.structured_outputs_regex = args.vllm_structured_outputs_regex
self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation

# Set up vLLM generation config
Expand All @@ -544,8 +544,8 @@ def __init__(
}
if args.generation_kwargs is not None:
generation_params.update(args.generation_kwargs)
if self.guided_decoding_regex:
generation_params["guided_decoding"] = GuidedDecodingParams(regex=self.guided_decoding_regex)
if self.structured_outputs_regex:
generation_params["structured_outputs"] = StructuredOutputsParams(regex=self.structured_outputs_regex)
self.generation_config = SamplingParams(**generation_params)

# When using vLLM, the main process is responsible for loading the model weights. This can cause process
Expand Down Expand Up @@ -791,7 +791,7 @@ def _generate_vllm_server(self, prompts, images=None):
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.generation_config.max_tokens,
guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None,
structured_outputs_regex=self.structured_outputs_regex if hasattr(self, "structured_outputs_regex") else None,
generation_kwargs=self.args.generation_kwargs,
)
# Flatten: each prompt generates 2 completions
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ class RLOOConfig(TrainingArguments):
Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
implementation.
vllm_guided_decoding_regex (`str`, *optional*):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
vllm_structured_outputs_regex (`str`, *optional*):
Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled.

> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)

Expand Down Expand Up @@ -507,9 +507,9 @@ class RLOOConfig(TrainingArguments):
"model implementation."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
vllm_structured_outputs_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."},
)

# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
Expand Down
Loading