Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,36 @@ def test_training_sequence_importance_sampling(self):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_with_chat_template_kwargs(self):
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
report_to="none",
chat_template_kwargs={"enable_thinking": False},
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen3ForCausalLM",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_mismatched_reward_processing_classes_length(self):
"""Test that mismatched length between reward_funcs and reward_processing_classes raises error."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
Expand Down
31 changes: 31 additions & 0 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,37 @@ def reward_func(completions, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_with_chat_template_kwargs(self):
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")

training_args = RLOOConfig(
bf16=False,
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
report_to="none",
chat_template_kwargs={"enable_thinking": False},
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen3ForCausalLM",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_mismatched_reward_processing_classes_length(self):
"""Test that mismatched length between reward_funcs and reward_processing_classes raises error."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
Expand Down
19 changes: 14 additions & 5 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ class GRPOConfig(TrainingArguments):
min_p (`float`, *optional*):
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
generation_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
chat_template_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to the `apply_chat_template` function when generating completions.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
Expand All @@ -91,11 +98,6 @@ class GRPOConfig(TrainingArguments):
parameter is only effective when `use_vllm` is set to `False`.
cache_implementation (`str`, *optional*):
Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
generation_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.

> Parameters that control generation acceleration powered by vLLM

Expand Down Expand Up @@ -375,6 +377,13 @@ class GRPOConfig(TrainingArguments):
"conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them."
},
)
chat_template_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating "
"completions."
},
)
repetition_penalty: float = field(
default=1.0,
metadata={
Expand Down
21 changes: 17 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def __init__(
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.chat_template_kwargs = args.chat_template_kwargs or {}
self.temperature = args.temperature
self.top_p = args.top_p
self.top_k = args.top_k
Expand Down Expand Up @@ -1066,7 +1067,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
texts = [
apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
Expand Down Expand Up @@ -1146,7 +1150,9 @@ def _generate_single_turn(self, prompts: list):
if self.rollout_func is not None:
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
ordered_set_of_prompts = [
apply_chat_template({"prompt": p}, self.processing_class)["prompt"]
apply_chat_template(
{"prompt": p}, self.processing_class, **self.chat_template_kwargs
)["prompt"]
for p in ordered_set_of_prompts
]
output = self.rollout_func(
Expand All @@ -1157,7 +1163,11 @@ def _generate_single_turn(self, prompts: list):
else:
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
# FIXME: this endpoint doesn't exist in vllm_client
output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params)
output = self.vllm_client.chat(
prompts=ordered_set_of_prompts,
**sampling_params,
chat_template_kwargs=self.chat_template_kwargs,
)
else:
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
# Extract required fields and collect any extra fields for reward functions
Expand Down Expand Up @@ -1272,6 +1282,7 @@ def _generate_single_turn(self, prompts: list):
add_generation_prompt=True,
tokenize=True,
return_dict=True,
**self.chat_template_kwargs,
)
else:
processor_outputs = self.processing_class(text=prompts, **processor_kwargs)
Expand Down Expand Up @@ -1317,6 +1328,7 @@ def _generate_single_turn(self, prompts: list):
add_generation_prompt=True,
tokenize=True,
return_dict=True,
**self.chat_template_kwargs,
)
else:
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
Expand Down Expand Up @@ -1450,7 +1462,8 @@ def _generate_and_score_completions(
# Get forward_kwargs for models with multimodal inputs
if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
for prompt in prompts
]
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
prompt_inputs = super()._prepare_inputs(prompt_inputs)
Expand Down
19 changes: 14 additions & 5 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ class RLOOConfig(TrainingArguments):
min_p (`float`, *optional*):
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
generation_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
chat_template_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to the `apply_chat_template` function when generating completions.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
Expand All @@ -91,11 +98,6 @@ class RLOOConfig(TrainingArguments):
parameter is only effective when `use_vllm` is set to `False`.
cache_implementation (`str`, *optional*):
Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
generation_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.

> Parameters that control generation acceleration powered by vLLM

Expand Down Expand Up @@ -327,6 +329,13 @@ class RLOOConfig(TrainingArguments):
"conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them."
},
)
chat_template_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating "
"completions."
},
)
repetition_penalty: float = field(
default=1.0,
metadata={
Expand Down
17 changes: 14 additions & 3 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def __init__(
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = args.max_completion_length
self.num_generations = args.num_generations
self.chat_template_kwargs = args.chat_template_kwargs or {}
self.temperature = args.temperature
self.top_p = args.top_p
self.top_k = args.top_k
Expand Down Expand Up @@ -927,7 +928,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
texts = [
apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
Expand Down Expand Up @@ -1004,7 +1008,11 @@ def _generate_single_turn(self, prompts: list):
}
with profiling_context(self, "vLLM.generate"):
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params)
output = self.vllm_client.chat(
prompts=ordered_set_of_prompts,
**sampling_params,
chat_template_kwargs=self.chat_template_kwargs,
)
else:
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
Expand Down Expand Up @@ -1097,6 +1105,7 @@ def _generate_single_turn(self, prompts: list):
add_generation_prompt=True,
tokenize=True,
return_dict=True,
**self.chat_template_kwargs,
)
else:
processor_outputs = self.processing_class(text=prompts, **processor_kwargs)
Expand Down Expand Up @@ -1140,6 +1149,7 @@ def _generate_single_turn(self, prompts: list):
add_generation_prompt=True,
tokenize=True,
return_dict=True,
**self.chat_template_kwargs,
)
else:
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
Expand Down Expand Up @@ -1265,7 +1275,8 @@ def _generate_and_score_completions(
# Get forward_kwargs for models with multimodal inputs
if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
for prompt in prompts
]
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
prompt_inputs = super()._prepare_inputs(prompt_inputs)
Expand Down
Loading