Skip to content

Commit 4e9ab9f

Browse files
kashifqgallouedec
andauthored
👑 [experimental] GOLD Trainer (#4349)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 parent b82a8f4 commit 4e9ab9f

File tree

8 files changed

+3428
-21
lines changed

8 files changed

+3428
-21
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@
111111
- sections:
112112
- local: bco_trainer
113113
title: BCO
114+
- local: gold_trainer
115+
title: GOLD
114116
- local: openenv
115117
title: OpenEnv Integration
116118
title: Experimental

docs/source/gold_trainer.md

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# General Online Logit Distillation (GOLD) Trainer
2+
3+
[![All_models-GOLD-blue](https://img.shields.io/badge/All_models-GOLD-blue)](https://huggingface.co/models?other=sft,gold)
4+
5+
## Overview
6+
7+
General Online Logit Distillation (GOLD) is an extension of Universal Logit Distillation (ULD) that supports
8+
student/teacher pairs with different tokenizers. It aligns the textual spans produced by both tokenizers and merges the
9+
associated logits so no completion tokens are dropped. This enables cross-tokenizer knowledge distillation, including
10+
mixed model families (for example, LLaMA students with Qwen teachers).
11+
12+
Key capabilities:
13+
14+
1. **Cross-tokenizer alignment** – GOLD incrementally decodes the student and teacher tokens, groups passages with the same visible text, and merges probabilities inside each group. This guarantees loss terms are computed over the full completion even when token boundaries differ.
15+
2. **Hybrid ULD loss** – when `uld_use_hybrid_loss` is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher.
16+
3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`GKDTrainer`](./gkd_trainer.md), so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.
17+
18+
> [!NOTE]
19+
> GOLD is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on.
20+
21+
## Usage tips
22+
23+
The [`GOLDTrainer`] subclasses [`SFTTrainer`] and accepts the same datasets as other TRL trainers (lists of ChatML style
24+
messages). Important configuration flags on [`GOLDConfig`] include:
25+
26+
* `use_uld_loss` – toggles Universal Logit Distillation. Set this to `True` for cross-tokenizer setups.
27+
* `teacher_tokenizer_name_or_path` – required when `use_uld_loss=True`; GOLD uses the teacher tokenizer to align tokens.
28+
* `uld_use_hybrid_loss`, `uld_hybrid_matched_weight`, `uld_hybrid_unmatched_weight` – enables and weights the hybrid
29+
matched/unmatched loss.
30+
* `beta`, `lmbda`, `seq_kd` – inherited from `GKDConfig`, controlling the generalized JSD interpolation and on-policy
31+
sampling ratio.
32+
33+
A minimal end-to-end example:
34+
35+
```python
36+
from datasets import load_dataset
37+
from trl.experimental.gold import GOLDConfig, GOLDTrainer
38+
39+
train_dataset = load_dataset(
40+
"HuggingFaceTB/OpenR1-Math-220k-default-verified",
41+
"all",
42+
split="train[:1024]",
43+
)
44+
45+
trainer = GOLDTrainer(
46+
model="meta-llama/Llama-3.2-1B-Instruct",
47+
teacher_model="Qwen/Qwen2.5-0.5B-Instruct",
48+
args=GOLDConfig(output_dir="gold-model", use_uld_loss=True, teacher_tokenizer_name_or_path="Qwen/Qwen2.5-0.5B-Instruct"),
49+
train_dataset=train_dataset,
50+
)
51+
trainer.train()
52+
```
53+
54+
For quick-start workflows you can rely on string identifiers as shown above—the trainer will load the model and tokenizer for you. Explicitly instantiating `AutoModelForCausalLM`, `AutoTokenizer`, or populating `GOLDConfig` is recommended only for advanced use cases where you need fine-grained control over initialization.
55+
56+
A more explicit setup might look like this when you need to customise model loading, tokenizer settings, or training arguments:
57+
58+
```python
59+
from datasets import load_dataset
60+
from trl import GOLDConfig, GOLDTrainer
61+
from transformers import AutoModelForCausalLM, AutoTokenizer
62+
63+
student_name = "meta-llama/Llama-3.2-1B-Instruct"
64+
teacher_name = "Qwen/Qwen2.5-0.5B-Instruct"
65+
66+
tokenizer = AutoTokenizer.from_pretrained(student_name)
67+
if tokenizer.pad_token is None:
68+
tokenizer.pad_token = tokenizer.eos_token
69+
70+
model = AutoModelForCausalLM.from_pretrained(student_name)
71+
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_name)
72+
73+
train_dataset = load_dataset(
74+
"HuggingFaceTB/OpenR1-Math-220k-default-verified",
75+
"all",
76+
split="train[:1024]",
77+
)
78+
79+
training_args = GOLDConfig(
80+
output_dir="gold-model",
81+
per_device_train_batch_size=1,
82+
teacher_model=teacher_name,
83+
teacher_tokenizer_name_or_path=teacher_name,
84+
use_uld_loss=True,
85+
uld_use_hybrid_loss=True,
86+
uld_hybrid_matched_weight=0.6,
87+
uld_hybrid_unmatched_weight=0.4,
88+
)
89+
90+
trainer = GOLDTrainer(
91+
model=model,
92+
teacher_model=teacher_model,
93+
args=training_args,
94+
processing_class=tokenizer,
95+
train_dataset=train_dataset,
96+
)
97+
trainer.train()
98+
```
99+
100+
### Expected dataset type
101+
102+
GOLD requires a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset, e.g.:
103+
104+
```python
105+
{"messages": [{"role": "user", "content": "What color is the sky?"},
106+
{"role": "assistant", "content": "It is blue."}]}
107+
```
108+
109+
`GOLDTrainer` keeps the raw messages so the ChatML collator can construct prompts and completions with the correct
110+
boundaries.
111+
112+
## GOLDTrainer
113+
114+
[[autodoc]] experimental.gold.GOLDTrainer
115+
- train
116+
- generate_on_policy_outputs
117+
- save_model
118+
- push_to_hub
119+
120+
## GOLDConfig
121+
122+
[[autodoc]] experimental.gold.GOLDConfig

0 commit comments

Comments
 (0)