-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added ELECTRA as a thin wrapper around BERT #358
Changes from 7 commits
81ba840
5bd4d54
467ee1f
1ca8514
85c2a5b
54cd512
2d070f0
37a3f0f
9fd1d62
ad1af8a
8dc7b1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .encoder import ELECTRAEncoder |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from typing import Any, Dict, List, Mapping, Optional, Tuple | ||
|
||
from ...util.string import StringTransform, StringTransformations | ||
from ..hf_hub.conversion import ( | ||
CommonHFKeys, | ||
HFConfigKey, | ||
HFConfigKeyDefault, | ||
HFSpecificConfig, | ||
config_from_hf, | ||
config_to_hf, | ||
) | ||
from ..bert import BERTConfig as ELECTRAConfig | ||
|
||
# Order-dependent. | ||
HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ | ||
# Old HF parameter names (one-way transforms). | ||
StringTransformations.regex_sub((r"\.gamma$", ".weight"), backward=None), | ||
StringTransformations.regex_sub((r"\.beta$", ".bias"), backward=None), | ||
# Prefixes. | ||
StringTransformations.remove_prefix("electra.", reversible=False), | ||
StringTransformations.regex_sub( | ||
(r"^encoder\.(layer\.)", "\\1"), | ||
(r"^(layer\.)", "encoder.\\1"), | ||
), | ||
# Layers. | ||
StringTransformations.regex_sub((r"^layer", "layers"), (r"^layers", "layer")), | ||
# Attention blocks. | ||
StringTransformations.regex_sub( | ||
(r"\.attention\.self\.(query|key|value)", ".mha.\\1"), | ||
(r"\.mha\.(query|key|value)", ".attention.self.\\1"), | ||
), | ||
StringTransformations.sub(".attention.output.dense", ".mha.output"), | ||
StringTransformations.sub( | ||
r".attention.output.LayerNorm", ".attn_residual_layer_norm" | ||
), | ||
# Pointwise feed-forward layers. | ||
StringTransformations.sub(".intermediate.dense", ".ffn.intermediate"), | ||
StringTransformations.regex_sub( | ||
(r"(\.\d+)\.output\.LayerNorm", "\\1.ffn_residual_layer_norm"), | ||
(r"(\.\d+)\.ffn_residual_layer_norm", "\\1.output.LayerNorm"), | ||
), | ||
StringTransformations.regex_sub( | ||
(r"(\.\d+)\.output\.dense", "\\1.ffn.output"), | ||
(r"(\.\d+)\.ffn\.output", "\\1.output.dense"), | ||
), | ||
# Embeddings. | ||
StringTransformations.replace( | ||
"embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight" | ||
), | ||
StringTransformations.replace( | ||
"embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight" | ||
), | ||
StringTransformations.replace( | ||
"embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight" | ||
), | ||
StringTransformations.replace( | ||
"embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight" | ||
), | ||
StringTransformations.replace( | ||
"embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias" | ||
), | ||
StringTransformations.replace( | ||
"embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias" | ||
), | ||
StringTransformations.replace( | ||
"embeddings_project.bias", "embeddings.projection.bias" | ||
), | ||
StringTransformations.replace( | ||
"embeddings_project.weight", "embeddings.projection.weight" | ||
), | ||
] | ||
|
||
HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [ | ||
(CommonHFKeys.ATTENTION_PROBS_DROPOUT_PROB, None), | ||
(CommonHFKeys.EMBEDDING_SIZE, None), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to add embedding size here, which led to a conflict with the BERT model. This was what led to the thin wrapper class. It is def. possible to avoid it by implementing an if else logic in This seemed like a reasonable comprise between not duplicating functionality and avoid coupling, though I could imagine you would want ELECTRA to be a part of BERT or completely independent. |
||
(CommonHFKeys.HIDDEN_DROPOUT_PROB, None), | ||
(CommonHFKeys.HIDDEN_SIZE, None), | ||
(CommonHFKeys.HIDDEN_ACT, None), | ||
(CommonHFKeys.INTERMEDIATE_SIZE, None), | ||
(CommonHFKeys.LAYER_NORM_EPS, None), | ||
(CommonHFKeys.NUM_ATTENTION_HEADS_UNIFORM, None), | ||
(CommonHFKeys.NUM_HIDDEN_LAYERS, None), | ||
(CommonHFKeys.VOCAB_SIZE, None), | ||
(CommonHFKeys.TYPE_VOCAB_SIZE, None), | ||
(CommonHFKeys.MAX_POSITION_EMBEDDINGS, None), | ||
] | ||
|
||
HF_SPECIFIC_CONFIG = HFSpecificConfig( | ||
architectures=["ElectraModel"], model_type="electra" | ||
) | ||
|
||
|
||
def _config_from_hf(hf_config: Mapping[str, Any]) -> ELECTRAConfig: | ||
kwargs = config_from_hf("ELECTRA", hf_config, HF_CONFIG_KEYS) | ||
return ELECTRAConfig( | ||
model_max_length=CommonHFKeys.MAX_POSITION_EMBEDDINGS.get_kwarg(kwargs), | ||
**kwargs, | ||
) | ||
|
||
|
||
def _config_to_hf(curated_config: ELECTRAConfig) -> Dict[str, Any]: | ||
out = config_to_hf(curated_config, [k for k, _ in HF_CONFIG_KEYS]) | ||
return HF_SPECIFIC_CONFIG.merge(out) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Any, Dict, Mapping, Type, TypeVar | ||
|
||
from torch import Tensor | ||
|
||
from ..bert import BERTConfig as ELECTRAConfig | ||
from ..bert import BERTEncoder | ||
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf | ||
from ._hf import HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf | ||
|
||
# Only provided as typing.Self in Python 3.11+. | ||
Self = TypeVar("Self", bound="ELECTRAEncoder") | ||
|
||
|
||
class ELECTRAEncoder(BERTEncoder): | ||
""" | ||
ELECTRA (`Clark et al., 2020`_) encoder. | ||
|
||
.. _Clark et al., 2020 : https://arxiv.org/abs/2003.10555 | ||
""" | ||
|
||
@classmethod | ||
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: | ||
return config.get("model_type") == "electra" | ||
|
||
@classmethod | ||
def state_dict_from_hf( | ||
cls: Type[Self], params: Mapping[str, Tensor] | ||
) -> Mapping[str, Tensor]: | ||
return state_dict_from_hf(params, HF_PARAM_KEY_TRANSFORMS) | ||
|
||
@classmethod | ||
def state_dict_to_hf( | ||
cls: Type[Self], params: Mapping[str, Tensor] | ||
) -> Mapping[str, Tensor]: | ||
return state_dict_to_hf(params, HF_PARAM_KEY_TRANSFORMS) | ||
|
||
@classmethod | ||
def config_from_hf(cls, hf_config: Mapping[str, Any]) -> ELECTRAConfig: | ||
return _config_from_hf(hf_config) | ||
|
||
@classmethod | ||
def config_to_hf(cls, curated_config: ELECTRAConfig) -> Mapping[str, Any]: | ||
return _config_to_hf(curated_config) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
|
||
from curated_transformers.models.electra.encoder import ELECTRAEncoder | ||
|
||
from ...compat import has_hf_transformers | ||
from ...conftest import TORCH_DEVICES | ||
from ..util import ( | ||
assert_encoder_output_equals_hf, | ||
) | ||
|
||
|
||
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") | ||
@pytest.mark.parametrize("torch_device", TORCH_DEVICES) | ||
@pytest.mark.parametrize("with_torch_sdp", [False, True]) | ||
@pytest.mark.parametrize( | ||
"model_name", | ||
[ | ||
"jonfd/electra-small-nordic", | ||
"Maltehb/aelaectra-danish-electra-small-cased", | ||
"google/electra-small-discriminator", | ||
], | ||
Comment on lines
+15
to
+19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked using a variety of models, but I imagine you might want to replace this with a dummy Electra model. |
||
) | ||
def test_encoder(model_name: str, torch_device, with_torch_sdp): | ||
assert_encoder_output_equals_hf( | ||
ELECTRAEncoder, | ||
"jonfd/electra-small-nordic", | ||
KennethEnevoldsen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torch_device, | ||
with_torch_sdp=with_torch_sdp, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
|
||
from curated_transformers.tokenizers.legacy.bert_tokenizer import ( | ||
BERTTokenizer, | ||
) | ||
|
||
from ...compat import has_hf_transformers | ||
from ..util import compare_tokenizer_outputs_with_hf_tokenizer | ||
|
||
|
||
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") | ||
@pytest.mark.parametrize( | ||
"model_name", | ||
[ | ||
"jonfd/electra-small-nordic", | ||
"Maltehb/aelaectra-danish-electra-small-cased", | ||
"google/electra-small-discriminator", | ||
], | ||
) | ||
def test_from_hf_hub_equals_hf_tokenizer(model_name: str, sample_texts): | ||
compare_tokenizer_outputs_with_hf_tokenizer(sample_texts, model_name, BERTTokenizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is only two notably differences from BERT here, this like and
embeddings_project.weight
andembeddings_project.bias