-
Notifications
You must be signed in to change notification settings - Fork 811
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add XLMR Base and Large pre-trained models and corresponding transfor…
…mations (#1406)
- Loading branch information
Showing
15 changed files
with
883 additions
and
2 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torchtext | ||
import torch | ||
|
||
from ..common.torchtext_test_case import TorchtextTestCase | ||
from ..common.assets import get_asset_path | ||
|
||
|
||
class TestModels(TorchtextTestCase): | ||
def test_xlmr_base_output(self): | ||
asset_name = "xlmr.base.output.pt" | ||
asset_path = get_asset_path(asset_name) | ||
xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
model = xlmr_base.get_model() | ||
model = model.eval() | ||
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) | ||
actual = model(model_input) | ||
expected = torch.load(asset_path) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_base_jit_output(self): | ||
asset_name = "xlmr.base.output.pt" | ||
asset_path = get_asset_path(asset_name) | ||
xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
model = xlmr_base.get_model() | ||
model = model.eval() | ||
model_jit = torch.jit.script(model) | ||
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) | ||
actual = model_jit(model_input) | ||
expected = torch.load(asset_path) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_large_output(self): | ||
asset_name = "xlmr.large.output.pt" | ||
asset_path = get_asset_path(asset_name) | ||
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER | ||
model = xlmr_base.get_model() | ||
model = model.eval() | ||
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) | ||
actual = model(model_input) | ||
expected = torch.load(asset_path) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_large_jit_output(self): | ||
asset_name = "xlmr.large.output.pt" | ||
asset_path = get_asset_path(asset_name) | ||
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER | ||
model = xlmr_base.get_model() | ||
model = model.eval() | ||
model_jit = torch.jit.script(model) | ||
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) | ||
actual = model_jit(model_input) | ||
expected = torch.load(asset_path) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_transform(self): | ||
xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
transform = xlmr_base.transform() | ||
test_text = "XLMR base Model Comparison" | ||
actual = transform([test_text]) | ||
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_transform_jit(self): | ||
xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
transform = xlmr_base.transform() | ||
transform_jit = torch.jit.script(transform) | ||
test_text = "XLMR base Model Comparison" | ||
actual = transform_jit([test_text]) | ||
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] | ||
torch.testing.assert_close(actual, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
from torchtext import functional | ||
from .common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
class TestFunctional(TorchtextTestCase): | ||
def test_to_tensor(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
padding_value = 0 | ||
actual = functional.to_tensor(input, padding_value=padding_value) | ||
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_to_tensor_jit(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
padding_value = 0 | ||
to_tensor_jit = torch.jit.script(functional.to_tensor) | ||
actual = to_tensor_jit(input, padding_value=padding_value) | ||
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_truncate(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
max_seq_len = 2 | ||
actual = functional.truncate(input, max_seq_len=max_seq_len) | ||
expected = [[1, 2], [1, 2]] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_truncate_jit(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
max_seq_len = 2 | ||
truncate_jit = torch.jit.script(functional.truncate) | ||
actual = truncate_jit(input, max_seq_len=max_seq_len) | ||
expected = [[1, 2], [1, 2]] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_add_token(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
token_id = 0 | ||
actual = functional.add_token(input, token_id=token_id) | ||
expected = [[0, 1, 2], [0, 1, 2, 3]] | ||
self.assertEqual(actual, expected) | ||
|
||
actual = functional.add_token(input, token_id=token_id, begin=False) | ||
expected = [[1, 2, 0], [1, 2, 3, 0]] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_add_token_jit(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
token_id = 0 | ||
add_token_jit = torch.jit.script(functional.add_token) | ||
actual = add_token_jit(input, token_id=token_id) | ||
expected = [[0, 1, 2], [0, 1, 2, 3]] | ||
self.assertEqual(actual, expected) | ||
|
||
actual = add_token_jit(input, token_id=token_id, begin=False) | ||
expected = [[1, 2, 0], [1, 2, 3, 0]] | ||
self.assertEqual(actual, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
from torchtext import transforms | ||
from torchtext.vocab import vocab | ||
from collections import OrderedDict | ||
|
||
from .common.torchtext_test_case import TorchtextTestCase | ||
from .common.assets import get_asset_path | ||
|
||
|
||
class TestTransforms(TorchtextTestCase): | ||
def test_spmtokenizer_transform(self): | ||
asset_name = "spm_example.model" | ||
asset_path = get_asset_path(asset_name) | ||
transform = transforms.SpmTokenizerTransform(asset_path) | ||
actual = transform(["Hello World!, how are you?"]) | ||
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_spmtokenizer_transform_jit(self): | ||
asset_name = "spm_example.model" | ||
asset_path = get_asset_path(asset_name) | ||
transform = transforms.SpmTokenizerTransform(asset_path) | ||
transform_jit = torch.jit.script(transform) | ||
actual = transform_jit(["Hello World!, how are you?"]) | ||
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_vocab_transform(self): | ||
vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)])) | ||
transform = transforms.VocabTransform(vocab_obj) | ||
actual = transform([['a', 'b', 'c']]) | ||
expected = [[0, 1, 2]] | ||
self.assertEqual(actual, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
from torch import Tensor | ||
from torch.nn.utils.rnn import pad_sequence | ||
from typing import List, Optional | ||
|
||
__all__ = [ | ||
'to_tensor', | ||
'truncate', | ||
'add_token', | ||
] | ||
|
||
|
||
def to_tensor(input: List[List[int]], padding_value: Optional[int] = None) -> Tensor: | ||
if padding_value is None: | ||
output = torch.tensor(input, dtype=torch.long) | ||
return output | ||
else: | ||
output = pad_sequence( | ||
[torch.tensor(ids, dtype=torch.long) for ids in input], | ||
batch_first=True, | ||
padding_value=float(padding_value) | ||
) | ||
return output | ||
|
||
|
||
def truncate(input: List[List[int]], max_seq_len: int) -> List[List[int]]: | ||
output: List[List[int]] = [] | ||
|
||
for ids in input: | ||
output.append(ids[:max_seq_len]) | ||
|
||
return output | ||
|
||
|
||
def add_token(input: List[List[int]], token_id: int, begin: bool = True) -> List[List[int]]: | ||
output: List[List[int]] = [] | ||
|
||
if begin: | ||
for ids in input: | ||
output.append([token_id] + ids) | ||
else: | ||
for ids in input: | ||
output.append(ids + [token_id]) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .roberta import * # noqa: F401, F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from .model import ( | ||
RobertaEncoderParams, | ||
RobertaClassificationHead, | ||
) | ||
|
||
from .bundler import ( | ||
RobertaModelBundle, | ||
XLMR_BASE_ENCODER, | ||
XLMR_LARGE_ENCODER, | ||
) | ||
|
||
__all__ = [ | ||
"RobertaEncoderParams", | ||
"RobertaClassificationHead", | ||
"RobertaModelBundle", | ||
"XLMR_BASE_ENCODER", | ||
"XLMR_LARGE_ENCODER", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
|
||
import os | ||
from dataclasses import dataclass | ||
from functools import partial | ||
|
||
from typing import Optional, Callable | ||
from torch.hub import load_state_dict_from_url | ||
from torch.nn import Module | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
from .model import ( | ||
RobertaEncoderParams, | ||
RobertaModel, | ||
_get_model, | ||
) | ||
|
||
from .transforms import get_xlmr_transform | ||
|
||
from torchtext import _TEXT_BUCKET | ||
|
||
|
||
@dataclass | ||
class RobertaModelBundle: | ||
""" | ||
Example - Pretrained encoder | ||
>>> import torch, torchtext | ||
>>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
>>> model = xlmr_base.get_model() | ||
>>> transform = xlmr_base.transform() | ||
>>> model_input = torch.tensor(transform(["Hello World"])) | ||
>>> output = model(model_input) | ||
>>> output.shape | ||
torch.Size([1, 4, 768]) | ||
>>> input_batch = ["Hello world", "How are you!"] | ||
>>> from torchtext.functional import to_tensor | ||
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx) | ||
>>> output = model(model_input) | ||
>>> output.shape | ||
torch.Size([2, 6, 768]) | ||
Example - Pretrained encoder attached to un-initialized classification head | ||
>>> import torch, torchtext | ||
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER | ||
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim) | ||
>>> classification_model = xlmr_large.get_model(head=classifier_head) | ||
>>> transform = xlmr_large.transform() | ||
>>> model_input = torch.tensor(transform(["Hello World"])) | ||
>>> output = classification_model(model_input) | ||
>>> output.shape | ||
torch.Size([1, 2]) | ||
""" | ||
_params: RobertaEncoderParams | ||
_path: Optional[str] = None | ||
_head: Optional[Module] = None | ||
transform: Optional[Callable] = None | ||
|
||
def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel: | ||
|
||
if head is not None: | ||
input_head = head | ||
if self._head is not None: | ||
logger.log("A custom head module was provided, discarding the default head module.") | ||
else: | ||
input_head = self._head | ||
|
||
model = _get_model(self._params, input_head) | ||
|
||
dl_kwargs = {} if dl_kwargs is None else dl_kwargs | ||
state_dict = load_state_dict_from_url(self._path, **dl_kwargs) | ||
if input_head is not None: | ||
model.load_state_dict(state_dict, strict=False) | ||
else: | ||
model.load_state_dict(state_dict, strict=True) | ||
return model | ||
|
||
@property | ||
def params(self) -> RobertaEncoderParams: | ||
return self._params | ||
|
||
|
||
XLMR_BASE_ENCODER = RobertaModelBundle( | ||
_path=os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt"), | ||
_params=RobertaEncoderParams(vocab_size=250002), | ||
transform=partial(get_xlmr_transform, | ||
vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), | ||
spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), | ||
) | ||
) | ||
|
||
XLMR_LARGE_ENCODER = RobertaModelBundle( | ||
_path=os.path.join(_TEXT_BUCKET, "xlmr.large.encoder.pt"), | ||
_params=RobertaEncoderParams(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), | ||
transform=partial(get_xlmr_transform, | ||
vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), | ||
spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), | ||
) | ||
) |
Oops, something went wrong.