Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import CPOConfig, CPOTrainer, ModelConfig, ScriptArguments, get_peft_config
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


# Enable logging in a Hugging Face Space
Expand All @@ -90,8 +89,6 @@
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

################
# Training
Expand Down
3 changes: 0 additions & 3 deletions examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
get_kbit_device_map,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


# Enable logging in a Hugging Face Space
Expand Down Expand Up @@ -128,8 +127,6 @@
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

Expand Down
3 changes: 0 additions & 3 deletions examples/scripts/online_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


# Enable logging in a Hugging Face Space
Expand Down Expand Up @@ -131,8 +130,6 @@
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

Expand Down
3 changes: 0 additions & 3 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


# Enable logging in a Hugging Face Space
Expand All @@ -91,8 +90,6 @@
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

################
# Training
Expand Down
3 changes: 0 additions & 3 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


# Enable logging in a Hugging Face Space
Expand Down Expand Up @@ -106,8 +105,6 @@
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
)
Expand Down
3 changes: 0 additions & 3 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


# Enable logging in a Hugging Face Space
Expand Down Expand Up @@ -113,8 +112,6 @@
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
)
Expand Down
3 changes: 0 additions & 3 deletions examples/scripts/xpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
get_kbit_device_map,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


# Enable logging in a Hugging Face Space
Expand Down Expand Up @@ -113,8 +112,6 @@
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

Expand Down
3 changes: 0 additions & 3 deletions tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer

from trl import CPOConfig, CPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

from .testing_utils import TrlTestCase, require_peft

Expand All @@ -33,15 +32,13 @@ def setup_method(self):
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
self.t5_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

@pytest.mark.parametrize(
"name, loss_type, config_name",
[
("qwen", "sigmoid", "standard_preference"),
("t5", "hinge", "standard_implicit_prompt_preference"),
("qwen", "ipo", "conversational_preference"),
("t5", "ipo", "conversational_implicit_prompt_preference"),
("qwen", "simpo", "standard_preference"),
("t5", "simpo", "standard_implicit_prompt_preference"),
("qwen", "hinge", "conversational_preference"),
Expand Down
5 changes: 0 additions & 5 deletions tests/test_gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

from trl import GKDConfig, GKDTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

from .testing_utils import TrlTestCase, require_liger_kernel

Expand Down Expand Up @@ -206,10 +205,6 @@ def setup_method(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token

# Ensure the tokenizer has a chat template
if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None:
self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

def test_gkd_trainer(self):
training_args = GKDConfig(
output_dir=self.tmp_dir,
Expand Down
3 changes: 0 additions & 3 deletions tests/test_orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer

from trl import ORPOConfig, ORPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

from .testing_utils import TrlTestCase, require_peft

Expand All @@ -33,15 +32,13 @@ def setup_method(self):
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
self.t5_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

@pytest.mark.parametrize(
"name, config_name",
[
("qwen", "standard_preference"),
("t5", "standard_implicit_prompt_preference"),
("qwen", "conversational_preference"),
("t5", "conversational_implicit_prompt_preference"),
],
)
def test_orpo_trainer(self, name, config_name):
Expand Down
4 changes: 0 additions & 4 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from transformers.utils import is_peft_available

from trl import PPOConfig, PPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

from .testing_utils import TrlTestCase, require_peft

Expand All @@ -37,9 +36,6 @@ def setup_method(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

if self.tokenizer.chat_template is None:
self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

# Add reward and value models as in ppo.py
reward_model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
self.value_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id, num_labels=1)
Expand Down
3 changes: 0 additions & 3 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,6 @@ def print_rich_table(df: pd.DataFrame) -> None:
console.print(table)


SIMPLE_SFT_CHAT_TEMPLATE = "{% for message in messages %}{{' ' + message['content']}}{% endfor %}{{ eos_token }}"
# SIMPLE_SFT_CHAT_TEMPLATE simply ends things with an EOS token, this helps the SFT model learn to end the completions with EOS tokens

SIMPLE_CHAT_TEMPLATE = "{% for message in messages %}{{message['role'].capitalize() + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"


Expand Down
Loading