diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index df9a522b86..de42c7ae8f 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -160,16 +160,35 @@ def encoderConf(self) -> RobertaEncoderConf: return self._encoder_conf -XLMR_BASE_ENCODER = RobertaBundle( - _path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"), - _encoder_conf=RobertaEncoderConf(vocab_size=250002), - transform=lambda: T.Sequential( +def xlmr_transform(truncate_length: int) -> Module: + """Standard transform for XLMR models.""" + return T.Sequential( T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")), T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))), - T.Truncate(254), + T.Truncate(truncate_length), T.AddToken(token=0, begin=True), T.AddToken(token=2, begin=False), - ), + ) + + +def roberta_transform(truncate_length: int) -> Module: + """Standard transform for RoBERTa models.""" + return T.Sequential( + T.GPT2BPETokenizer( + encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), + vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), + ), + T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), + T.Truncate(truncate_length), + T.AddToken(token=0, begin=True), + T.AddToken(token=2, begin=False), + ) + + +XLMR_BASE_ENCODER = RobertaBundle( + _path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"), + _encoder_conf=RobertaEncoderConf(vocab_size=250002), + transform=lambda: xlmr_transform(254), ) XLMR_BASE_ENCODER.__doc__ = """ @@ -193,13 +212,7 @@ def encoderConf(self) -> RobertaEncoderConf: _encoder_conf=RobertaEncoderConf( vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24 ), - transform=lambda: T.Sequential( - T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))), - T.Truncate(510), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - ), + transform=lambda: xlmr_transform(510), ) XLMR_LARGE_ENCODER.__doc__ = """ @@ -221,16 +234,7 @@ def encoderConf(self) -> RobertaEncoderConf: ROBERTA_BASE_ENCODER = RobertaBundle( _path=urljoin(_TEXT_BUCKET, "roberta.base.encoder.pt"), _encoder_conf=RobertaEncoderConf(vocab_size=50265), - transform=lambda: T.Sequential( - T.GPT2BPETokenizer( - encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), - vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), - ), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), - T.Truncate(254), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - ), + transform=lambda: roberta_transform(254), ) ROBERTA_BASE_ENCODER.__doc__ = """ @@ -263,16 +267,7 @@ def encoderConf(self) -> RobertaEncoderConf: num_attention_heads=16, num_encoder_layers=24, ), - transform=lambda: T.Sequential( - T.GPT2BPETokenizer( - encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), - vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), - ), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), - T.Truncate(510), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - ), + transform=lambda: roberta_transform(510), ) ROBERTA_LARGE_ENCODER.__doc__ = """ @@ -302,16 +297,7 @@ def encoderConf(self) -> RobertaEncoderConf: num_encoder_layers=6, padding_idx=1, ), - transform=lambda: T.Sequential( - T.GPT2BPETokenizer( - encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), - vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), - ), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), - T.Truncate(510), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - ), + transform=lambda: roberta_transform(510), ) ROBERTA_DISTILLED_ENCODER.__doc__ = """