Skip to content

Commit ecb2811

Browse files
Add MiniLLM Trainer (#4504)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 89e4688 commit ecb2811

File tree

8 files changed

+727
-6
lines changed

8 files changed

+727
-6
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@
115115
title: GSPO-token
116116
- local: judges
117117
title: Judges
118+
- local: minillm
119+
title: MiniLLM
118120
- local: papo_trainer
119121
title: PAPO
120122
- local: xpo_trainer

docs/source/minillm.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# MiniLLM Trainer
2+
3+
[![All_models-MiniLLM-blue](https://img.shields.io/badge/All_models-MiniLLM-blue)](https://huggingface.co/models?other=minillm,trl)
4+
5+
## Overview
6+
7+
TRL supports the MiniLLM Trainer for distilling large language models into smaller ones using reverse KLD for better precision, quality, and performance, as described in the paper [Knowledge Distillation of Large Language Models](https://huggingface.co/papers/2306.08543) by [Yuxian Gu](https://huggingface.co/t1101675), [Li Dong](https://huggingface.co/unilm), [Furu Wei](https://huggingface.co/thegenerality), and Minlie Huang.
8+
The abstract from the paper is the following:
9+
10+
> Knowledge Distillation (KD) is a promising technique for reducing the high computational demand of large language models (LLMs). However, previous KD methods are primarily applied to white-box classification models or training small models to imitate black-box model APIs like ChatGPT. How to effectively distill the knowledge from white-box generative LLMs is still under-explored, which becomes more and more important with the prosperity of LLMs. In this work, we propose MiniLLM that distills smaller language models from generative larger language models. We first replace the forward Kullback-Leibler divergence (KLD) objective in the standard KD approaches with reverse KLD, which is more suitable for KD on generative language models, to prevent the student model from overestimating the low-probability regions of the teacher distribution. Then, we derive an effective optimization approach to learn this objective. Extensive experiments in the instruction-following setting show that the MiniLLM models generate more precise responses with the higher overall quality, lower exposure bias, better calibration, and higher long-text generation performance. Our method is also scalable for different model families with 120M to 13B parameters. We will release our code and model checkpoints at https://aka.ms/MiniLLM.
11+
12+
This post-training method was contributed by [Yuxian Gu](https://huggingface.co/t1101675).
13+
14+
It is a generalized version of [Think Machine Lab's On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/), with the option to add distribution-level single-step distillation signals (like GKD when `beta=1`) and long-context reverse KLD signals.
15+
16+
$$
17+
\begin{align}
18+
L_{\text{MiniLLM}}&=\alpha_1\mathbb{E}_{x\sim \pi_{\theta}}\sum_{t'=t}^{|x|}\frac{\gamma^{t'-t}}{\sum_{t'}\gamma^{t'-t}}\left[\log \frac{\pi_{\theta}(x_{t'+1}|x_{1..t'})}{\pi_{\text{teacher}}(x_{t'+1}|x_{1..t'})}\right] \\
19+
&+ \alpha_2\mathbb{E}_{x\sim \pi_{\theta}} \text{KL}\left[\pi_\theta(\cdot|x_{1..t})||\pi_{\text{teacher}}(\cdot | x_{1..t})\right].
20+
\end{align}
21+
$$
22+
23+
When \\( \alpha_1=1 \\), \\( \alpha_2=0 \\), \\( \gamma=0 \\), which corresponds to
24+
25+
```python
26+
from trl.experimental.minillm import MiniLLMConfig
27+
28+
training_args = MiniLLMConfig(
29+
rkl_advantage=True,
30+
single_step_decomposition=False,
31+
gamma=False
32+
)
33+
```
34+
35+
\\( L_{\text{MiniLLM}} \\) becomes the on-policy KD implemented in [Tinker](https://github.com/thinking-machines-lab/tinker-cookbook/blob/5d08be6d130596b7bedd02197861c41fa81ea436/tinker_cookbook/distillation/train_on_policy.py#L88):
36+
37+
$$
38+
L_{\text{tinker}}=\mathbb{E}_{x\sim \pi_{\theta}}\left[\log \frac{\pi_{\theta}(x_{t'+1}|x_{1..t'})}{\pi_{\text{teacher}}(x_{t'+1}|x_{1..t'})}\right].
39+
$$
40+
41+
When \\( \alpha_1=0 \\), \\( \alpha_2=1 \\), which corresponds to
42+
43+
```python
44+
from trl.experimental.minillm import MiniLLMConfig
45+
46+
training_args = MiniLLMConfig(
47+
rkl_advantage=False,
48+
single_step_decomposition=True
49+
)
50+
```
51+
52+
\\( L_{\text{MiniLLM}} \\) becomes the reverse KLD version of the GKD loss as in [GKD Trainer](./gkd.md):
53+
54+
$$
55+
L_{\text{GKD-RKL}}=\mathbb{E}_{x\sim \pi_{\theta}} \text{KL}\left[\pi_\theta(\cdot|x_{1..t})||\pi_{\text{teacher}}(\cdot | x_{1..t})\right].
56+
$$
57+
58+
## MiniLLMTrainer
59+
60+
[[autodoc]] experimental.minillm.MiniLLMTrainer
61+
- train
62+
- save_model
63+
- push_to_hub
64+
65+
## MiniLLMConfig
66+
67+
[[autodoc]] experimental.minillm.MiniLLMConfig

docs/source/paper_index.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,3 +671,29 @@ config = GOLDConfig(
671671

672672
)
673673
```
674+
675+
### Knowledge Distillation of Large Language Models
676+
677+
**📜 Paper**: https://huggingface.co/papers/2306.08543
678+
679+
MiniLLM is the first on-policy knowledge distillation method, which minimizes the sequence-level reverse KLD between the teacher and the student model and is optimized by reinforcement learning.
680+
681+
It is a generalized version of [Think Machine Lab's On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/), with the option to add distribution-level single-step distillation signals (like GKD when `beta=1`) and long-context reverse KLD signals.
682+
683+
Alternatively, you can use the [`experimental.MiniLLMTrainer`] and [`experimental.MiniLLMConfig`] to perform MiniLLM distillation as follows:
684+
685+
```python
686+
from datasets import load_dataset
687+
from trl.experimental.minillm import MiniLLMTrainer
688+
689+
dataset = load_dataset("trl-lib/tldr", split="train")
690+
691+
trainer = MiniLLMTrainer(
692+
model="Qwen/Qwen3-0.6B",
693+
teacher_model="Qwen/Qwen3-1.7B",
694+
train_dataset=dataset,
695+
)
696+
trainer.train()
697+
```
698+
699+
For more details, see the [MiniLLM Trainer documentation](minillm) documentation.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
from datasets import load_dataset
18+
19+
from trl.experimental.minillm import MiniLLMConfig, MiniLLMTrainer
20+
21+
from ..testing_utils import TrlTestCase
22+
23+
24+
@pytest.mark.low_priority
25+
class TestMiniLLMTrainer(TrlTestCase):
26+
def test_train(self):
27+
# Get the dataset
28+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
29+
30+
# Initialize the trainer
31+
training_args = MiniLLMConfig(
32+
output_dir=self.tmp_dir,
33+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
34+
num_generations=3, # reduce the number of generations to reduce memory usage
35+
max_completion_length=32, # reduce the completion length to reduce memory usage
36+
report_to="none",
37+
)
38+
trainer = MiniLLMTrainer(
39+
model="trl-internal-testing/small-Qwen3ForCausalLM",
40+
teacher_model="trl-internal-testing/tiny-Qwen3ForCausalLM",
41+
args=training_args,
42+
train_dataset=dataset,
43+
)
44+
45+
# Save the initial parameters to compare them later
46+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
47+
48+
# Train the model
49+
trainer.train()
50+
51+
# Check that the training loss is not None
52+
assert trainer.state.log_history[-1]["train_loss"] is not None
53+
54+
# Check the params have changed
55+
for n, param in previous_trainable_params.items():
56+
new_param = trainer.model.get_parameter(n)
57+
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .minillm_config import MiniLLMConfig
16+
from .minillm_trainer import MiniLLMTrainer
17+
18+
19+
__all__ = ["MiniLLMConfig", "MiniLLMTrainer"]
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import warnings
16+
from dataclasses import dataclass, field
17+
from typing import Any
18+
19+
from transformers import TrainingArguments
20+
21+
from ...trainer.grpo_config import GRPOConfig
22+
23+
24+
@dataclass
25+
class MiniLLMConfig(GRPOConfig):
26+
"""
27+
Configuration class for [`MiniLLMTrainer`].
28+
29+
This class includes only the parameters that are specific to MiniLLM training. For a full list of training
30+
arguments, please refer to the [`~transformers.TrainingArguments`] and [`GRPOConfig`] documentation.
31+
32+
Args:
33+
temperature (`float`, *optional*, defaults to `0.9`):
34+
Temperature for sampling. The higher the temperature, the more random the completions.
35+
lmbda (`float`, *optional*, defaults to `0.5`):
36+
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
37+
student-generated outputs).
38+
beta (`float`, *optional*, defaults to `0.5`):
39+
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
40+
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
41+
max_new_tokens (`int`, *optional*, defaults to `128`):
42+
Maximum number of tokens to generate per completion.
43+
teacher_model_name_or_path (`str`, *optional*):
44+
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
45+
trained.
46+
teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
47+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
48+
from a string.
49+
disable_dropout (`bool`, *optional*, defaults to `True`):
50+
Whether to disable dropout in the model.
51+
seq_kd (`bool`, *optional*, defaults to `False`):
52+
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
53+
teacher-generated output).
54+
"""
55+
56+
teacher_model_init_kwargs: dict[str, Any] | None = field(
57+
default=None,
58+
metadata={
59+
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
60+
"teacher model from a string."
61+
},
62+
)
63+
disable_dropout: bool = field(
64+
default=True,
65+
metadata={"help": "Whether to disable dropouts in `model`."},
66+
)
67+
rkl_advantage: bool = field(
68+
default=True,
69+
metadata={"help": "Whether to add the reverse KL advantage to the reward advantage."},
70+
)
71+
single_step_decomposition: bool = field(
72+
default=True,
73+
metadata={"help": "Whether to use single-step decomposition for the KL divergence computation."},
74+
)
75+
kd_temperature: float = field(
76+
default=1.0,
77+
metadata={
78+
"help": "Temperature for knowledge distillation. Higher temperatures produce softer probability "
79+
"distributions over classes."
80+
},
81+
)
82+
gamma: float = field(
83+
default=0.0,
84+
metadata={"help": "Discount factor for future rewards in reinforcement learning."},
85+
)
86+
length_normalization: bool = field(
87+
default=True,
88+
metadata={"help": "Whether to apply length normalization to the rewards."},
89+
)
90+
91+
def __post_init__(self):
92+
# We do not use the post_init of GRPOConfig because:
93+
# 1. num_generations can be < 2 in MiniLLMConfig. Scale_rewards must be set to "none" to avoid nan.
94+
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
95+
96+
TrainingArguments.__post_init__(self)
97+
98+
self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards)
99+
if self.num_generations == 1:
100+
self.scale_rewards = "none"
101+
102+
num_processes = self.world_size
103+
# The current default effective batch size
104+
if self.generation_batch_size is None and self.steps_per_generation is None:
105+
self.steps_per_generation = self.gradient_accumulation_steps
106+
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
107+
elif self.generation_batch_size is not None and self.steps_per_generation is None:
108+
# Just ensure the value is divisible by the global batch size
109+
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
110+
raise ValueError(
111+
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
112+
f"({self.per_device_train_batch_size * num_processes})."
113+
)
114+
self.steps_per_generation = self.generation_batch_size // (
115+
self.per_device_train_batch_size * num_processes
116+
)
117+
elif self.generation_batch_size is None and self.steps_per_generation is not None:
118+
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
119+
else:
120+
raise ValueError(
121+
"'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time"
122+
)
123+
124+
if self.do_eval and self.eval_strategy != "no":
125+
# Just ensure the value is divisible by the global batch size
126+
if (self.per_device_eval_batch_size * num_processes) % self.num_generations != 0:
127+
raise ValueError(
128+
f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be "
129+
f"divisible by num_generations ({self.num_generations})."
130+
)
131+
132+
# The generation batch must contain full prompt groups (no partials), so it must be divisible by
133+
# num_generations.
134+
if self.generation_batch_size % self.num_generations != 0:
135+
raise ValueError(
136+
f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations "
137+
f"({self.num_generations})."
138+
)
139+
140+
if self.use_liger_loss is not None:
141+
warnings.warn(
142+
"The `use_liger_loss` argument is deprecated and will be removed in version 0.28.0. Please use "
143+
"`use_liger_kernel` instead.",
144+
FutureWarning,
145+
stacklevel=2,
146+
)
147+
self.use_liger_kernel = self.use_liger_loss
148+
149+
if self.delta is not None and self.use_liger_kernel:
150+
raise ValueError("Liger kernel does not support two-sided GRPO loss yet.")

0 commit comments

Comments
 (0)