Skip to content

Commit 6d6a603

Browse files
authored
Merge branch 'main' into Issues_#4403
2 parents 9e3f098 + 43253b2 commit 6d6a603

File tree

13 files changed

+44
-39
lines changed

13 files changed

+44
-39
lines changed

docs/source/paper_index.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,47 @@ def add_margin(example):
605605

606606
dataset = dataset.map(add_margin)
607607
```
608+
609+
## Distillation
610+
Papers relating to training a student model with the help of a teacher model.
611+
612+
### On-Policy Distillation
613+
**📰 Blog**: https://thinkingmachines.ai/blog/on-policy-distillation/
614+
615+
On-Policy Distillation involves a student model generating rollouts for each batch of training data. We subsequently obtain the probability distributions for each token of the rollouts from both the student and teacher models. The student model is then optimized to minimize the negative Kullback-Leibler (KL) divergence between its own token distributions and those of the teacher model.
616+
617+
| Method | Sampling | Reward signal |
618+
|-------------------------|------------|---------------|
619+
| Supervised finetuning | off-policy | dense |
620+
| Reinforcement learning | on-policy | sparse |
621+
| On-policy distillation | on-policy | dense |
622+
623+
On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to restore generalization capabilities lost during SFT.
624+
625+
Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data.
626+
627+
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`GKDTrainer`] and [`GKDConfig`]:
628+
629+
```python
630+
from trl import GKDConfig
631+
632+
config = GKDConfig(
633+
lmbda=1.0, # student produces rollouts for all batches
634+
beta=1.0, # to ensure reverse-kl as the loss function
635+
teacher_model_name_or_path="teacher-model", # specify the teacher model
636+
637+
)
638+
```
639+
640+
Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration:
641+
642+
```python
643+
from trl.experimental import GOLDConfig
644+
645+
config = GOLDConfig(
646+
lmbda=1.0, # student produces rollouts for all batches
647+
beta=1.0, # to ensure reverse-kl as the loss function
648+
teacher_model_name_or_path="teacher-model", # specify the teacher model
649+
650+
)
651+
```

examples/scripts/cpo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6565

6666
from trl import CPOConfig, CPOTrainer, ModelConfig, ScriptArguments, get_peft_config
67-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
6867

6968

7069
# Enable logging in a Hugging Face Space
@@ -90,8 +89,6 @@
9089
# Dataset
9190
################
9291
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
93-
if tokenizer.chat_template is None:
94-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
9592

9693
################
9794
# Training

examples/scripts/nash_md.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
get_kbit_device_map,
7474
get_quantization_config,
7575
)
76-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
7776

7877

7978
# Enable logging in a Hugging Face Space
@@ -128,8 +127,6 @@
128127
)
129128
if tokenizer.pad_token is None:
130129
tokenizer.pad_token = tokenizer.eos_token
131-
if tokenizer.chat_template is None:
132-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
133130

134131
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
135132

examples/scripts/online_dpo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
get_peft_config,
7070
get_quantization_config,
7171
)
72-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
7372

7473

7574
# Enable logging in a Hugging Face Space
@@ -131,8 +130,6 @@
131130
trust_remote_code=model_args.trust_remote_code,
132131
**model_kwargs,
133132
)
134-
if tokenizer.chat_template is None:
135-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
136133
if tokenizer.pad_token_id is None:
137134
tokenizer.pad_token = tokenizer.eos_token
138135

examples/scripts/orpo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6565

6666
from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config
67-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
6867

6968

7069
# Enable logging in a Hugging Face Space
@@ -91,8 +90,6 @@
9190
# Dataset
9291
################
9392
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
94-
if tokenizer.chat_template is None:
95-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
9693

9794
################
9895
# Training

examples/scripts/ppo/ppo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
get_peft_config,
4444
get_quantization_config,
4545
)
46-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
4746

4847

4948
# Enable logging in a Hugging Face Space
@@ -106,8 +105,6 @@
106105
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
107106
)
108107
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
109-
if tokenizer.chat_template is None:
110-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
111108
value_model = AutoModelForSequenceClassification.from_pretrained(
112109
training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
113110
)

examples/scripts/ppo/ppo_tldr.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
get_peft_config,
4444
get_quantization_config,
4545
)
46-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
4746

4847

4948
# Enable logging in a Hugging Face Space
@@ -113,8 +112,6 @@
113112
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
114113
)
115114
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
116-
if tokenizer.chat_template is None:
117-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
118115
value_model = AutoModelForSequenceClassification.from_pretrained(
119116
training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
120117
)

examples/scripts/xpo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
get_kbit_device_map,
5858
get_quantization_config,
5959
)
60-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
6160

6261

6362
# Enable logging in a Hugging Face Space
@@ -113,8 +112,6 @@
113112
)
114113
if tokenizer.pad_token is None:
115114
tokenizer.pad_token = tokenizer.eos_token
116-
if tokenizer.chat_template is None:
117-
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
118115

119116
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
120117

tests/test_cpo_trainer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
1818

1919
from trl import CPOConfig, CPOTrainer
20-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
2120

2221
from .testing_utils import TrlTestCase, require_peft
2322

@@ -33,15 +32,13 @@ def setup_method(self):
3332
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
3433
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
3534
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
36-
self.t5_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
3735

3836
@pytest.mark.parametrize(
3937
"name, loss_type, config_name",
4038
[
4139
("qwen", "sigmoid", "standard_preference"),
4240
("t5", "hinge", "standard_implicit_prompt_preference"),
4341
("qwen", "ipo", "conversational_preference"),
44-
("t5", "ipo", "conversational_implicit_prompt_preference"),
4542
("qwen", "simpo", "standard_preference"),
4643
("t5", "simpo", "standard_implicit_prompt_preference"),
4744
("qwen", "hinge", "conversational_preference"),

tests/test_gkd_trainer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
2222

2323
from trl import GKDConfig, GKDTrainer
24-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
2524

2625
from .testing_utils import TrlTestCase, require_liger_kernel
2726

@@ -206,10 +205,6 @@ def setup_method(self):
206205
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
207206
self.tokenizer.pad_token = self.tokenizer.eos_token
208207

209-
# Ensure the tokenizer has a chat template
210-
if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None:
211-
self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
212-
213208
def test_gkd_trainer(self):
214209
training_args = GKDConfig(
215210
output_dir=self.tmp_dir,

0 commit comments

Comments
 (0)