The DistilBERT model was proposed in the blog post Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT, and the paper DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. DistilBERT is a small, fast, cheap and light Transformer model trained by distilling BERT base. It has 40% less parameters than google-bert/bert-base-uncased, runs 60% faster while preserving over 95% of BERT's performances as measured on the GLUE language understanding benchmark.
The abstract from the paper is the following:
As Transfer Learning from large-scale pre-trained models becomes more prevalent in Natural Language Processing (NLP), operating these large models in on-the-edge and/or under constrained computational training or inference budgets remains challenging. In this work, we propose a method to pre-train a smaller general-purpose language representation model, called DistilBERT, which can then be fine-tuned with good performances on a wide range of tasks like its larger counterparts. While most prior work investigated the use of distillation for building task-specific models, we leverage knowledge distillation during the pretraining phase and show that it is possible to reduce the size of a BERT model by 40%, while retaining 97% of its language understanding capabilities and being 60% faster. To leverage the inductive biases learned by larger models during pretraining, we introduce a triple loss combining language modeling, distillation and cosine-distance losses. Our smaller, faster and lighter model is cheaper to pre-train and we demonstrate its capabilities for on-device computations in a proof-of-concept experiment and a comparative on-device study.
This model was contributed by victorsanh. This model jax version was contributed by kamalkraj. The original code can be found here.
-
DistilBERT doesn't have
token_type_ids
, you don't need to indicate which token belongs to which segment. Just separate your segments with the separation tokentokenizer.sep_token
(or[SEP]
). -
DistilBERT doesn't have options to select the input positions (
position_ids
input). This could be added if necessary though, just let us know if you need this option. -
Same as BERT but smaller. Trained by distillation of the pretrained BERT model, meaning it’s been trained to predict the same probabilities as the larger model. The actual objective is a combination of:
- finding the same probabilities as the teacher model
- predicting the masked tokens correctly (but no next-sentence objective)
- a cosine similarity between the hidden states of the student and the teacher model
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of torch.nn.functional
. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
official documentation
or the GPU Inference
page for more information.
SDPA is used by default for torch>=2.1.1
when an implementation is available, but you may also set
attn_implementation="sdpa"
in from_pretrained()
to explicitly request SDPA to be used.
from transformers import DistilBertModel
model = DistilBertModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
For the best speedups, we recommend loading the model in half-precision (e.g. torch.float16
or torch.bfloat16
).
On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with float16
and the distilbert-base-uncased
model with
a MaskedLM head, we saw the following speedups during training and inference.
num_training_steps | batch_size | seq_len | is cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
---|---|---|---|---|---|---|---|---|---|
100 | 1 | 128 | False | 0.010 | 0.008 | 28.870 | 397.038 | 399.629 | -0.649 |
100 | 1 | 256 | False | 0.011 | 0.009 | 20.681 | 412.505 | 412.606 | -0.025 |
100 | 2 | 128 | False | 0.011 | 0.009 | 23.741 | 412.213 | 412.606 | -0.095 |
100 | 2 | 256 | False | 0.015 | 0.013 | 16.502 | 427.491 | 425.787 | 0.400 |
100 | 4 | 128 | False | 0.015 | 0.013 | 13.828 | 427.491 | 425.787 | 0.400 |
100 | 4 | 256 | False | 0.025 | 0.022 | 12.882 | 594.156 | 502.745 | 18.182 |
100 | 8 | 128 | False | 0.023 | 0.022 | 8.010 | 545.922 | 502.745 | 8.588 |
100 | 8 | 256 | False | 0.046 | 0.041 | 12.763 | 983.450 | 798.480 | 23.165 |
num_batches | batch_size | seq_len | is cuda | is half | use mask | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
---|---|---|---|---|---|---|---|---|---|---|---|
50 | 2 | 64 | True | True | True | 0.032 | 0.025 | 28.192 | 154.532 | 155.531 | -0.642 |
50 | 2 | 128 | True | True | True | 0.033 | 0.025 | 32.636 | 157.286 | 157.482 | -0.125 |
50 | 4 | 64 | True | True | True | 0.032 | 0.026 | 24.783 | 157.023 | 157.449 | -0.271 |
50 | 4 | 128 | True | True | True | 0.034 | 0.028 | 19.299 | 162.794 | 162.269 | 0.323 |
50 | 8 | 64 | True | True | True | 0.035 | 0.028 | 25.105 | 160.958 | 162.204 | -0.768 |
50 | 8 | 128 | True | True | True | 0.052 | 0.046 | 12.375 | 173.155 | 171.844 | 0.763 |
50 | 16 | 64 | True | True | True | 0.051 | 0.045 | 12.882 | 172.106 | 171.713 | 0.229 |
50 | 16 | 128 | True | True | True | 0.096 | 0.081 | 18.524 | 191.257 | 191.517 | -0.136 |
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DistilBERT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
- A blog post on Getting Started with Sentiment Analysis using Python with DistilBERT.
- A blog post on how to train DistilBERT with Blurr for sequence classification.
- A blog post on how to use Ray to tune DistilBERT hyperparameters.
- A blog post on how to train DistilBERT with Hugging Face and Amazon SageMaker.
- A notebook on how to finetune DistilBERT for multi-label classification. 🌎
- A notebook on how to finetune DistilBERT for multiclass classification with PyTorch. 🌎
- A notebook on how to finetune DistilBERT for text classification in TensorFlow. 🌎
- [
DistilBertForSequenceClassification
] is supported by this example script and notebook. - [
TFDistilBertForSequenceClassification
] is supported by this example script and notebook. - [
FlaxDistilBertForSequenceClassification
] is supported by this example script and notebook. - Text classification task guide
- [
DistilBertForTokenClassification
] is supported by this example script and notebook. - [
TFDistilBertForTokenClassification
] is supported by this example script and notebook. - [
FlaxDistilBertForTokenClassification
] is supported by this example script. - Token classification chapter of the 🤗 Hugging Face Course.
- Token classification task guide
- [
DistilBertForMaskedLM
] is supported by this example script and notebook. - [
TFDistilBertForMaskedLM
] is supported by this example script and notebook. - [
FlaxDistilBertForMaskedLM
] is supported by this example script and notebook. - Masked language modeling chapter of the 🤗 Hugging Face Course.
- Masked language modeling task guide
- [
DistilBertForQuestionAnswering
] is supported by this example script and notebook. - [
TFDistilBertForQuestionAnswering
] is supported by this example script and notebook. - [
FlaxDistilBertForQuestionAnswering
] is supported by this example script. - Question answering chapter of the 🤗 Hugging Face Course.
- Question answering task guide
Multiple choice
- [
DistilBertForMultipleChoice
] is supported by this example script and notebook. - [
TFDistilBertForMultipleChoice
] is supported by this example script and notebook. - Multiple choice task guide
⚗️ Optimization
- A blog post on how to quantize DistilBERT with 🤗 Optimum and Intel.
- A blog post on how Optimizing Transformers for GPUs with 🤗 Optimum.
- A blog post on Optimizing Transformers with Hugging Face Optimum.
⚡️ Inference
- A blog post on how to Accelerate BERT inference with Hugging Face Transformers and AWS Inferentia with DistilBERT.
- A blog post on Serverless Inference with Hugging Face's Transformers, DistilBERT and Amazon SageMaker.
🚀 Deploy
- A blog post on how to deploy DistilBERT on Google Cloud.
- A blog post on how to deploy DistilBERT with Amazon SageMaker.
- A blog post on how to Deploy BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module.
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
pip install -U flash-attn --no-build-isolation
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. torch.float16
)
To load and run a model using Flash Attention 2, refer to the snippet below:
>>> import torch
>>> from transformers import AutoTokenizer, AutoModel
>>> device = "cuda" # the device to load the model onto
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')
>>> model = AutoModel.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> text = "Replace me by any text you'd like."
>>> encoded_input = tokenizer(text, return_tensors='pt').to(device)
>>> model.to(device)
>>> output = model(**encoded_input)
[[autodoc]] DistilBertConfig
[[autodoc]] DistilBertTokenizer
[[autodoc]] DistilBertTokenizerFast
[[autodoc]] DistilBertModel - forward
[[autodoc]] DistilBertForMaskedLM - forward
[[autodoc]] DistilBertForSequenceClassification - forward
[[autodoc]] DistilBertForMultipleChoice - forward
[[autodoc]] DistilBertForTokenClassification - forward
[[autodoc]] DistilBertForQuestionAnswering - forward
[[autodoc]] TFDistilBertModel - call
[[autodoc]] TFDistilBertForMaskedLM - call
[[autodoc]] TFDistilBertForSequenceClassification - call
[[autodoc]] TFDistilBertForMultipleChoice - call
[[autodoc]] TFDistilBertForTokenClassification - call
[[autodoc]] TFDistilBertForQuestionAnswering - call
[[autodoc]] FlaxDistilBertModel - call
[[autodoc]] FlaxDistilBertForMaskedLM - call
[[autodoc]] FlaxDistilBertForSequenceClassification - call
[[autodoc]] FlaxDistilBertForMultipleChoice - call
[[autodoc]] FlaxDistilBertForTokenClassification - call
[[autodoc]] FlaxDistilBertForQuestionAnswering - call