Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Basic distilling. #6527

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,30 @@ class LoraArguments:
)


@dataclass
class DistillingArguments:
r"""
Arguments pertaining to the distilling training.
"""

distilling_lambda: float = field(
default=0.5,
metadata={"help": "The lambda parameter in the distilling loss."},
)
distilling_temperature: float = field(
default=1.0,
metadata={"help": "The temperature parameter in the distilling softmax."},
)
teacher_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the teacher model used for the distilling."},
)
teacher_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the teacher model."},
)


@dataclass
class RLHFArguments:
r"""
Expand Down Expand Up @@ -334,7 +358,13 @@ class SwanLabArguments:

@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
FreezeArguments,
LoraArguments,
RLHFArguments,
GaloreArguments,
BAdamArgument,
SwanLabArguments,
DistillingArguments,
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
Expand All @@ -344,7 +374,7 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "distilling"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
Expand Down
18 changes: 18 additions & 0 deletions src/llamafactory/train/distilling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .workflow import run_distilling


__all__ = ["run_distilling"]
190 changes: 190 additions & 0 deletions src/llamafactory/train/distilling/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from transformers import Seq2SeqTrainer
from typing_extensions import override

from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler


if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput

from ...hparams import FinetuningArguments


logger = logging.get_logger(__name__)


class CustomDistillingTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""

def __init__(
self,
teacher_model: Union["PreTrainedModel", torch.nn.Module],
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
**kwargs,
):
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")

self.teacher_model = teacher_model
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work on DDP setting?


super().__init__(**kwargs)
self.finetuning_args = finetuning_args

if processor is not None:
self.add_callback(SaveProcessorCallback(processor))

if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()

@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)

@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)

return super()._get_train_sampler()

@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
labels = inputs.get("labels")
padding_mask = labels.eq(-100)
label_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs)
with torch.no_grad():
teacher_outputs = self.teacher_model(**inputs)
# Shape: (batch_size, seq_len, vocab_size)
teacher_prob = torch.nn.functional.softmax(
teacher_outputs.logits / self.finetuning_args.distilling_temperature, dim=-1
)
student_logprob = torch.nn.functional.log_softmax(
outputs.logits / self.finetuning_args.distilling_temperature, dim=-1
)
kl_losses = (teacher_prob * (teacher_prob.log() - student_logprob)).sum(dim=-1)
kl_losses.masked_fill_(padding_mask, 0)
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
loss = (
self.finetuning_args.distilling_lambda
* kl_losses.mean()
/ (num_active_elements * student_logprob.shape[-1])
+ label_loss
)

if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
loss = loss / self.args.gradient_accumulation_steps

return (loss, outputs) if return_outputs else loss

@override
def prediction_step(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the prediction in distill trainer since sft trainer already provided this function

self,
model: "torch.nn.Module",
inputs: Dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Removes the prompt part in the generated tokens.

Subclass and override to inject custom behavior.
"""
if self.args.predict_with_generate: # do not pass labels to model when generate
labels = inputs.pop("labels", None)
else:
labels = inputs.get("labels")

loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
generated_tokens = generated_tokens.contiguous()

return loss, generated_tokens, labels

def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None:
r"""
Saves model predictions to `output_dir`.

A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return

output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")

labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
)
preds = np.where(
predict_results.predictions != IGNORE_INDEX,
predict_results.predictions,
self.processing_class.pad_token_id,
)

for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)

decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)

with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
Loading