Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add eval_mode argument to pretrained transformer embedder #5111

Merged
merged 6 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Ported the following Huggingface `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`.
- Added new `sub_token_mode` parameter to `pretrained_transformer_mismatched_embedder` class to support first sub-token embedding
- Added new `eval_mode` in `PretrainedTransformerEmbedder`. If it is set to `True`, the transformer is _always_ run in evaluation mode, which, e.g., disables dropout and does not update batch normalization statistics.

### Changed

- Sanity checks in the `GradientDescentTrainer` can now be turned off by setting the `run_sanity_checks` parameter to `False`.
- Allow the order of examples in the task cards to be specified explicitly
- `histogram_interval` parameter is now deprecated in `TensorboardWriter`, please use `distribution_interval` instead.
- Memory usage is not logged in tensorboard during training now. `ConsoleLoggerCallback` should be used instead.
- If `train_parameters` in PretrainedTransformerEmbedder is `False`, the transformer's dropout and batch normalization layers are now set to evaluation mode.
- If you use the `min_count` parameter of the Vocabulary, but you specify a namespace that does not exist, the vocabulary creation will raise a `ConfigurationError`.
- Documentation updates made to SoftmaxLoss regarding padding and the expected shapes of the input and output tensors of `forward`.
- Moved the data preparation script for coref into allennlp-models.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
want to use the encoder.
train_parameters: `bool`, optional (default = `True`)
If this is `True`, the transformer weights get updated during training. If this is `False`, the
transformer weights are not updated during training and the dropout and batch normalization layers
are set to evaluation mode.
transformer weights are not updated during training.
nelson-liu marked this conversation as resolved.
Show resolved Hide resolved
eval_mode: `bool`, optional (default = `False`)
If this is True, the model is always set to evaluation mode (e.g., the dropout is disabled and the
batch normalization layer statistics are not updated). If this is False, the dropout and batch
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
normalization layers are only set to evaluation model when when the model is evaluating on development
nelson-liu marked this conversation as resolved.
Show resolved Hide resolved
or train data.
nelson-liu marked this conversation as resolved.
Show resolved Hide resolved
last_layer_only: `bool`, optional (default = `True`)
When `True` (the default), only the final layer of the pretrained transformer is taken
for the embeddings. But if set to `False`, a scalar mix of all of the layers
Expand All @@ -66,6 +70,7 @@ def __init__(
max_length: int = None,
sub_module: str = None,
train_parameters: bool = True,
eval_mode: bool = False,
last_layer_only: bool = True,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
Expand Down Expand Up @@ -125,15 +130,18 @@ def __init__(

self.train_parameters = train_parameters
if not train_parameters:
self.transformer_model.eval()
for param in self.transformer_model.parameters():
param.requires_grad = False

self.eval_mode = eval_mode
if eval_mode:
self.transformer_model.eval()

@overrides
def train(self, mode: bool = True):
self.training = mode
for name, module in self.named_children():
if not self.train_parameters and name == "transformer_model":
if self.eval_mode and name == "transformer_model":
module.eval()
else:
module.train(mode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,7 @@ def test_embeddings_resize(self):
)

def test_eval_mode(self):
token_embedder = PretrainedTransformerEmbedder(
"epwalsh/bert-xsmall-dummy", train_parameters=False
)
token_embedder = PretrainedTransformerEmbedder("epwalsh/bert-xsmall-dummy", eval_mode=True)
assert token_embedder.training and not token_embedder.transformer_model.training

class TrainableModule(torch.nn.Module):
Expand Down