From 9f2fb3f00cd9a4cc8d41d2e9cbfa5e9bf9533224 Mon Sep 17 00:00:00 2001 From: parmeet Date: Sun, 28 Nov 2021 17:59:41 -0500 Subject: [PATCH] change from_config to build_model and other minor clean-ups (#1452) --- test/models/test_models.py | 16 ++++---- torchtext/models/roberta/bundler.py | 62 ++++++++++++++--------------- torchtext/models/roberta/model.py | 38 ++++++++---------- 3 files changed, 53 insertions(+), 63 deletions(-) diff --git a/test/models/test_models.py b/test/models/test_models.py index 67876c75f..2eb432745 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -93,13 +93,13 @@ def test_xlmr_transform_jit(self): expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] torch.testing.assert_close(actual, expected) - def test_roberta_bundler_from_config(self): + def test_roberta_bundler_build_model(self): from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2) # case: user provide encoder checkpoint state dict dummy_encoder = RobertaModel(dummy_encoder_conf) - model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, checkpoint=dummy_encoder.state_dict()) self.assertEqual(model.state_dict(), dummy_encoder.state_dict()) @@ -107,17 +107,17 @@ def test_roberta_bundler_from_config(self): dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) - model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, head=another_dummy_classifier_head, checkpoint=dummy_classifier.state_dict()) self.assertEqual(model.state_dict(), dummy_classifier.state_dict()) # case: user provide classifier checkpoint state dict when head is given and override_head is set True another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) - model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, head=another_dummy_classifier_head, checkpoint=dummy_classifier.state_dict(), - override_head=True) + override_checkpoint_head=True) self.assertEqual(model.head.state_dict(), another_dummy_classifier_head.state_dict()) # case: user provide only encoder checkpoint state dict when head is given @@ -126,7 +126,7 @@ def test_roberta_bundler_from_config(self): encoder_state_dict = {} for k, v in dummy_classifier.encoder.state_dict().items(): encoder_state_dict['encoder.' + k] = v - model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict) + model = torchtext.models.RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict) self.assertEqual(model.state_dict(), dummy_classifier.state_dict()) def test_roberta_bundler_train(self): @@ -146,7 +146,7 @@ def _train(model): # does not freeze encoder dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) - model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, freeze_encoder=False, checkpoint=dummy_classifier.state_dict()) @@ -162,7 +162,7 @@ def _train(model): # freeze encoder dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) - model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, freeze_encoder=True, checkpoint=dummy_classifier.state_dict()) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index d1b699e4c..58774aacd 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -13,7 +13,6 @@ from .model import ( RobertaEncoderConf, RobertaModel, - _get_model, ) from .transforms import get_xlmr_transform @@ -30,44 +29,38 @@ def _is_head_available_in_checkpoint(checkpoint, head_state_dict): class RobertaModelBundle: """RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None) - Example - Pretrained encoder + Example - Pretrained base xlmr encoder >>> import torch, torchtext + >>> from torchtext.functional import to_tensor >>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER >>> model = xlmr_base.get_model() >>> transform = xlmr_base.transform() - >>> model_input = torch.tensor(transform(["Hello World"])) - >>> output = model(model_input) - >>> output.shape - torch.Size([1, 4, 768]) >>> input_batch = ["Hello world", "How are you!"] - >>> from torchtext.functional import to_tensor >>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx) >>> output = model(model_input) >>> output.shape torch.Size([2, 6, 768]) - Example - Pretrained encoder attached to un-initialized classification head + Example - Pretrained large xlmr encoder attached to un-initialized classification head >>> import torch, torchtext + >>> from torchtext.models import RobertaClassificationHead + >>> from torchtext.functional import to_tensor >>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER - >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.encoderConf.embedding_dim) - >>> classification_model = xlmr_large.get_model(head=classifier_head) + >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = 1024) + >>> model = xlmr_large.get_model(head=classifier_head) >>> transform = xlmr_large.transform() - >>> model_input = torch.tensor(transform(["Hello World"])) - >>> output = classification_model(model_input) + >>> input_batch = ["Hello world", "How are you!"] + >>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx) + >>> output = model(model_input) >>> output.shape torch.Size([1, 2]) Example - User-specified configuration and checkpoint >>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead >>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt" - >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002) - >>> roberta_bundle = RobertaModelBundle(_encoder_conf=roberta_encoder_conf, _path=model_weights_path) - >>> encoder = roberta_bundle.get_model() + >>> encoder_conf = RobertaEncoderConf(vocab_size=250002) >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) - >>> classifier = roberta_bundle.get_model(head=classifier_head) - >>> # using from_config - >>> encoder = RobertaModelBundle.from_config(config=roberta_encoder_conf, checkpoint=model_weights_path) - >>> classifier = RobertaModelBundle.from_config(config=roberta_encoder_conf, head=classifier_head, checkpoint=model_weights_path) + >>> model = RobertaModelBundle.build_model(encoder_conf=encoder_conf, head=classifier_head, checkpoint=model_weights_path) """ _encoder_conf: RobertaEncoderConf _path: Optional[str] = None @@ -75,11 +68,11 @@ class RobertaModelBundle: transform: Optional[Callable] = None def get_model(self, + *, head: Optional[Module] = None, load_weights: bool = True, freeze_encoder: bool = False, - *, - dl_kwargs=None) -> RobertaModel: + dl_kwargs: Dict[str, Any] = None) -> RobertaModel: r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel Args: @@ -103,35 +96,38 @@ def get_model(self, else: input_head = self._head - return RobertaModelBundle.from_config(encoder_conf=self._encoder_conf, + return RobertaModelBundle.build_model(encoder_conf=self._encoder_conf, head=input_head, freeze_encoder=freeze_encoder, - checkpoint=self._path, - override_head=True, + checkpoint=self._path if load_weights else None, + override_checkpoint_head=True, + strict=True, dl_kwargs=dl_kwargs) @classmethod - def from_config( + def build_model( cls, encoder_conf: RobertaEncoderConf, + *, head: Optional[Module] = None, freeze_encoder: bool = False, checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None, - *, - override_head: bool = False, + override_checkpoint_head: bool = False, + strict=True, dl_kwargs: Dict[str, Any] = None, ) -> RobertaModel: - """Class method to create model with user-defined encoder configuration and checkpoint + """Class builder method Args: encoder_conf (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration head (nn.Module): A module to be attached to the encoder to perform specific task. (Default: ``None``) freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``) checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``) - override_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``) + override_checkpoint_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``) + strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: ``True``) dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``) """ - model = _get_model(encoder_conf, head, freeze_encoder) + model = RobertaModel(encoder_conf, head, freeze_encoder) if checkpoint is not None: if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]): state_dict = checkpoint @@ -145,10 +141,10 @@ def from_config( regex = re.compile(r"^head\.") head_state_dict = {k: v for k, v in model.state_dict().items() if regex.findall(k)} # If checkpoint does not contains head_state_dict, then we augment the checkpoint with user-provided head state_dict - if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_head: + if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_checkpoint_head: state_dict.update(head_state_dict) - model.load_state_dict(state_dict, strict=True) + model.load_state_dict(state_dict, strict=strict) return model @@ -168,7 +164,7 @@ def encoderConf(self) -> RobertaEncoderConf: XLMR_BASE_ENCODER.__doc__ = ( ''' - XLM-R Encoder with base configuration + XLM-R Encoder with Base configuration Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage. ''' diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py index 60a436255..e4dd8ddc8 100644 --- a/torchtext/models/roberta/model.py +++ b/torchtext/models/roberta/model.py @@ -42,6 +42,7 @@ def __init__( dropout: float = 0.1, scaling: Optional[float] = None, normalize_before: bool = False, + freeze: bool = False, ): super().__init__() if not scaling: @@ -62,17 +63,17 @@ def __init__( return_all_layers=False, ) - @classmethod - def from_config(cls, config: RobertaEncoderConf): - return cls(**asdict(config)) + if freeze: + for p in self.parameters(): + p.requires_grad = False - def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor: output = self.transformer(tokens) if torch.jit.isinstance(output, List[Tensor]): output = output[-1] output = output.transpose(1, 0) - if mask is not None: - output = output[mask.to(torch.bool), :] + if masked_tokens is not None: + output = output[masked_tokens.to(torch.bool), :] return output @@ -100,7 +101,7 @@ def forward(self, features): class RobertaModel(Module): """ - Example - Instantiate model with user-specified configuration + Example - Instatiating model object >>> from torchtext.models import RobertaEncoderConf, RobertaModel, RobertaClassificationHead >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002) >>> encoder = RobertaModel(config=roberta_encoder_conf) @@ -108,27 +109,20 @@ class RobertaModel(Module): >>> classifier = RobertaModel(config=roberta_encoder_conf, head=classifier_head) """ - def __init__(self, config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False): + def __init__(self, + encoder_conf: RobertaEncoderConf, + head: Optional[Module] = None, + freeze_encoder: bool = False): super().__init__() - assert isinstance(config, RobertaEncoderConf) - - self.encoder = RobertaEncoder.from_config(config) - if freeze_encoder: - for param in self.encoder.parameters(): - param.requires_grad = False - - logger.info("Encoder weights are frozen") + assert isinstance(encoder_conf, RobertaEncoderConf) + self.encoder = RobertaEncoder(**asdict(encoder_conf), freeze=freeze_encoder) self.head = head - def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: - features = self.encoder(tokens, mask) + def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor: + features = self.encoder(tokens, masked_tokens) if self.head is None: return features x = self.head(features) return x - - -def _get_model(config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel: - return RobertaModel(config, head, freeze_encoder)