Skip to content

Commit ac6cea8

Browse files
Fix add_generation_prompt arg for paged transformers in GRPO and RLOO trainers (#4370)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 1e39eb6 commit ac6cea8

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,12 +1263,15 @@ def _generate_single_turn(self, prompts: list):
12631263
processor_kwargs = {
12641264
"max_length": self.max_prompt_length,
12651265
"truncation": True,
1266-
"add_generation_prompt": True,
12671266
"add_special_tokens": False,
12681267
}
12691268
if is_conversational({"prompt": prompts[0]}):
12701269
processor_outputs = self.processing_class.apply_chat_template(
1271-
conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True
1270+
conversation=prompts,
1271+
**processor_kwargs,
1272+
add_generation_prompt=True,
1273+
tokenize=True,
1274+
return_dict=True,
12721275
)
12731276
else:
12741277
processor_outputs = self.processing_class(text=prompts, **processor_kwargs)

trl/trainer/rloo_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,13 +1088,13 @@ def _generate_single_turn(self, prompts: list):
10881088
processor_kwargs = {
10891089
"max_length": self.max_prompt_length,
10901090
"truncation": True,
1091-
"add_generation_prompt": True,
10921091
"add_special_tokens": False,
10931092
}
10941093
if is_conversational({"prompt": prompts[0]}):
10951094
processor_outputs = self.processing_class.apply_chat_template(
10961095
conversation=prompts,
10971096
**processor_kwargs,
1097+
add_generation_prompt=True,
10981098
tokenize=True,
10991099
return_dict=True,
11001100
)
@@ -1137,7 +1137,7 @@ def _generate_single_turn(self, prompts: list):
11371137
generate_inputs = self.processing_class.apply_chat_template(
11381138
conversation=prompts,
11391139
**processor_kwargs,
1140-
add_generation_kwargs=True,
1140+
add_generation_prompt=True,
11411141
tokenize=True,
11421142
return_dict=True,
11431143
)

0 commit comments

Comments
 (0)