From f4a7a87166968e012f816179e5d0dc9ab5901af3 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Wed, 8 Mar 2023 12:32:07 -0500 Subject: [PATCH 1/3] Add XLMR and RoBERTa transforms as classes --- torchtext/models/roberta/bundler.py | 50 ++++---------------------- torchtext/models/roberta/transform.py | 52 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 44 deletions(-) create mode 100644 torchtext/models/roberta/transform.py diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index df9a522b86..3f2074fc67 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -14,6 +14,7 @@ from torchtext import _TEXT_BUCKET from .model import RobertaEncoderConf, RobertaModel +from .transform import RobertaTransform, XLMRTransform def _is_head_available_in_checkpoint(checkpoint, head_state_dict): @@ -163,13 +164,7 @@ def encoderConf(self) -> RobertaEncoderConf: XLMR_BASE_ENCODER = RobertaBundle( _path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"), _encoder_conf=RobertaEncoderConf(vocab_size=250002), - transform=lambda: T.Sequential( - T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))), - T.Truncate(254), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - ), + transform=lambda: XLMRTransform(truncate_length=254), ) XLMR_BASE_ENCODER.__doc__ = """ @@ -193,13 +188,7 @@ def encoderConf(self) -> RobertaEncoderConf: _encoder_conf=RobertaEncoderConf( vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24 ), - transform=lambda: T.Sequential( - T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))), - T.Truncate(510), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - ), + transform=lambda: XLMRTransform(truncate_length=510), ) XLMR_LARGE_ENCODER.__doc__ = """ @@ -221,16 +210,7 @@ def encoderConf(self) -> RobertaEncoderConf: ROBERTA_BASE_ENCODER = RobertaBundle( _path=urljoin(_TEXT_BUCKET, "roberta.base.encoder.pt"), _encoder_conf=RobertaEncoderConf(vocab_size=50265), - transform=lambda: T.Sequential( - T.GPT2BPETokenizer( - encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), - vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), - ), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), - T.Truncate(254), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - ), + transform=lambda: RobertaTransform(truncate_length=254), ) ROBERTA_BASE_ENCODER.__doc__ = """ @@ -263,16 +243,7 @@ def encoderConf(self) -> RobertaEncoderConf: num_attention_heads=16, num_encoder_layers=24, ), - transform=lambda: T.Sequential( - T.GPT2BPETokenizer( - encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), - vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), - ), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), - T.Truncate(510), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - ), + transform=lambda: RobertaTransform(truncate_length=510), ) ROBERTA_LARGE_ENCODER.__doc__ = """ @@ -302,16 +273,7 @@ def encoderConf(self) -> RobertaEncoderConf: num_encoder_layers=6, padding_idx=1, ), - transform=lambda: T.Sequential( - T.GPT2BPETokenizer( - encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), - vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), - ), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), - T.Truncate(510), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - ), + transform=lambda: RobertaTransform(truncate_length=510), ) ROBERTA_DISTILLED_ENCODER.__doc__ = """ diff --git a/torchtext/models/roberta/transform.py b/torchtext/models/roberta/transform.py new file mode 100644 index 0000000000..77efeb4350 --- /dev/null +++ b/torchtext/models/roberta/transform.py @@ -0,0 +1,52 @@ +from typing import List, Union +from urllib.parse import urljoin + +import torch +import torchtext.transforms as T +from torch.nn import Module +from torchtext import _TEXT_BUCKET +from torchtext._download_hooks import load_state_dict_from_url + + +class RobertaTransform(Module): + """Standard transform for RoBERTa model.""" + + def __init__(self, truncate_length: int) -> None: + super().__init__() + self.transform = T.Sequential( + T.GPT2BPETokenizer( + encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), + vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), + ), + T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), + T.Truncate(truncate_length), + T.AddToken(token=0, begin=True), + T.AddToken(token=2, begin=False), + T.ToTensor(padding_value=0), + ) + + def forward(self, inputs: Union[str, List[str]]) -> torch.Tensor: + transformed_outputs = self.transform(inputs) + assert torch.jit.isinstance(transformed_outputs, torch.Tensor) + return transformed_outputs + + +class XLMRTransform(Module): + """Standard transform for XLMR model.""" + + def __init__(self, truncate_length: int) -> None: + super().__init__() + self.tokenizer = T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")) + self.transform = T.Sequential( + T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")), + T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))), + T.Truncate(truncate_length), + T.AddToken(token=0, begin=True), + T.AddToken(token=2, begin=False), + T.ToTensor(padding_value=0), + ) + + def forward(self, inputs: Union[str, List[str]]) -> torch.Tensor: + transformed_outputs = self.transform(inputs) + assert torch.jit.isinstance(transformed_outputs, torch.Tensor) + return transformed_outputs From 498036c0119f141459d1737fa798c32b4e7758e6 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Wed, 8 Mar 2023 13:57:48 -0500 Subject: [PATCH 2/3] Remove unused import --- torchtext/models/roberta/bundler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 3f2074fc67..5680b105c0 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -10,7 +10,6 @@ logger = logging.getLogger(__name__) -import torchtext.transforms as T from torchtext import _TEXT_BUCKET from .model import RobertaEncoderConf, RobertaModel From e35cb8e8834d5b74ed12bb2b68fd671f30b49c93 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Thu, 9 Mar 2023 11:15:46 -0500 Subject: [PATCH 3/3] Add roberta transform and xlmr transform as factory methods --- torchtext/models/roberta/bundler.py | 37 +++++++++++++++---- torchtext/models/roberta/transform.py | 52 --------------------------- 2 files changed, 31 insertions(+), 58 deletions(-) delete mode 100644 torchtext/models/roberta/transform.py diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 5680b105c0..de42c7ae8f 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -10,10 +10,10 @@ logger = logging.getLogger(__name__) +import torchtext.transforms as T from torchtext import _TEXT_BUCKET from .model import RobertaEncoderConf, RobertaModel -from .transform import RobertaTransform, XLMRTransform def _is_head_available_in_checkpoint(checkpoint, head_state_dict): @@ -160,10 +160,35 @@ def encoderConf(self) -> RobertaEncoderConf: return self._encoder_conf +def xlmr_transform(truncate_length: int) -> Module: + """Standard transform for XLMR models.""" + return T.Sequential( + T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")), + T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))), + T.Truncate(truncate_length), + T.AddToken(token=0, begin=True), + T.AddToken(token=2, begin=False), + ) + + +def roberta_transform(truncate_length: int) -> Module: + """Standard transform for RoBERTa models.""" + return T.Sequential( + T.GPT2BPETokenizer( + encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), + vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), + ), + T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), + T.Truncate(truncate_length), + T.AddToken(token=0, begin=True), + T.AddToken(token=2, begin=False), + ) + + XLMR_BASE_ENCODER = RobertaBundle( _path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"), _encoder_conf=RobertaEncoderConf(vocab_size=250002), - transform=lambda: XLMRTransform(truncate_length=254), + transform=lambda: xlmr_transform(254), ) XLMR_BASE_ENCODER.__doc__ = """ @@ -187,7 +212,7 @@ def encoderConf(self) -> RobertaEncoderConf: _encoder_conf=RobertaEncoderConf( vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24 ), - transform=lambda: XLMRTransform(truncate_length=510), + transform=lambda: xlmr_transform(510), ) XLMR_LARGE_ENCODER.__doc__ = """ @@ -209,7 +234,7 @@ def encoderConf(self) -> RobertaEncoderConf: ROBERTA_BASE_ENCODER = RobertaBundle( _path=urljoin(_TEXT_BUCKET, "roberta.base.encoder.pt"), _encoder_conf=RobertaEncoderConf(vocab_size=50265), - transform=lambda: RobertaTransform(truncate_length=254), + transform=lambda: roberta_transform(254), ) ROBERTA_BASE_ENCODER.__doc__ = """ @@ -242,7 +267,7 @@ def encoderConf(self) -> RobertaEncoderConf: num_attention_heads=16, num_encoder_layers=24, ), - transform=lambda: RobertaTransform(truncate_length=510), + transform=lambda: roberta_transform(510), ) ROBERTA_LARGE_ENCODER.__doc__ = """ @@ -272,7 +297,7 @@ def encoderConf(self) -> RobertaEncoderConf: num_encoder_layers=6, padding_idx=1, ), - transform=lambda: RobertaTransform(truncate_length=510), + transform=lambda: roberta_transform(510), ) ROBERTA_DISTILLED_ENCODER.__doc__ = """ diff --git a/torchtext/models/roberta/transform.py b/torchtext/models/roberta/transform.py deleted file mode 100644 index 77efeb4350..0000000000 --- a/torchtext/models/roberta/transform.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import List, Union -from urllib.parse import urljoin - -import torch -import torchtext.transforms as T -from torch.nn import Module -from torchtext import _TEXT_BUCKET -from torchtext._download_hooks import load_state_dict_from_url - - -class RobertaTransform(Module): - """Standard transform for RoBERTa model.""" - - def __init__(self, truncate_length: int) -> None: - super().__init__() - self.transform = T.Sequential( - T.GPT2BPETokenizer( - encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), - vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), - ), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))), - T.Truncate(truncate_length), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - T.ToTensor(padding_value=0), - ) - - def forward(self, inputs: Union[str, List[str]]) -> torch.Tensor: - transformed_outputs = self.transform(inputs) - assert torch.jit.isinstance(transformed_outputs, torch.Tensor) - return transformed_outputs - - -class XLMRTransform(Module): - """Standard transform for XLMR model.""" - - def __init__(self, truncate_length: int) -> None: - super().__init__() - self.tokenizer = T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")) - self.transform = T.Sequential( - T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")), - T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))), - T.Truncate(truncate_length), - T.AddToken(token=0, begin=True), - T.AddToken(token=2, begin=False), - T.ToTensor(padding_value=0), - ) - - def forward(self, inputs: Union[str, List[str]]) -> torch.Tensor: - transformed_outputs = self.transform(inputs) - assert torch.jit.isinstance(transformed_outputs, torch.Tensor) - return transformed_outputs