Skip to content

Commit

Permalink
change from_config to build_model and other minor clean-ups (#1452)
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet committed Nov 28, 2021
1 parent aea6ad6 commit 9f2fb3f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 63 deletions.
16 changes: 8 additions & 8 deletions test/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,31 @@ def test_xlmr_transform_jit(self):
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
torch.testing.assert_close(actual, expected)

def test_roberta_bundler_from_config(self):
def test_roberta_bundler_build_model(self):
from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle
dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)

# case: user provide encoder checkpoint state dict
dummy_encoder = RobertaModel(dummy_encoder_conf)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
checkpoint=dummy_encoder.state_dict())
self.assertEqual(model.state_dict(), dummy_encoder.state_dict())

# case: user provide classifier checkpoint state dict when head is given and override_head is False (by default)
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict())
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())

# case: user provide classifier checkpoint state dict when head is given and override_head is set True
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict(),
override_head=True)
override_checkpoint_head=True)
self.assertEqual(model.head.state_dict(), another_dummy_classifier_head.state_dict())

# case: user provide only encoder checkpoint state dict when head is given
Expand All @@ -126,7 +126,7 @@ def test_roberta_bundler_from_config(self):
encoder_state_dict = {}
for k, v in dummy_classifier.encoder.state_dict().items():
encoder_state_dict['encoder.' + k] = v
model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
model = torchtext.models.RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())

def test_roberta_bundler_train(self):
Expand All @@ -146,7 +146,7 @@ def _train(model):
# does not freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=False,
checkpoint=dummy_classifier.state_dict())
Expand All @@ -162,7 +162,7 @@ def _train(model):
# freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=True,
checkpoint=dummy_classifier.state_dict())
Expand Down
62 changes: 29 additions & 33 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .model import (
RobertaEncoderConf,
RobertaModel,
_get_model,
)

from .transforms import get_xlmr_transform
Expand All @@ -30,56 +29,50 @@ def _is_head_available_in_checkpoint(checkpoint, head_state_dict):
class RobertaModelBundle:
"""RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None)
Example - Pretrained encoder
Example - Pretrained base xlmr encoder
>>> import torch, torchtext
>>> from torchtext.functional import to_tensor
>>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER
>>> model = xlmr_base.get_model()
>>> transform = xlmr_base.transform()
>>> model_input = torch.tensor(transform(["Hello World"]))
>>> output = model(model_input)
>>> output.shape
torch.Size([1, 4, 768])
>>> input_batch = ["Hello world", "How are you!"]
>>> from torchtext.functional import to_tensor
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
>>> output = model(model_input)
>>> output.shape
torch.Size([2, 6, 768])
Example - Pretrained encoder attached to un-initialized classification head
Example - Pretrained large xlmr encoder attached to un-initialized classification head
>>> import torch, torchtext
>>> from torchtext.models import RobertaClassificationHead
>>> from torchtext.functional import to_tensor
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.encoderConf.embedding_dim)
>>> classification_model = xlmr_large.get_model(head=classifier_head)
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = 1024)
>>> model = xlmr_large.get_model(head=classifier_head)
>>> transform = xlmr_large.transform()
>>> model_input = torch.tensor(transform(["Hello World"]))
>>> output = classification_model(model_input)
>>> input_batch = ["Hello world", "How are you!"]
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
>>> output = model(model_input)
>>> output.shape
torch.Size([1, 2])
Example - User-specified configuration and checkpoint
>>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead
>>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt"
>>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002)
>>> roberta_bundle = RobertaModelBundle(_encoder_conf=roberta_encoder_conf, _path=model_weights_path)
>>> encoder = roberta_bundle.get_model()
>>> encoder_conf = RobertaEncoderConf(vocab_size=250002)
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
>>> classifier = roberta_bundle.get_model(head=classifier_head)
>>> # using from_config
>>> encoder = RobertaModelBundle.from_config(config=roberta_encoder_conf, checkpoint=model_weights_path)
>>> classifier = RobertaModelBundle.from_config(config=roberta_encoder_conf, head=classifier_head, checkpoint=model_weights_path)
>>> model = RobertaModelBundle.build_model(encoder_conf=encoder_conf, head=classifier_head, checkpoint=model_weights_path)
"""
_encoder_conf: RobertaEncoderConf
_path: Optional[str] = None
_head: Optional[Module] = None
transform: Optional[Callable] = None

def get_model(self,
*,
head: Optional[Module] = None,
load_weights: bool = True,
freeze_encoder: bool = False,
*,
dl_kwargs=None) -> RobertaModel:
dl_kwargs: Dict[str, Any] = None) -> RobertaModel:
r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel
Args:
Expand All @@ -103,35 +96,38 @@ def get_model(self,
else:
input_head = self._head

return RobertaModelBundle.from_config(encoder_conf=self._encoder_conf,
return RobertaModelBundle.build_model(encoder_conf=self._encoder_conf,
head=input_head,
freeze_encoder=freeze_encoder,
checkpoint=self._path,
override_head=True,
checkpoint=self._path if load_weights else None,
override_checkpoint_head=True,
strict=True,
dl_kwargs=dl_kwargs)

@classmethod
def from_config(
def build_model(
cls,
encoder_conf: RobertaEncoderConf,
*,
head: Optional[Module] = None,
freeze_encoder: bool = False,
checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None,
*,
override_head: bool = False,
override_checkpoint_head: bool = False,
strict=True,
dl_kwargs: Dict[str, Any] = None,
) -> RobertaModel:
"""Class method to create model with user-defined encoder configuration and checkpoint
"""Class builder method
Args:
encoder_conf (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration
head (nn.Module): A module to be attached to the encoder to perform specific task. (Default: ``None``)
freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``)
checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)
override_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
override_checkpoint_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: ``True``)
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``)
"""
model = _get_model(encoder_conf, head, freeze_encoder)
model = RobertaModel(encoder_conf, head, freeze_encoder)
if checkpoint is not None:
if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]):
state_dict = checkpoint
Expand All @@ -145,10 +141,10 @@ def from_config(
regex = re.compile(r"^head\.")
head_state_dict = {k: v for k, v in model.state_dict().items() if regex.findall(k)}
# If checkpoint does not contains head_state_dict, then we augment the checkpoint with user-provided head state_dict
if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_head:
if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_checkpoint_head:
state_dict.update(head_state_dict)

model.load_state_dict(state_dict, strict=True)
model.load_state_dict(state_dict, strict=strict)

return model

Expand All @@ -168,7 +164,7 @@ def encoderConf(self) -> RobertaEncoderConf:

XLMR_BASE_ENCODER.__doc__ = (
'''
XLM-R Encoder with base configuration
XLM-R Encoder with Base configuration
Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage.
'''
Expand Down
38 changes: 16 additions & 22 deletions torchtext/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
dropout: float = 0.1,
scaling: Optional[float] = None,
normalize_before: bool = False,
freeze: bool = False,
):
super().__init__()
if not scaling:
Expand All @@ -62,17 +63,17 @@ def __init__(
return_all_layers=False,
)

@classmethod
def from_config(cls, config: RobertaEncoderConf):
return cls(**asdict(config))
if freeze:
for p in self.parameters():
p.requires_grad = False

def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor:
output = self.transformer(tokens)
if torch.jit.isinstance(output, List[Tensor]):
output = output[-1]
output = output.transpose(1, 0)
if mask is not None:
output = output[mask.to(torch.bool), :]
if masked_tokens is not None:
output = output[masked_tokens.to(torch.bool), :]
return output


Expand Down Expand Up @@ -100,35 +101,28 @@ def forward(self, features):
class RobertaModel(Module):
"""
Example - Instantiate model with user-specified configuration
Example - Instatiating model object
>>> from torchtext.models import RobertaEncoderConf, RobertaModel, RobertaClassificationHead
>>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002)
>>> encoder = RobertaModel(config=roberta_encoder_conf)
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
>>> classifier = RobertaModel(config=roberta_encoder_conf, head=classifier_head)
"""

def __init__(self, config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False):
def __init__(self,
encoder_conf: RobertaEncoderConf,
head: Optional[Module] = None,
freeze_encoder: bool = False):
super().__init__()
assert isinstance(config, RobertaEncoderConf)

self.encoder = RobertaEncoder.from_config(config)
if freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False

logger.info("Encoder weights are frozen")
assert isinstance(encoder_conf, RobertaEncoderConf)

self.encoder = RobertaEncoder(**asdict(encoder_conf), freeze=freeze_encoder)
self.head = head

def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
features = self.encoder(tokens, mask)
def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor:
features = self.encoder(tokens, masked_tokens)
if self.head is None:
return features

x = self.head(features)
return x


def _get_model(config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel:
return RobertaModel(config, head, freeze_encoder)

0 comments on commit 9f2fb3f

Please sign in to comment.