|
| 1 | +# General Online Logit Distillation (GOLD) Trainer |
| 2 | + |
| 3 | +[](https://huggingface.co/models?other=sft,gold) |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +General Online Logit Distillation (GOLD) is an extension of Universal Logit Distillation (ULD) that supports |
| 8 | +student/teacher pairs with different tokenizers. It aligns the textual spans produced by both tokenizers and merges the |
| 9 | +associated logits so no completion tokens are dropped. This enables cross-tokenizer knowledge distillation, including |
| 10 | +mixed model families (for example, LLaMA students with Qwen teachers). |
| 11 | + |
| 12 | +Key capabilities: |
| 13 | + |
| 14 | +1. **Cross-tokenizer alignment** – GOLD incrementally decodes the student and teacher tokens, groups passages with the same visible text, and merges probabilities inside each group. This guarantees loss terms are computed over the full completion even when token boundaries differ. |
| 15 | +2. **Hybrid ULD loss** – when `uld_use_hybrid_loss` is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher. |
| 16 | +3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`GKDTrainer`](./gkd_trainer.md), so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run. |
| 17 | + |
| 18 | +> [!NOTE] |
| 19 | +> GOLD is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on. |
| 20 | +
|
| 21 | +## Usage tips |
| 22 | + |
| 23 | +The [`GOLDTrainer`] subclasses [`SFTTrainer`] and accepts the same datasets as other TRL trainers (lists of ChatML style |
| 24 | +messages). Important configuration flags on [`GOLDConfig`] include: |
| 25 | + |
| 26 | +* `use_uld_loss` – toggles Universal Logit Distillation. Set this to `True` for cross-tokenizer setups. |
| 27 | +* `teacher_tokenizer_name_or_path` – required when `use_uld_loss=True`; GOLD uses the teacher tokenizer to align tokens. |
| 28 | +* `uld_use_hybrid_loss`, `uld_hybrid_matched_weight`, `uld_hybrid_unmatched_weight` – enables and weights the hybrid |
| 29 | + matched/unmatched loss. |
| 30 | +* `beta`, `lmbda`, `seq_kd` – inherited from `GKDConfig`, controlling the generalized JSD interpolation and on-policy |
| 31 | + sampling ratio. |
| 32 | + |
| 33 | +A minimal end-to-end example: |
| 34 | + |
| 35 | +```python |
| 36 | +from datasets import load_dataset |
| 37 | +from trl.experimental.gold import GOLDConfig, GOLDTrainer |
| 38 | + |
| 39 | +train_dataset = load_dataset( |
| 40 | + "HuggingFaceTB/OpenR1-Math-220k-default-verified", |
| 41 | + "all", |
| 42 | + split="train[:1024]", |
| 43 | +) |
| 44 | + |
| 45 | +trainer = GOLDTrainer( |
| 46 | + model="meta-llama/Llama-3.2-1B-Instruct", |
| 47 | + teacher_model="Qwen/Qwen2.5-0.5B-Instruct", |
| 48 | + args=GOLDConfig(output_dir="gold-model", use_uld_loss=True, teacher_tokenizer_name_or_path="Qwen/Qwen2.5-0.5B-Instruct"), |
| 49 | + train_dataset=train_dataset, |
| 50 | +) |
| 51 | +trainer.train() |
| 52 | +``` |
| 53 | + |
| 54 | +For quick-start workflows you can rely on string identifiers as shown above—the trainer will load the model and tokenizer for you. Explicitly instantiating `AutoModelForCausalLM`, `AutoTokenizer`, or populating `GOLDConfig` is recommended only for advanced use cases where you need fine-grained control over initialization. |
| 55 | + |
| 56 | +A more explicit setup might look like this when you need to customise model loading, tokenizer settings, or training arguments: |
| 57 | + |
| 58 | +```python |
| 59 | +from datasets import load_dataset |
| 60 | +from trl import GOLDConfig, GOLDTrainer |
| 61 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 62 | + |
| 63 | +student_name = "meta-llama/Llama-3.2-1B-Instruct" |
| 64 | +teacher_name = "Qwen/Qwen2.5-0.5B-Instruct" |
| 65 | + |
| 66 | +tokenizer = AutoTokenizer.from_pretrained(student_name) |
| 67 | +if tokenizer.pad_token is None: |
| 68 | + tokenizer.pad_token = tokenizer.eos_token |
| 69 | + |
| 70 | +model = AutoModelForCausalLM.from_pretrained(student_name) |
| 71 | +teacher_model = AutoModelForCausalLM.from_pretrained(teacher_name) |
| 72 | + |
| 73 | +train_dataset = load_dataset( |
| 74 | + "HuggingFaceTB/OpenR1-Math-220k-default-verified", |
| 75 | + "all", |
| 76 | + split="train[:1024]", |
| 77 | +) |
| 78 | + |
| 79 | +training_args = GOLDConfig( |
| 80 | + output_dir="gold-model", |
| 81 | + per_device_train_batch_size=1, |
| 82 | + teacher_model=teacher_name, |
| 83 | + teacher_tokenizer_name_or_path=teacher_name, |
| 84 | + use_uld_loss=True, |
| 85 | + uld_use_hybrid_loss=True, |
| 86 | + uld_hybrid_matched_weight=0.6, |
| 87 | + uld_hybrid_unmatched_weight=0.4, |
| 88 | +) |
| 89 | + |
| 90 | +trainer = GOLDTrainer( |
| 91 | + model=model, |
| 92 | + teacher_model=teacher_model, |
| 93 | + args=training_args, |
| 94 | + processing_class=tokenizer, |
| 95 | + train_dataset=train_dataset, |
| 96 | +) |
| 97 | +trainer.train() |
| 98 | +``` |
| 99 | + |
| 100 | +### Expected dataset type |
| 101 | + |
| 102 | +GOLD requires a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset, e.g.: |
| 103 | + |
| 104 | +```python |
| 105 | +{"messages": [{"role": "user", "content": "What color is the sky?"}, |
| 106 | + {"role": "assistant", "content": "It is blue."}]} |
| 107 | +``` |
| 108 | + |
| 109 | +`GOLDTrainer` keeps the raw messages so the ChatML collator can construct prompts and completions with the correct |
| 110 | +boundaries. |
| 111 | + |
| 112 | +## GOLDTrainer |
| 113 | + |
| 114 | +[[autodoc]] experimental.gold.GOLDTrainer |
| 115 | + - train |
| 116 | + - generate_on_policy_outputs |
| 117 | + - save_model |
| 118 | + - push_to_hub |
| 119 | + |
| 120 | +## GOLDConfig |
| 121 | + |
| 122 | +[[autodoc]] experimental.gold.GOLDConfig |
0 commit comments