Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Reinit layers of pretrained transformer in cached_transformers.get() (#…
Browse files Browse the repository at this point in the history
…5505)

* Arg to re-init some layers of pretrained transformer

* Better error handelling for bad indices

* Add a test for reinit_layers

* Type cast to appease mypy

* Get number of hidden layers in model agnostic way

* Support a list of regexes in reinit_modules

* Add tests for when reinit_modules is a list of str

* Fix broken test for re-initializing modules

* Break reinit_modules unit test into two

* Update changelog

* Tests for when reinit_modules should have no effect

* Revert pretrained transformer embedder to main

* Move reinit_modules feature to cached transformers

* Better error message for invalid reinit_modules argument

Co-authored-by: Pete <epwalsh10@gmail.com>

* Correct error message to say tuple, not list

* Add note about layer indices failing in docstring

* Correct "list" with "tuple" in error message

Co-authored-by: Pete <petew@allenai.org>
Co-authored-by: Pete <epwalsh10@gmail.com>
  • Loading branch information
3 people authored Dec 23, 2021
1 parent ec1fb69 commit 06ec7f9
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added a way to resize the vocabulary in the T5 module
- Added an argument `reinit_modules` to `cached_transformers.get()` that allows you to re-initialize the pretrained weights of a transformer model, using layer indices or regex strings.

### Fixed

Expand Down
76 changes: 71 additions & 5 deletions allennlp/common/cached_transformers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import re
import warnings
from typing import NamedTuple, Optional, Dict, Tuple
from typing import Dict, NamedTuple, Optional, Tuple, Union, cast

import transformers
from transformers import AutoModel, AutoConfig

from allennlp.common.checks import ConfigurationError
from transformers import AutoConfig, AutoModel

logger = logging.getLogger(__name__)

Expand All @@ -13,6 +14,7 @@ class TransformerSpec(NamedTuple):
model_name: str
override_weights_file: Optional[str] = None
override_weights_strip_prefix: Optional[str] = None
reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None


_model_cache: Dict[TransformerSpec, transformers.PreTrainedModel] = {}
Expand All @@ -23,6 +25,7 @@ def get(
make_copy: bool,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None,
load_weights: bool = True,
**kwargs,
) -> transformers.PreTrainedModel:
Expand All @@ -43,13 +46,28 @@ def get(
with `torch.save()`.
override_weights_strip_prefix : `str`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`)
If this is an integer, the last `reinit_modules` layers of the transformer will be
re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will
be re-initialized. Note, because the module structure of the transformer `model_name` can
differ, we cannot guarantee that providing an integer or tuple of integers will work. If
this fails, you can instead provide a tuple of strings, which will be treated as regexes and
any module with a name matching the regex will be re-initialized. Re-initializing the last
few layers of a pretrained transformer can reduce the instability of fine-tuning on small
datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect
if `load_weights` is `False` or `override_weights_file` is not `None`.
load_weights : `bool`, optional (default = `True`)
If set to `False`, no weights will be loaded. This is helpful when you only
want to initialize the architecture, like when you've already fine-tuned a model
and are going to load the weights from a state dict elsewhere.
"""
global _model_cache
spec = TransformerSpec(model_name, override_weights_file, override_weights_strip_prefix)
spec = TransformerSpec(
model_name,
override_weights_file,
override_weights_strip_prefix,
reinit_modules,
)
transformer = _model_cache.get(spec, None)
if transformer is None:
if not load_weights:
Expand All @@ -59,15 +77,27 @@ def get(
"but 'load_weights' is set to False, so 'override_weights_file' will be ignored.",
UserWarning,
)
if reinit_modules is not None:
warnings.warn(
"You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), "
"but 'load_weights' is set to False, so 'reinit_modules' will be ignored.",
UserWarning,
)
transformer = AutoModel.from_config(
AutoConfig.from_pretrained(
model_name,
**kwargs,
)
)
elif override_weights_file is not None:
from allennlp.common.file_utils import cached_path
if reinit_modules is not None:
warnings.warn(
"You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), "
"but 'override_weights_file' is not None, so 'reinit_modules' will be ignored.",
UserWarning,
)
import torch
from allennlp.common.file_utils import cached_path

override_weights_file = cached_path(override_weights_file)
override_weights = torch.load(override_weights_file)
Expand Down Expand Up @@ -110,6 +140,42 @@ def strip_prefix(s):
transformer.module.load_state_dict(override_weights)
else:
transformer.load_state_dict(override_weights)
elif reinit_modules is not None:
transformer = AutoModel.from_pretrained(
model_name,
**kwargs,
)
num_layers = transformer.config.num_hidden_layers
if isinstance(reinit_modules, int):
reinit_modules = tuple(range(num_layers - reinit_modules, num_layers))
if all(isinstance(x, int) for x in reinit_modules):
# This type cast is neccessary to avoid a mypy error.
reinit_modules = cast(Tuple[int], reinit_modules)
if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in reinit_modules):
raise ValueError(
f"A layer index in reinit_modules ({reinit_modules}) is invalid."
f" Must be between 0 and the maximum layer index ({num_layers - 1}.)"
)
# Some transformer models organize their modules differently, so if this fails,
# raise an error with a helpful message.
try:
for layer_idx in reinit_modules:
transformer.encoder.layer[layer_idx].apply(transformer._init_weights)
except AttributeError:
raise ConfigurationError(
f"Unable to re-initialize the layers of transformer model"
f" {model_name} using layer indices. Please provide a tuple of"
" strings corresponding to the names of the layers to re-initialize."
)
elif all(isinstance(x, str) for x in reinit_modules):
for regex in reinit_modules:
for name, module in transformer.named_modules():
if re.search(regex, name):
module.apply(transformer._init_weights)
else:
raise ValueError(
"reinit_modules must be either an integer, a tuple of strings, or a tuple of integers."
)
else:
transformer = AutoModel.from_pretrained(
model_name,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import logging
import math
from typing import Optional, Tuple, Dict, Any

from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from transformers import XLNetConfig

from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.modules.scalar_mix import ScalarMix
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
from allennlp.nn.util import batched_index_select
from transformers import XLNetConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,6 +52,16 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
with `torch.save()`.
override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`)
If this is an integer, the last `reinit_modules` layers of the transformer will be
re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will
be re-initialized. Note, because the module structure of the transformer `model_name` can
differ, we cannot guarantee that providing an integer or tuple of integers will work. If
this fails, you can instead provide a tuple of strings, which will be treated as regexes and
any module with a name matching the regex will be re-initialized. Re-initializing the last
few layers of a pretrained transformer can reduce the instability of fine-tuning on small
datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect
if `load_weights` is `False` or `override_weights_file` is not `None`.
load_weights: `bool`, optional (default = `True`)
Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive
it usually makes sense to set this to `False` (via the `overrides` parameter)
Expand Down Expand Up @@ -84,6 +92,7 @@ def __init__(
last_layer_only: bool = True,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None,
load_weights: bool = True,
gradient_checkpointing: Optional[bool] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -97,6 +106,7 @@ def __init__(
True,
override_weights_file=override_weights_file,
override_weights_strip_prefix=override_weights_strip_prefix,
reinit_modules=reinit_modules,
load_weights=load_weights,
**(transformer_kwargs or {}),
)
Expand Down
103 changes: 98 additions & 5 deletions tests/common/cached_transformers_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytest
import torch
import os
import json
import os

import pytest
import torch
from allennlp.common import cached_transformers
from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase

from transformers import AutoModel, AutoConfig
from transformers import AutoConfig, AutoModel


class TestCachedTransformers(AllenNlpTestCase):
Expand Down Expand Up @@ -72,6 +72,99 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self):
for p1, p2 in zip(transformer.parameters(), override_transformer.parameters()):
assert p1.data.ne(p2.data).sum() == 0

def test_reinit_modules_no_op(self):
# Test the case where reinit_modules is None (default)
preinit_weights = torch.cat(
[
# Comparing all weights of the model is rather complicated, so arbitrarily
# compare the weights of attention module.
layer.attention.output.dense.weight
for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
]
)
postinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
]
)
assert torch.equal(postinit_weights, preinit_weights)

def test_reinit_modules_with_layer_indices(self):
# Comparing all weights of the model is rather complicated, so arbitrarily compare the
# weights of attention module.
preinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
]
)

# Test the case when reinit_modules is a valid int.
postinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in cached_transformers.get(
"bert-base-cased", True, reinit_modules=2
).encoder.layer
]
)
assert torch.equal(postinit_weights[:10], preinit_weights[:10])
assert not torch.equal(postinit_weights[10:], preinit_weights[10:])

# Test the case when reinit_modules is a valid list of integers.
postinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in cached_transformers.get(
"bert-base-cased", True, reinit_modules=(10, 11)
).encoder.layer
]
)
assert torch.equal(postinit_weights[:10], preinit_weights[:10])
assert not torch.equal(postinit_weights[10:], preinit_weights[10:])

# Should raise a ValueError because reinit_modules contains at least one index that is
# greater than the models maximum number of layers
with pytest.raises(ValueError):
_ = cached_transformers.get("bert-base-cased", True, reinit_modules=1000)
with pytest.raises(ValueError):
_ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, 1000))
# The argument cannot mix layer indices and regex strings.
with pytest.raises(ValueError):
_ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, "attentions"))
# This model has a non-standard structure, so if a layer index or list of layer indexes
# is provided, we raise a ConfigurationError.
with pytest.raises(ConfigurationError):
_ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=1)
with pytest.raises(ConfigurationError):
_ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=(1, 2))

def test_reinit_modules_with_regex_strings(self):
# Comparing all weights of the model is rather complicated, so arbitrarily compare the
# weights of wpe module.
reinit_module = "wpe"
# This MUST be a deep copy, otherwise the parameters will be re-initialized and the
# test will break.
preinit_weights = list(
cached_transformers.get("sshleifer/tiny-gpt2", True)
.get_submodule(reinit_module)
.parameters()
)

postinit_weights = list(
cached_transformers.get(
"sshleifer/tiny-gpt2",
True,
reinit_modules=(reinit_module,),
)
.get_submodule(reinit_module)
.parameters()
)
assert all(
(not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights))
)

def test_from_pretrained_no_load_weights(self):
_ = cached_transformers.get(
"epwalsh/bert-xsmall-dummy", False, load_weights=False, cache_dir=self.TEST_DIR
Expand Down

0 comments on commit 06ec7f9

Please sign in to comment.