Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change from_config to build_model and other minor clean-ups #1452

Merged
merged 4 commits into from
Nov 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions test/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,31 @@ def test_xlmr_transform_jit(self):
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
torch.testing.assert_close(actual, expected)

def test_roberta_bundler_from_config(self):
def test_roberta_bundler_build_model(self):
from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle
dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)

# case: user provide encoder checkpoint state dict
dummy_encoder = RobertaModel(dummy_encoder_conf)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
checkpoint=dummy_encoder.state_dict())
self.assertEqual(model.state_dict(), dummy_encoder.state_dict())

# case: user provide classifier checkpoint state dict when head is given and override_head is False (by default)
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict())
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())

# case: user provide classifier checkpoint state dict when head is given and override_head is set True
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict(),
override_head=True)
override_checkpoint_head=True)
self.assertEqual(model.head.state_dict(), another_dummy_classifier_head.state_dict())

# case: user provide only encoder checkpoint state dict when head is given
Expand All @@ -126,7 +126,7 @@ def test_roberta_bundler_from_config(self):
encoder_state_dict = {}
for k, v in dummy_classifier.encoder.state_dict().items():
encoder_state_dict['encoder.' + k] = v
model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
model = torchtext.models.RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())

def test_roberta_bundler_train(self):
Expand All @@ -146,7 +146,7 @@ def _train(model):
# does not freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=False,
checkpoint=dummy_classifier.state_dict())
Expand All @@ -162,7 +162,7 @@ def _train(model):
# freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=True,
checkpoint=dummy_classifier.state_dict())
Expand Down
62 changes: 29 additions & 33 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .model import (
RobertaEncoderConf,
RobertaModel,
_get_model,
)

from .transforms import get_xlmr_transform
Expand All @@ -30,56 +29,50 @@ def _is_head_available_in_checkpoint(checkpoint, head_state_dict):
class RobertaModelBundle:
"""RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None)

Example - Pretrained encoder
Example - Pretrained base xlmr encoder
>>> import torch, torchtext
>>> from torchtext.functional import to_tensor
>>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER
>>> model = xlmr_base.get_model()
>>> transform = xlmr_base.transform()
>>> model_input = torch.tensor(transform(["Hello World"]))
>>> output = model(model_input)
>>> output.shape
torch.Size([1, 4, 768])
>>> input_batch = ["Hello world", "How are you!"]
>>> from torchtext.functional import to_tensor
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
>>> output = model(model_input)
>>> output.shape
torch.Size([2, 6, 768])

Example - Pretrained encoder attached to un-initialized classification head
Example - Pretrained large xlmr encoder attached to un-initialized classification head
>>> import torch, torchtext
>>> from torchtext.models import RobertaClassificationHead
>>> from torchtext.functional import to_tensor
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.encoderConf.embedding_dim)
>>> classification_model = xlmr_large.get_model(head=classifier_head)
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = 1024)
>>> model = xlmr_large.get_model(head=classifier_head)
>>> transform = xlmr_large.transform()
>>> model_input = torch.tensor(transform(["Hello World"]))
>>> output = classification_model(model_input)
>>> input_batch = ["Hello world", "How are you!"]
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
>>> output = model(model_input)
>>> output.shape
torch.Size([1, 2])

Example - User-specified configuration and checkpoint
>>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead
>>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt"
>>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002)
>>> roberta_bundle = RobertaModelBundle(_encoder_conf=roberta_encoder_conf, _path=model_weights_path)
>>> encoder = roberta_bundle.get_model()
>>> encoder_conf = RobertaEncoderConf(vocab_size=250002)
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
>>> classifier = roberta_bundle.get_model(head=classifier_head)
>>> # using from_config
>>> encoder = RobertaModelBundle.from_config(config=roberta_encoder_conf, checkpoint=model_weights_path)
>>> classifier = RobertaModelBundle.from_config(config=roberta_encoder_conf, head=classifier_head, checkpoint=model_weights_path)
>>> model = RobertaModelBundle.build_model(encoder_conf=encoder_conf, head=classifier_head, checkpoint=model_weights_path)
"""
_encoder_conf: RobertaEncoderConf
_path: Optional[str] = None
_head: Optional[Module] = None
transform: Optional[Callable] = None

def get_model(self,
*,
head: Optional[Module] = None,
load_weights: bool = True,
freeze_encoder: bool = False,
*,
dl_kwargs=None) -> RobertaModel:
dl_kwargs: Dict[str, Any] = None) -> RobertaModel:
r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel

Args:
Expand All @@ -103,35 +96,38 @@ def get_model(self,
else:
input_head = self._head

return RobertaModelBundle.from_config(encoder_conf=self._encoder_conf,
return RobertaModelBundle.build_model(encoder_conf=self._encoder_conf,
head=input_head,
freeze_encoder=freeze_encoder,
checkpoint=self._path,
override_head=True,
checkpoint=self._path if load_weights else None,
override_checkpoint_head=True,
strict=True,
dl_kwargs=dl_kwargs)

@classmethod
def from_config(
def build_model(
cls,
encoder_conf: RobertaEncoderConf,
*,
head: Optional[Module] = None,
freeze_encoder: bool = False,
checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None,
*,
override_head: bool = False,
override_checkpoint_head: bool = False,
strict=True,
dl_kwargs: Dict[str, Any] = None,
) -> RobertaModel:
"""Class method to create model with user-defined encoder configuration and checkpoint
"""Class builder method

Args:
encoder_conf (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration
head (nn.Module): A module to be attached to the encoder to perform specific task. (Default: ``None``)
freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``)
checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)
override_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
override_checkpoint_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: ``True``)
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``)
"""
model = _get_model(encoder_conf, head, freeze_encoder)
model = RobertaModel(encoder_conf, head, freeze_encoder)
if checkpoint is not None:
if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]):
state_dict = checkpoint
Expand All @@ -145,10 +141,10 @@ def from_config(
regex = re.compile(r"^head\.")
head_state_dict = {k: v for k, v in model.state_dict().items() if regex.findall(k)}
# If checkpoint does not contains head_state_dict, then we augment the checkpoint with user-provided head state_dict
if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_head:
if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_checkpoint_head:
state_dict.update(head_state_dict)

model.load_state_dict(state_dict, strict=True)
model.load_state_dict(state_dict, strict=strict)

return model

Expand All @@ -168,7 +164,7 @@ def encoderConf(self) -> RobertaEncoderConf:

XLMR_BASE_ENCODER.__doc__ = (
'''
XLM-R Encoder with base configuration
XLM-R Encoder with Base configuration

Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage.
'''
Expand Down
38 changes: 16 additions & 22 deletions torchtext/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
dropout: float = 0.1,
scaling: Optional[float] = None,
normalize_before: bool = False,
freeze: bool = False,
):
super().__init__()
if not scaling:
Expand All @@ -62,17 +63,17 @@ def __init__(
return_all_layers=False,
)

@classmethod
def from_config(cls, config: RobertaEncoderConf):
return cls(**asdict(config))
if freeze:
for p in self.parameters():
p.requires_grad = False

def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor:
output = self.transformer(tokens)
if torch.jit.isinstance(output, List[Tensor]):
output = output[-1]
output = output.transpose(1, 0)
if mask is not None:
output = output[mask.to(torch.bool), :]
if masked_tokens is not None:
output = output[masked_tokens.to(torch.bool), :]
return output


Expand Down Expand Up @@ -100,35 +101,28 @@ def forward(self, features):
class RobertaModel(Module):
"""

Example - Instantiate model with user-specified configuration
Example - Instatiating model object
>>> from torchtext.models import RobertaEncoderConf, RobertaModel, RobertaClassificationHead
>>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002)
>>> encoder = RobertaModel(config=roberta_encoder_conf)
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
>>> classifier = RobertaModel(config=roberta_encoder_conf, head=classifier_head)
"""

def __init__(self, config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False):
def __init__(self,
encoder_conf: RobertaEncoderConf,
head: Optional[Module] = None,
freeze_encoder: bool = False):
super().__init__()
assert isinstance(config, RobertaEncoderConf)

self.encoder = RobertaEncoder.from_config(config)
if freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False

logger.info("Encoder weights are frozen")
assert isinstance(encoder_conf, RobertaEncoderConf)

self.encoder = RobertaEncoder(**asdict(encoder_conf), freeze=freeze_encoder)
self.head = head

def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
features = self.encoder(tokens, mask)
def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor:
features = self.encoder(tokens, masked_tokens)
if self.head is None:
return features

x = self.head(features)
return x


def _get_model(config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel:
return RobertaModel(config, head, freeze_encoder)