-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: Basic distilling. #6527
Open
marko1616
wants to merge
3
commits into
hiyouga:main
Choose a base branch
from
marko1616:feat/distilling
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+383
−2
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2024 the LlamaFactory team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .workflow import run_distilling | ||
|
||
|
||
__all__ = ["run_distilling"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. | ||
# | ||
# This code is inspired by the HuggingFace's transformers library. | ||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import os | ||
from types import MethodType | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
from transformers import Seq2SeqTrainer | ||
from typing_extensions import override | ||
|
||
from ...extras import logging | ||
from ...extras.constants import IGNORE_INDEX | ||
from ...extras.packages import is_transformers_version_greater_than | ||
from ..callbacks import SaveProcessorCallback | ||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler | ||
|
||
|
||
if TYPE_CHECKING: | ||
from torch.utils.data import Dataset | ||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin | ||
from transformers.trainer import PredictionOutput | ||
|
||
from ...hparams import FinetuningArguments | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class CustomDistillingTrainer(Seq2SeqTrainer): | ||
r""" | ||
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
teacher_model: Union["PreTrainedModel", torch.nn.Module], | ||
finetuning_args: "FinetuningArguments", | ||
processor: Optional["ProcessorMixin"], | ||
**kwargs, | ||
): | ||
if is_transformers_version_greater_than("4.46"): | ||
kwargs["processing_class"] = kwargs.pop("tokenizer") | ||
else: | ||
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer") | ||
|
||
self.teacher_model = teacher_model | ||
|
||
super().__init__(**kwargs) | ||
self.finetuning_args = finetuning_args | ||
|
||
if processor is not None: | ||
self.add_callback(SaveProcessorCallback(processor)) | ||
|
||
if finetuning_args.use_badam: | ||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore | ||
|
||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) | ||
self.add_callback(BAdamCallback) | ||
|
||
@override | ||
def create_optimizer(self) -> "torch.optim.Optimizer": | ||
if self.optimizer is None: | ||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) | ||
return super().create_optimizer() | ||
|
||
@override | ||
def create_scheduler( | ||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None | ||
) -> "torch.optim.lr_scheduler.LRScheduler": | ||
create_custom_scheduler(self.args, num_training_steps, optimizer) | ||
return super().create_scheduler(num_training_steps, optimizer) | ||
|
||
@override | ||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: | ||
if self.finetuning_args.disable_shuffling: | ||
return torch.utils.data.SequentialSampler(self.train_dataset) | ||
|
||
return super()._get_train_sampler() | ||
|
||
@override | ||
def compute_loss( | ||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs | ||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: | ||
labels = inputs.get("labels") | ||
padding_mask = labels.eq(-100) | ||
label_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs) | ||
with torch.no_grad(): | ||
teacher_outputs = self.teacher_model(**inputs) | ||
# Shape: (batch_size, seq_len, vocab_size) | ||
teacher_prob = torch.nn.functional.softmax( | ||
teacher_outputs.logits / self.finetuning_args.distilling_temperature, dim=-1 | ||
) | ||
student_logprob = torch.nn.functional.log_softmax( | ||
outputs.logits / self.finetuning_args.distilling_temperature, dim=-1 | ||
) | ||
kl_losses = (teacher_prob * (teacher_prob.log() - student_logprob)).sum(dim=-1) | ||
kl_losses.masked_fill_(padding_mask, 0) | ||
num_active_elements = padding_mask.numel() - padding_mask.long().sum() | ||
loss = ( | ||
self.finetuning_args.distilling_lambda | ||
* kl_losses.mean() | ||
/ (num_active_elements * student_logprob.shape[-1]) | ||
+ label_loss | ||
) | ||
|
||
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): | ||
loss = loss / self.args.gradient_accumulation_steps | ||
|
||
return (loss, outputs) if return_outputs else loss | ||
|
||
@override | ||
def prediction_step( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can remove the prediction in distill trainer since sft trainer already provided this function |
||
self, | ||
model: "torch.nn.Module", | ||
inputs: Dict[str, Union["torch.Tensor", Any]], | ||
prediction_loss_only: bool, | ||
ignore_keys: Optional[List[str]] = None, | ||
**gen_kwargs, | ||
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: | ||
r""" | ||
Removes the prompt part in the generated tokens. | ||
|
||
Subclass and override to inject custom behavior. | ||
""" | ||
if self.args.predict_with_generate: # do not pass labels to model when generate | ||
labels = inputs.pop("labels", None) | ||
else: | ||
labels = inputs.get("labels") | ||
|
||
loss, generated_tokens, _ = super().prediction_step( | ||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs | ||
) | ||
if generated_tokens is not None and self.args.predict_with_generate: | ||
generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id | ||
generated_tokens = generated_tokens.contiguous() | ||
|
||
return loss, generated_tokens, labels | ||
|
||
def save_predictions( | ||
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True | ||
) -> None: | ||
r""" | ||
Saves model predictions to `output_dir`. | ||
|
||
A custom behavior that not contained in Seq2SeqTrainer. | ||
""" | ||
if not self.is_world_process_zero(): | ||
return | ||
|
||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") | ||
logger.info_rank0(f"Saving prediction results to {output_prediction_file}") | ||
|
||
labels = np.where( | ||
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id | ||
) | ||
preds = np.where( | ||
predict_results.predictions != IGNORE_INDEX, | ||
predict_results.predictions, | ||
self.processing_class.pad_token_id, | ||
) | ||
|
||
for i in range(len(preds)): | ||
pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0] | ||
if len(pad_len): # move pad token to last | ||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) | ||
|
||
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) | ||
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) | ||
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens) | ||
|
||
with open(output_prediction_file, "w", encoding="utf-8") as f: | ||
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels): | ||
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it work on DDP setting?