Skip to content

Commit

Permalink
Add XLMR and RoBERTa transforms as factory functions (#2102)
Browse files Browse the repository at this point in the history
* Add XLMR and RoBERTa transforms as classes

* Remove unused import

* Add roberta transform and xlmr transform as factory methods
  • Loading branch information
joecummings authored Mar 9, 2023
1 parent 9d42632 commit 46e7eef
Showing 1 changed file with 29 additions and 43 deletions.
72 changes: 29 additions & 43 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,35 @@ def encoderConf(self) -> RobertaEncoderConf:
return self._encoder_conf


XLMR_BASE_ENCODER = RobertaBundle(
_path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
_encoder_conf=RobertaEncoderConf(vocab_size=250002),
transform=lambda: T.Sequential(
def xlmr_transform(truncate_length: int) -> Module:
"""Standard transform for XLMR models."""
return T.Sequential(
T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")),
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))),
T.Truncate(254),
T.Truncate(truncate_length),
T.AddToken(token=0, begin=True),
T.AddToken(token=2, begin=False),
),
)


def roberta_transform(truncate_length: int) -> Module:
"""Standard transform for RoBERTa models."""
return T.Sequential(
T.GPT2BPETokenizer(
encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"),
vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"),
),
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))),
T.Truncate(truncate_length),
T.AddToken(token=0, begin=True),
T.AddToken(token=2, begin=False),
)


XLMR_BASE_ENCODER = RobertaBundle(
_path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
_encoder_conf=RobertaEncoderConf(vocab_size=250002),
transform=lambda: xlmr_transform(254),
)

XLMR_BASE_ENCODER.__doc__ = """
Expand All @@ -193,13 +212,7 @@ def encoderConf(self) -> RobertaEncoderConf:
_encoder_conf=RobertaEncoderConf(
vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24
),
transform=lambda: T.Sequential(
T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")),
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))),
T.Truncate(510),
T.AddToken(token=0, begin=True),
T.AddToken(token=2, begin=False),
),
transform=lambda: xlmr_transform(510),
)

XLMR_LARGE_ENCODER.__doc__ = """
Expand All @@ -221,16 +234,7 @@ def encoderConf(self) -> RobertaEncoderConf:
ROBERTA_BASE_ENCODER = RobertaBundle(
_path=urljoin(_TEXT_BUCKET, "roberta.base.encoder.pt"),
_encoder_conf=RobertaEncoderConf(vocab_size=50265),
transform=lambda: T.Sequential(
T.GPT2BPETokenizer(
encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"),
vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"),
),
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))),
T.Truncate(254),
T.AddToken(token=0, begin=True),
T.AddToken(token=2, begin=False),
),
transform=lambda: roberta_transform(254),
)

ROBERTA_BASE_ENCODER.__doc__ = """
Expand Down Expand Up @@ -263,16 +267,7 @@ def encoderConf(self) -> RobertaEncoderConf:
num_attention_heads=16,
num_encoder_layers=24,
),
transform=lambda: T.Sequential(
T.GPT2BPETokenizer(
encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"),
vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"),
),
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))),
T.Truncate(510),
T.AddToken(token=0, begin=True),
T.AddToken(token=2, begin=False),
),
transform=lambda: roberta_transform(510),
)

ROBERTA_LARGE_ENCODER.__doc__ = """
Expand Down Expand Up @@ -302,16 +297,7 @@ def encoderConf(self) -> RobertaEncoderConf:
num_encoder_layers=6,
padding_idx=1,
),
transform=lambda: T.Sequential(
T.GPT2BPETokenizer(
encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"),
vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"),
),
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))),
T.Truncate(510),
T.AddToken(token=0, begin=True),
T.AddToken(token=2, begin=False),
),
transform=lambda: roberta_transform(510),
)

ROBERTA_DISTILLED_ENCODER.__doc__ = """
Expand Down

0 comments on commit 46e7eef

Please sign in to comment.