Skip to content

Commit 7a57fd4

Browse files
authored
MiniLLM: Fix arguments in config & add to documentation index (#4518)
1 parent a145eaf commit 7a57fd4

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

docs/source/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
4949
### Knowledge distillation
5050

5151
- [`GKDTrainer`]
52+
- [`experimental.minillm.MiniLLMTrainer`] 🧪
5253

5354
</div>
5455
</div>

trl/experimental/minillm/minillm_config.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,22 @@ class MiniLLMConfig(GRPOConfig):
3030
arguments, please refer to the [`~transformers.TrainingArguments`] and [`GRPOConfig`] documentation.
3131
3232
Args:
33-
temperature (`float`, *optional*, defaults to `0.9`):
34-
Temperature for sampling. The higher the temperature, the more random the completions.
35-
lmbda (`float`, *optional*, defaults to `0.5`):
36-
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
37-
student-generated outputs).
38-
beta (`float`, *optional*, defaults to `0.5`):
39-
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
40-
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
41-
max_new_tokens (`int`, *optional*, defaults to `128`):
42-
Maximum number of tokens to generate per completion.
43-
teacher_model_name_or_path (`str`, *optional*):
44-
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
45-
trained.
4633
teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
4734
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
4835
from a string.
4936
disable_dropout (`bool`, *optional*, defaults to `True`):
5037
Whether to disable dropout in the model.
51-
seq_kd (`bool`, *optional*, defaults to `False`):
52-
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
53-
teacher-generated output).
38+
rkl_advantage (`bool`, *optional*, defaults to `True`):
39+
Whether to add the reverse KL advantage to the reward advantage.
40+
single_step_decomposition (`bool`, *optional*, defaults to `True`):
41+
Whether to use single-step decomposition for the KL divergence computation.
42+
kd_temperature (`float`, *optional*, defaults to `1.0`):
43+
Temperature for knowledge distillation. Higher temperatures produce softer probability distributions over
44+
classes.
45+
gamma (`float`, *optional*, defaults to `0.0`):
46+
Discount factor for future rewards in reinforcement learning.
47+
length_normalization (`bool`, *optional*, defaults to `True`):
48+
Whether to apply length normalization to the rewards.
5449
"""
5550

5651
teacher_model_init_kwargs: dict[str, Any] | None = field(

0 commit comments

Comments
 (0)