Skip to content

Latest commit

 

History

History
330 lines (222 loc) · 19.4 KB

distilbert.md

File metadata and controls

330 lines (222 loc) · 19.4 KB

DistilBERT

Overview

The DistilBERT model was proposed in the blog post Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT, and the paper DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. DistilBERT is a small, fast, cheap and light Transformer model trained by distilling BERT base. It has 40% less parameters than google-bert/bert-base-uncased, runs 60% faster while preserving over 95% of BERT's performances as measured on the GLUE language understanding benchmark.

The abstract from the paper is the following:

As Transfer Learning from large-scale pre-trained models becomes more prevalent in Natural Language Processing (NLP), operating these large models in on-the-edge and/or under constrained computational training or inference budgets remains challenging. In this work, we propose a method to pre-train a smaller general-purpose language representation model, called DistilBERT, which can then be fine-tuned with good performances on a wide range of tasks like its larger counterparts. While most prior work investigated the use of distillation for building task-specific models, we leverage knowledge distillation during the pretraining phase and show that it is possible to reduce the size of a BERT model by 40%, while retaining 97% of its language understanding capabilities and being 60% faster. To leverage the inductive biases learned by larger models during pretraining, we introduce a triple loss combining language modeling, distillation and cosine-distance losses. Our smaller, faster and lighter model is cheaper to pre-train and we demonstrate its capabilities for on-device computations in a proof-of-concept experiment and a comparative on-device study.

This model was contributed by victorsanh. This model jax version was contributed by kamalkraj. The original code can be found here.

Usage tips

  • DistilBERT doesn't have token_type_ids, you don't need to indicate which token belongs to which segment. Just separate your segments with the separation token tokenizer.sep_token (or [SEP]).

  • DistilBERT doesn't have options to select the input positions (position_ids input). This could be added if necessary though, just let us know if you need this option.

  • Same as BERT but smaller. Trained by distillation of the pretrained BERT model, meaning it’s been trained to predict the same probabilities as the larger model. The actual objective is a combination of:

    • finding the same probabilities as the teacher model
    • predicting the masked tokens correctly (but no next-sentence objective)
    • a cosine similarity between the hidden states of the student and the teacher model

Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of torch.nn.functional. This function encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the official documentation or the GPU Inference page for more information.

SDPA is used by default for torch>=2.1.1 when an implementation is available, but you may also set attn_implementation="sdpa" in from_pretrained() to explicitly request SDPA to be used.

from transformers import DistilBertModel
model = DistilBertModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")

For the best speedups, we recommend loading the model in half-precision (e.g. torch.float16 or torch.bfloat16).

On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with float16 and the distilbert-base-uncased model with a MaskedLM head, we saw the following speedups during training and inference.

Training

num_training_steps batch_size seq_len is cuda Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
100 1 128 False 0.010 0.008 28.870 397.038 399.629 -0.649
100 1 256 False 0.011 0.009 20.681 412.505 412.606 -0.025
100 2 128 False 0.011 0.009 23.741 412.213 412.606 -0.095
100 2 256 False 0.015 0.013 16.502 427.491 425.787 0.400
100 4 128 False 0.015 0.013 13.828 427.491 425.787 0.400
100 4 256 False 0.025 0.022 12.882 594.156 502.745 18.182
100 8 128 False 0.023 0.022 8.010 545.922 502.745 8.588
100 8 256 False 0.046 0.041 12.763 983.450 798.480 23.165

Inference

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 2 64 True True True 0.032 0.025 28.192 154.532 155.531 -0.642
50 2 128 True True True 0.033 0.025 32.636 157.286 157.482 -0.125
50 4 64 True True True 0.032 0.026 24.783 157.023 157.449 -0.271
50 4 128 True True True 0.034 0.028 19.299 162.794 162.269 0.323
50 8 64 True True True 0.035 0.028 25.105 160.958 162.204 -0.768
50 8 128 True True True 0.052 0.046 12.375 173.155 171.844 0.763
50 16 64 True True True 0.051 0.045 12.882 172.106 171.713 0.229
50 16 128 True True True 0.096 0.081 18.524 191.257 191.517 -0.136

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DistilBERT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

Multiple choice

⚗️ Optimization

⚡️ Inference

🚀 Deploy

Combining DistilBERT and Flash Attention 2

First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.

pip install -U flash-attn --no-build-isolation

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. torch.float16)

To load and run a model using Flash Attention 2, refer to the snippet below:

>>> import torch
>>> from transformers import AutoTokenizer, AutoModel

>>> device = "cuda" # the device to load the model onto

>>> tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')
>>> model = AutoModel.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")

>>> text = "Replace me by any text you'd like."

>>> encoded_input = tokenizer(text, return_tensors='pt').to(device)
>>> model.to(device)

>>> output = model(**encoded_input)

DistilBertConfig

[[autodoc]] DistilBertConfig

DistilBertTokenizer

[[autodoc]] DistilBertTokenizer

DistilBertTokenizerFast

[[autodoc]] DistilBertTokenizerFast

DistilBertModel

[[autodoc]] DistilBertModel - forward

DistilBertForMaskedLM

[[autodoc]] DistilBertForMaskedLM - forward

DistilBertForSequenceClassification

[[autodoc]] DistilBertForSequenceClassification - forward

DistilBertForMultipleChoice

[[autodoc]] DistilBertForMultipleChoice - forward

DistilBertForTokenClassification

[[autodoc]] DistilBertForTokenClassification - forward

DistilBertForQuestionAnswering

[[autodoc]] DistilBertForQuestionAnswering - forward

TFDistilBertModel

[[autodoc]] TFDistilBertModel - call

TFDistilBertForMaskedLM

[[autodoc]] TFDistilBertForMaskedLM - call

TFDistilBertForSequenceClassification

[[autodoc]] TFDistilBertForSequenceClassification - call

TFDistilBertForMultipleChoice

[[autodoc]] TFDistilBertForMultipleChoice - call

TFDistilBertForTokenClassification

[[autodoc]] TFDistilBertForTokenClassification - call

TFDistilBertForQuestionAnswering

[[autodoc]] TFDistilBertForQuestionAnswering - call

FlaxDistilBertModel

[[autodoc]] FlaxDistilBertModel - call

FlaxDistilBertForMaskedLM

[[autodoc]] FlaxDistilBertForMaskedLM - call

FlaxDistilBertForSequenceClassification

[[autodoc]] FlaxDistilBertForSequenceClassification - call

FlaxDistilBertForMultipleChoice

[[autodoc]] FlaxDistilBertForMultipleChoice - call

FlaxDistilBertForTokenClassification

[[autodoc]] FlaxDistilBertForTokenClassification - call

FlaxDistilBertForQuestionAnswering

[[autodoc]] FlaxDistilBertForQuestionAnswering - call