Skip to content

Commit 6f41b18

Browse files
fix: Remove chat template setting from non-SFT trainer scripts (#4437)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 parent 8d64144 commit 6f41b18

File tree

12 files changed

+0
-39
lines changed

12 files changed

+0
-39
lines changed

examples/scripts/cpo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6565

6666
from trl import CPOConfig, CPOTrainer, ModelConfig, ScriptArguments, get_peft_config
67-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
6867

6968

7069
# Enable logging in a Hugging Face Space
@@ -90,8 +89,6 @@
9089
# Dataset
9190
################
9291
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
93-
if tokenizer.chat_template is None:
94-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
9592

9693
################
9794
# Training

examples/scripts/nash_md.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
get_kbit_device_map,
7474
get_quantization_config,
7575
)
76-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
7776

7877

7978
# Enable logging in a Hugging Face Space
@@ -128,8 +127,6 @@
128127
)
129128
if tokenizer.pad_token is None:
130129
tokenizer.pad_token = tokenizer.eos_token
131-
if tokenizer.chat_template is None:
132-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
133130

134131
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
135132

examples/scripts/online_dpo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
get_peft_config,
7070
get_quantization_config,
7171
)
72-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
7372

7473

7574
# Enable logging in a Hugging Face Space
@@ -131,8 +130,6 @@
131130
trust_remote_code=model_args.trust_remote_code,
132131
**model_kwargs,
133132
)
134-
if tokenizer.chat_template is None:
135-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
136133
if tokenizer.pad_token_id is None:
137134
tokenizer.pad_token = tokenizer.eos_token
138135

examples/scripts/orpo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6565

6666
from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config
67-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
6867

6968

7069
# Enable logging in a Hugging Face Space
@@ -91,8 +90,6 @@
9190
# Dataset
9291
################
9392
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
94-
if tokenizer.chat_template is None:
95-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
9693

9794
################
9895
# Training

examples/scripts/ppo/ppo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
get_peft_config,
4444
get_quantization_config,
4545
)
46-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
4746

4847

4948
# Enable logging in a Hugging Face Space
@@ -106,8 +105,6 @@
106105
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
107106
)
108107
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
109-
if tokenizer.chat_template is None:
110-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
111108
value_model = AutoModelForSequenceClassification.from_pretrained(
112109
training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
113110
)

examples/scripts/ppo/ppo_tldr.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
get_peft_config,
4444
get_quantization_config,
4545
)
46-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
4746

4847

4948
# Enable logging in a Hugging Face Space
@@ -113,8 +112,6 @@
113112
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
114113
)
115114
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
116-
if tokenizer.chat_template is None:
117-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
118115
value_model = AutoModelForSequenceClassification.from_pretrained(
119116
training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
120117
)

examples/scripts/xpo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
get_kbit_device_map,
5858
get_quantization_config,
5959
)
60-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
6160

6261

6362
# Enable logging in a Hugging Face Space
@@ -113,8 +112,6 @@
113112
)
114113
if tokenizer.pad_token is None:
115114
tokenizer.pad_token = tokenizer.eos_token
116-
if tokenizer.chat_template is None:
117-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
118115

119116
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
120117

tests/test_cpo_trainer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
1818

1919
from trl import CPOConfig, CPOTrainer
20-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
2120

2221
from .testing_utils import TrlTestCase, require_peft
2322

@@ -33,15 +32,13 @@ def setup_method(self):
3332
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
3433
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
3534
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
36-
self.t5_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
3735

3836
@pytest.mark.parametrize(
3937
"name, loss_type, config_name",
4038
[
4139
("qwen", "sigmoid", "standard_preference"),
4240
("t5", "hinge", "standard_implicit_prompt_preference"),
4341
("qwen", "ipo", "conversational_preference"),
44-
("t5", "ipo", "conversational_implicit_prompt_preference"),
4542
("qwen", "simpo", "standard_preference"),
4643
("t5", "simpo", "standard_implicit_prompt_preference"),
4744
("qwen", "hinge", "conversational_preference"),

tests/test_gkd_trainer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
2222

2323
from trl import GKDConfig, GKDTrainer
24-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
2524

2625
from .testing_utils import TrlTestCase, require_liger_kernel
2726

@@ -206,10 +205,6 @@ def setup_method(self):
206205
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
207206
self.tokenizer.pad_token = self.tokenizer.eos_token
208207

209-
# Ensure the tokenizer has a chat template
210-
if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None:
211-
self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
212-
213208
def test_gkd_trainer(self):
214209
training_args = GKDConfig(
215210
output_dir=self.tmp_dir,

tests/test_orpo_trainer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
1818

1919
from trl import ORPOConfig, ORPOTrainer
20-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
2120

2221
from .testing_utils import TrlTestCase, require_peft
2322

@@ -33,15 +32,13 @@ def setup_method(self):
3332
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
3433
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
3534
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
36-
self.t5_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
3735

3836
@pytest.mark.parametrize(
3937
"name, config_name",
4038
[
4139
("qwen", "standard_preference"),
4240
("t5", "standard_implicit_prompt_preference"),
4341
("qwen", "conversational_preference"),
44-
("t5", "conversational_implicit_prompt_preference"),
4542
],
4643
)
4744
def test_orpo_trainer(self, name, config_name):

0 commit comments

Comments
 (0)