From a7c3fbd79bea5260a1f5e47c2a704f8df1548bd9 Mon Sep 17 00:00:00 2001 From: Nelson Liu Date: Fri, 9 Apr 2021 14:32:41 -0700 Subject: [PATCH 1/5] Add eval_mode argument to pretrained transformer embedder --- .../pretrained_transformer_embedder.py | 16 ++++++++++++---- .../pretrained_transformer_embedder_test.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 9b344a4f8e9..18855009649 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -39,8 +39,12 @@ class PretrainedTransformerEmbedder(TokenEmbedder): want to use the encoder. train_parameters: `bool`, optional (default = `True`) If this is `True`, the transformer weights get updated during training. If this is `False`, the - transformer weights are not updated during training and the dropout and batch normalization layers - are set to evaluation mode. + transformer weights are not updated during training. + eval_mode: `bool`, optional (default = `False`) + If this is True, the model is always set to evaluation mode (e.g., the dropout is disabled and the + batch normalization layer statistics are not updated). If this is False, the dropout and batch + normalization layers are only set to evaluation model when when the model is evaluating on development + or train data. last_layer_only: `bool`, optional (default = `True`) When `True` (the default), only the final layer of the pretrained transformer is taken for the embeddings. But if set to `False`, a scalar mix of all of the layers @@ -66,6 +70,7 @@ def __init__( max_length: int = None, sub_module: str = None, train_parameters: bool = True, + eval_mode: bool = False, last_layer_only: bool = True, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, @@ -125,15 +130,18 @@ def __init__( self.train_parameters = train_parameters if not train_parameters: - self.transformer_model.eval() for param in self.transformer_model.parameters(): param.requires_grad = False + self.eval_mode = eval_mode + if eval_mode: + self.transformer_model.eval() + @overrides def train(self, mode: bool = True): self.training = mode for name, module in self.named_children(): - if not self.train_parameters and name == "transformer_model": + if self.eval_mode and name == "transformer_model": module.eval() else: module.train(mode) diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 89ab0c77eeb..7c5911e3070 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -333,7 +333,7 @@ def test_embeddings_resize(self): def test_eval_mode(self): token_embedder = PretrainedTransformerEmbedder( - "epwalsh/bert-xsmall-dummy", train_parameters=False + "epwalsh/bert-xsmall-dummy", eval_mode=True ) assert token_embedder.training and not token_embedder.transformer_model.training From ab0769ea2f23aab25e4f382d75afd0e97e716090 Mon Sep 17 00:00:00 2001 From: Nelson Liu Date: Fri, 9 Apr 2021 14:34:51 -0700 Subject: [PATCH 2/5] Edit changelog entry --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e46f64f50fb..bcc73761e2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ported the following Huggingface `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`. - Added new `sub_token_mode` parameter to `pretrained_transformer_mismatched_embedder` class to support first sub-token embedding +- Added new `eval_mode` in `PretrainedTransformerEmbedder`. If it is set to `True`, the transformer is _always_ run in evaluation mode, which, e.g., disables dropout and does not update batch normalization statistics. ### Changed @@ -18,7 +19,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Allow the order of examples in the task cards to be specified explicitly - `histogram_interval` parameter is now deprecated in `TensorboardWriter`, please use `distribution_interval` instead. - Memory usage is not logged in tensorboard during training now. `ConsoleLoggerCallback` should be used instead. -- If `train_parameters` in PretrainedTransformerEmbedder is `False`, the transformer's dropout and batch normalization layers are now set to evaluation mode. - If you use the `min_count` parameter of the Vocabulary, but you specify a namespace that does not exist, the vocabulary creation will raise a `ConfigurationError`. - Documentation updates made to SoftmaxLoss regarding padding and the expected shapes of the input and output tensors of `forward`. - Moved the data preparation script for coref into allennlp-models. From 7b85b9e54624572e5ecf9f2a55148fb6334c37fa Mon Sep 17 00:00:00 2001 From: Nelson Liu Date: Fri, 9 Apr 2021 17:24:43 -0700 Subject: [PATCH 3/5] Lint --- .../token_embedders/pretrained_transformer_embedder_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 7c5911e3070..00356d80ae8 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -332,9 +332,7 @@ def test_embeddings_resize(self): ) def test_eval_mode(self): - token_embedder = PretrainedTransformerEmbedder( - "epwalsh/bert-xsmall-dummy", eval_mode=True - ) + token_embedder = PretrainedTransformerEmbedder("epwalsh/bert-xsmall-dummy", eval_mode=True) assert token_embedder.training and not token_embedder.transformer_model.training class TrainableModule(torch.nn.Module): From fba04321cd825ab85691143e2399c9f50e977703 Mon Sep 17 00:00:00 2001 From: Evan Pete Walsh Date: Mon, 12 Apr 2021 16:41:15 -0700 Subject: [PATCH 4/5] Update allennlp/modules/token_embedders/pretrained_transformer_embedder.py --- .../token_embedders/pretrained_transformer_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 18855009649..4b6e77cd112 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -41,8 +41,8 @@ class PretrainedTransformerEmbedder(TokenEmbedder): If this is `True`, the transformer weights get updated during training. If this is `False`, the transformer weights are not updated during training. eval_mode: `bool`, optional (default = `False`) - If this is True, the model is always set to evaluation mode (e.g., the dropout is disabled and the - batch normalization layer statistics are not updated). If this is False, the dropout and batch + If this is `True`, the model is always set to evaluation mode (e.g., the dropout is disabled and the + batch normalization layer statistics are not updated). If this is `False`, the dropout and batch normalization layers are only set to evaluation model when when the model is evaluating on development or train data. last_layer_only: `bool`, optional (default = `True`) From 36a8c2d75f0724306882232538f72d1bc1cfd46f Mon Sep 17 00:00:00 2001 From: Nelson Liu Date: Mon, 12 Apr 2021 22:15:06 -0700 Subject: [PATCH 5/5] Apply suggestions from code review --- .../token_embedders/pretrained_transformer_embedder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 4b6e77cd112..1ce457b1adb 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -42,9 +42,9 @@ class PretrainedTransformerEmbedder(TokenEmbedder): transformer weights are not updated during training. eval_mode: `bool`, optional (default = `False`) If this is `True`, the model is always set to evaluation mode (e.g., the dropout is disabled and the - batch normalization layer statistics are not updated). If this is `False`, the dropout and batch - normalization layers are only set to evaluation model when when the model is evaluating on development - or train data. + batch normalization layer statistics are not updated). If this is `False`, such dropout and batch + normalization layers are only set to evaluation mode when when the model is evaluating on development + or test data. last_layer_only: `bool`, optional (default = `True`) When `True` (the default), only the final layer of the pretrained transformer is taken for the embeddings. But if set to `False`, a scalar mix of all of the layers