Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DistilRoberta Model to OSS (cherry picked commit) #1998

Merged
merged 3 commits into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ Models
The library currently consist of following pre-trained models:

* RoBERTa: `Base and Large Architecture <https://github.com/pytorch/fairseq/tree/main/examples/roberta#pre-trained-models>`_
* `DistilRoBERTa <https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/README.md>`_
* XLM-RoBERTa: `Base and Large Architure <https://github.com/pytorch/fairseq/tree/main/examples/xlmr#pre-trained-models>`_

Tokenizers
Expand Down
10 changes: 3 additions & 7 deletions test/integration_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torchtext.models import (
ROBERTA_BASE_ENCODER,
ROBERTA_LARGE_ENCODER,
ROBERTA_DISTILLED_ENCODER,
XLMR_BASE_ENCODER,
XLMR_LARGE_ENCODER,
)
Expand All @@ -15,13 +16,7 @@
"xlmr_large": XLMR_LARGE_ENCODER,
"roberta_base": ROBERTA_BASE_ENCODER,
"roberta_large": ROBERTA_LARGE_ENCODER,
}

BUNDLERS = {
"xlmr_base": XLMR_BASE_ENCODER,
"xlmr_large": XLMR_LARGE_ENCODER,
"roberta_base": ROBERTA_BASE_ENCODER,
"roberta_large": ROBERTA_LARGE_ENCODER,
"roberta_distilled": ROBERTA_DISTILLED_ENCODER,
}


Expand All @@ -32,6 +27,7 @@
("xlmr_large",),
("roberta_base",),
("roberta_large",),
("roberta_distilled",),
],
)
class TestRobertaEncoders(TorchtextTestCase):
Expand Down
Binary file not shown.
2 changes: 2 additions & 0 deletions torchtext/models/roberta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .bundler import (
ROBERTA_BASE_ENCODER,
ROBERTA_LARGE_ENCODER,
ROBERTA_DISTILLED_ENCODER,
RobertaBundle,
XLMR_BASE_ENCODER,
XLMR_LARGE_ENCODER,
Expand All @@ -16,4 +17,5 @@
"XLMR_LARGE_ENCODER",
"ROBERTA_BASE_ENCODER",
"ROBERTA_LARGE_ENCODER",
"ROBERTA_DISTILLED_ENCODER",
]
37 changes: 37 additions & 0 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,40 @@ def encoderConf(self) -> RobertaEncoderConf:

Please refer to :func:`torchtext.models.RobertaBundle` for the usage.
"""


ROBERTA_DISTILLED_ENCODER = RobertaBundle(
_path=urljoin(_TEXT_BUCKET, "roberta.distilled.encoder.pt"),
_encoder_conf=RobertaEncoderConf(
num_encoder_layers=6,
padding_idx=1,
),
transform=lambda: T.Sequential(
T.GPT2BPETokenizer(
encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"),
vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"),
),
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))),
T.Truncate(510),
T.AddToken(token=0, begin=True),
T.AddToken(token=2, begin=False),
),
)

ROBERTA_DISTILLED_ENCODER.__doc__ = """
Roberta Encoder with Distilled Weights

DistilRoBERTa is trained using knowledge distillation, a technique to compress a large
model called the teacher into a smaller model called the student. By distillating RoBERTa,
a smaller and faster Transformer model is obtained while maintaining most of the performance.

DistilRoBERTa was pretrained solely on OpenWebTextCorpus, a reproduction of OpenAI's WebText dataset.
On average DistilRoBERTa is twice as fast as RoBERTa Base.

Originally published by Hugging Face under the Apache 2.0 License
and redistributed with the same license.
[`License <https://www.apache.org/licenses/LICENSE-2.0>`__,
`Source <https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation>`__]

Please refer to :func:`torchtext.models.RobertaBundle` for the usage.
"""