diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 25f24373cc9..caf637b69b3 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1606,6 +1606,36 @@ def test_training_sequence_importance_sampling(self): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_training_with_chat_template_kwargs(self): + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + report_to="none", + chat_template_kwargs={"enable_thinking": False}, + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3ForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_mismatched_reward_processing_classes_length(self): """Test that mismatched length between reward_funcs and reward_processing_classes raises error.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index a6cf327ac71..06583a2e089 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1307,6 +1307,37 @@ def reward_func(completions, **kwargs): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_training_with_chat_template_kwargs(self): + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + training_args = RLOOConfig( + bf16=False, + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + report_to="none", + chat_template_kwargs={"enable_thinking": False}, + ) + trainer = RLOOTrainer( + model="trl-internal-testing/tiny-Qwen3ForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_mismatched_reward_processing_classes_length(self): """Test that mismatched length between reward_funcs and reward_processing_classes raises error.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index e6299927ab6..81aae3fcc5d 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -81,6 +81,13 @@ class GRPOConfig(TrainingArguments): min_p (`float`, *optional*): Minimum token probability, which will be scaled by the probability of the most likely token. It must be a value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. + generation_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or + `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the + generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict + with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. + chat_template_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the `apply_chat_template` function when generating completions. repetition_penalty (`float`, *optional*, defaults to `1.0`): Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat @@ -91,11 +98,6 @@ class GRPOConfig(TrainingArguments): parameter is only effective when `use_vllm` is set to `False`. cache_implementation (`str`, *optional*): Implementation of the cache method for faster generation when `use_vllm` is set to `False`. - generation_kwargs (`dict[str, Any]`, *optional*): - Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or - `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the - generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict - with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. > Parameters that control generation acceleration powered by vLLM @@ -375,6 +377,13 @@ class GRPOConfig(TrainingArguments): "conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them." }, ) + chat_template_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating " + "completions." + }, + ) repetition_penalty: float = field( default=1.0, metadata={ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f72cfd97db3..8797415bc04 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -377,6 +377,7 @@ def __init__( self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper + self.chat_template_kwargs = args.chat_template_kwargs or {} self.temperature = args.temperature self.top_p = args.top_p self.top_k = args.top_k @@ -1066,7 +1067,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): messages = [{"messages": p + c} for p, c in zip(prompts, completions)] - texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] else: texts = [p + c for p, c in zip(prompts, completions)] reward_inputs = reward_processing_class( @@ -1146,7 +1150,9 @@ def _generate_single_turn(self, prompts: list): if self.rollout_func is not None: if is_conversational({"prompt": ordered_set_of_prompts[0]}): ordered_set_of_prompts = [ - apply_chat_template({"prompt": p}, self.processing_class)["prompt"] + apply_chat_template( + {"prompt": p}, self.processing_class, **self.chat_template_kwargs + )["prompt"] for p in ordered_set_of_prompts ] output = self.rollout_func( @@ -1157,7 +1163,11 @@ def _generate_single_turn(self, prompts: list): else: if is_conversational({"prompt": ordered_set_of_prompts[0]}): # FIXME: this endpoint doesn't exist in vllm_client - output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + output = self.vllm_client.chat( + prompts=ordered_set_of_prompts, + **sampling_params, + chat_template_kwargs=self.chat_template_kwargs, + ) else: output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) # Extract required fields and collect any extra fields for reward functions @@ -1272,6 +1282,7 @@ def _generate_single_turn(self, prompts: list): add_generation_prompt=True, tokenize=True, return_dict=True, + **self.chat_template_kwargs, ) else: processor_outputs = self.processing_class(text=prompts, **processor_kwargs) @@ -1317,6 +1328,7 @@ def _generate_single_turn(self, prompts: list): add_generation_prompt=True, tokenize=True, return_dict=True, + **self.chat_template_kwargs, ) else: generate_inputs = self.processing_class(text=prompts, **processor_kwargs) @@ -1450,7 +1462,8 @@ def _generate_and_score_completions( # Get forward_kwargs for models with multimodal inputs if images is not None: prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] + for prompt in prompts ] prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") prompt_inputs = super()._prepare_inputs(prompt_inputs) diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 28ee1b6fb46..02a1dc9883f 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -81,6 +81,13 @@ class RLOOConfig(TrainingArguments): min_p (`float`, *optional*): Minimum token probability, which will be scaled by the probability of the most likely token. It must be a value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. + generation_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or + `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the + generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict + with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. + chat_template_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the `apply_chat_template` function when generating completions. repetition_penalty (`float`, *optional*, defaults to `1.0`): Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat @@ -91,11 +98,6 @@ class RLOOConfig(TrainingArguments): parameter is only effective when `use_vllm` is set to `False`. cache_implementation (`str`, *optional*): Implementation of the cache method for faster generation when `use_vllm` is set to `False`. - generation_kwargs (`dict[str, Any]`, *optional*): - Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or - `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the - generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict - with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. > Parameters that control generation acceleration powered by vLLM @@ -327,6 +329,13 @@ class RLOOConfig(TrainingArguments): "conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them." }, ) + chat_template_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating " + "completions." + }, + ) repetition_penalty: float = field( default=1.0, metadata={ diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index d4ce15fccd5..5c1be427738 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -361,6 +361,7 @@ def __init__( self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length self.num_generations = args.num_generations + self.chat_template_kwargs = args.chat_template_kwargs or {} self.temperature = args.temperature self.top_p = args.top_p self.top_k = args.top_k @@ -927,7 +928,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): messages = [{"messages": p + c} for p, c in zip(prompts, completions)] - texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] else: texts = [p + c for p, c in zip(prompts, completions)] reward_inputs = reward_processing_class( @@ -1004,7 +1008,11 @@ def _generate_single_turn(self, prompts: list): } with profiling_context(self, "vLLM.generate"): if is_conversational({"prompt": ordered_set_of_prompts[0]}): - output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + output = self.vllm_client.chat( + prompts=ordered_set_of_prompts, + **sampling_params, + chat_template_kwargs=self.chat_template_kwargs, + ) else: output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) @@ -1097,6 +1105,7 @@ def _generate_single_turn(self, prompts: list): add_generation_prompt=True, tokenize=True, return_dict=True, + **self.chat_template_kwargs, ) else: processor_outputs = self.processing_class(text=prompts, **processor_kwargs) @@ -1140,6 +1149,7 @@ def _generate_single_turn(self, prompts: list): add_generation_prompt=True, tokenize=True, return_dict=True, + **self.chat_template_kwargs, ) else: generate_inputs = self.processing_class(text=prompts, **processor_kwargs) @@ -1265,7 +1275,8 @@ def _generate_and_score_completions( # Get forward_kwargs for models with multimodal inputs if images is not None: prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] + for prompt in prompts ] prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") prompt_inputs = super()._prepare_inputs(prompt_inputs)