Skip to content

Commit

Permalink
Gkd trainer (#1814)
Browse files Browse the repository at this point in the history
* initial

* initial gkd script

* fix output dir name

* smaller max_new_tokens_response size

* fix tab

* use temperature from config

* initial docs

* initial test

* add generalized_jsd_loss

* some docs

* fix order of interpolation

* use log_target=True

* fix formatting

* docstrings

* add peft example

* more docs

* formatting

* fix ordering

* use unwrap_model_for_generation

* initial DataCollatorForLastCompletionLM

* add generation inputs

* logits from the completions

* add eps to probs

* select the logits after removing the padding

* formatting

* interpolate log_probs

* add back online sampling

* update tests

* fix typos

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/_toctree.yml

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use Qwen2

* Update trl/trainer/gkd_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* fixes

* renamed lamda to lmbda due to keyword

* fix config name

* move collator to utils

* fix formatting

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* the larger the lmbda the more on policy it should be

* Use JSD instead of KL

* use DataCollatorForChatML

* fix labels

* use torch_call

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* set default collator to DataCollatorForChatML

* return only the prompts

* fix labels of generated outputs

* formatting

* fix comment

* add missing _prepare_deepspeed

* no attention mask when generating

* update test

* set a sensible max_seq_length

* set default in the collator

* Update tests/test_gkd_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_gkd_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix padding

* formatting

* Update tests/test_gkd_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix tests

* TestGeneralizedJSDLoss

* fix typos

* use a mask to calculate jsd loss

* use the super() training_step after the inputs are created

* fix the docs

* create generate_on_policy_outputs

* loss does not need labels

* use_cache is false when gradient checkpointing is True

* use self.assert

* fix toc

* generate_on_policy_outputs needs token_id

* use papers link

* teacher_model is in eval mode so no need for disabling dropout

* log completions and use_liger

* prompt from train if no eval

* fix logging and add cache empty

* add_generation_prompt=True

* fix prompts

* Update docs/source/gkd_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update examples/scripts/gkd.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* minor doc changes

* fix temp default

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update docs

* fix dataset format

* fix dataset format

* no need for scores in generation

* teacher_model_init_kwargs

* Update _toctree.yml

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* fix

* remove rich

* add determinstic test

* fix code

* use bigger teacher model

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
  • Loading branch information
4 people authored Sep 11, 2024
1 parent 642c4b1 commit 85696aa
Show file tree
Hide file tree
Showing 9 changed files with 940 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
title: Iterative SFT
- local: reward_trainer
title: Reward Model
- local: gkd_trainer
title: GKD Trainer
title: Trainers
- local: models
title: Model Classes
Expand Down
95 changes: 95 additions & 0 deletions docs/source/gkd_trainer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Generalized Knowledge Distillation Trainer

## Overview

Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.

The abstract from the paper is the following:

> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.


The key aspects of GKD are:
1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences.
2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher.

This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Lewis Tunstall](https://huggingface.co/lewtun).

## Usage tips

The GKD Trainer is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs two parameters to be set via the [`GKDConfig`] namely:
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.

The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
> [!WARNING]
> Make sure that `attn_implementation="flash_attention_2" when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture.

The basic API is as follows:

```python
from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)

NUM_DUMMY_SAMPLES = 100

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")

train_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "Hi, how are you?"},
{"role": "assistant", "content": "I'm great thanks"},
]
]
* NUM_DUMMY_SAMPLES
}
)
eval_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "What colour is the sky?"},
{"role": "assistant", "content": "The sky is blue"},
]
]
* NUM_DUMMY_SAMPLES
}
)

args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
model=model,
teacher_model=teacher_model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
```

### Expected dataset format

The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:
* `role`: either `system`, `assistant` or `user`
* `content`: the message content


## GKDTrainer

[[autodoc]] GKDTrainer

## GKDConfig

[[autodoc]] GKDConfig
133 changes: 133 additions & 0 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# Full training:
python examples/scripts/gkd.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
--dataset_name andito/chatbot_arena_completions \
--learning_rate 2e-5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--output_dir gkd-model \
--logging_steps 10 \
--num_train_epochs 1 \
--push_to_hub \
--gradient_checkpointing
# LoRA:
python examples/scripts/gkd.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
--dataset_name andito/chatbot_arena_completions \
--learning_rate 2e-4 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--output_dir gkd-model \
--logging_steps 10 \
--num_train_epochs 1 \
--push_to_hub \
--gradient_checkpointing \
--use_peft \
--lora_r 64 \
--lora_alpha 16
"""

from datasets import load_dataset
from transformers import AutoTokenizer

from trl import (
GKDConfig,
GKDTrainer,
ModelConfig,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.commands.cli_utils import SFTScriptArguments, TrlParser
from trl.trainer.callbacks import LogCompletionsCallback


if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, GKDConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()

################
# Model & Tokenizer
################
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs

teacher_model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
use_cache=True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.teacher_model_init_kwargs = teacher_model_kwargs

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
padding_side="left",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
################
raw_datasets = load_dataset(args.dataset_name)
train_dataset = raw_datasets[args.dataset_train_split]
try:
eval_dataset = raw_datasets[args.dataset_test_split]
prompts = eval_dataset["messages"][:8]
except KeyError:
eval_dataset = None
prompts = train_dataset["messages"][:8]

# remove the last assistant message from the prompts messages and then apply chat template to the prompts
prompts = [prompts[i][:-1] for i in range(len(prompts))]
prompts = [tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) for prompt in prompts]

################
# Training
################
trainer = GKDTrainer(
model=model_config.model_name_or_path,
teacher_model=training_args.teacher_model_name_or_path,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
log_completions_callback = LogCompletionsCallback(prompts)
trainer.add_callback(log_completions_callback)
trainer.train()

trainer.save_model(training_args.output_dir)
Loading

0 comments on commit 85696aa

Please sign in to comment.