Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ELECTRA as a thin wrapper around BERT #358

Merged
merged 11 commits into from
Apr 2, 2024
9 changes: 7 additions & 2 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -37,7 +37,12 @@ class BERTEncoder(TransformerEncoder[BERTConfig], FromHFHub[BERTConfig]):
.. _Devlin et al., 2018 : https://arxiv.org/abs/1810.04805
"""

def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None):
def __init__(
self,
config: BERTConfig,
*,
device: Optional[torch.device] = None,
):
"""
Construct a BERT encoder.

Expand Down
1 change: 1 addition & 0 deletions curated_transformers/models/electra/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .encoder import ELECTRAEncoder
103 changes: 103 additions & 0 deletions curated_transformers/models/electra/_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple

from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import (
CommonHFKeys,
HFConfigKey,
HFConfigKeyDefault,
HFSpecificConfig,
config_from_hf,
config_to_hf,
)
from ..bert import BERTConfig as ELECTRAConfig

# Order-dependent.
HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
# Old HF parameter names (one-way transforms).
StringTransformations.regex_sub((r"\.gamma$", ".weight"), backward=None),
StringTransformations.regex_sub((r"\.beta$", ".bias"), backward=None),
# Prefixes.
StringTransformations.remove_prefix("electra.", reversible=False),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only two notably differences from BERT here, this like and embeddings_project.weight and embeddings_project.bias

StringTransformations.regex_sub(
(r"^encoder\.(layer\.)", "\\1"),
(r"^(layer\.)", "encoder.\\1"),
),
# Layers.
StringTransformations.regex_sub((r"^layer", "layers"), (r"^layers", "layer")),
# Attention blocks.
StringTransformations.regex_sub(
(r"\.attention\.self\.(query|key|value)", ".mha.\\1"),
(r"\.mha\.(query|key|value)", ".attention.self.\\1"),
),
StringTransformations.sub(".attention.output.dense", ".mha.output"),
StringTransformations.sub(
r".attention.output.LayerNorm", ".attn_residual_layer_norm"
),
# Pointwise feed-forward layers.
StringTransformations.sub(".intermediate.dense", ".ffn.intermediate"),
StringTransformations.regex_sub(
(r"(\.\d+)\.output\.LayerNorm", "\\1.ffn_residual_layer_norm"),
(r"(\.\d+)\.ffn_residual_layer_norm", "\\1.output.LayerNorm"),
),
StringTransformations.regex_sub(
(r"(\.\d+)\.output\.dense", "\\1.ffn.output"),
(r"(\.\d+)\.ffn\.output", "\\1.output.dense"),
),
# Embeddings.
StringTransformations.replace(
"embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight"
),
StringTransformations.replace(
"embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight"
),
StringTransformations.replace(
"embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"
),
StringTransformations.replace(
"embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight"
),
StringTransformations.replace(
"embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias"
),
StringTransformations.replace(
"embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias"
),
StringTransformations.replace(
"embeddings_project.bias", "embeddings.projection.bias"
),
StringTransformations.replace(
"embeddings_project.weight", "embeddings.projection.weight"
),
]

HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [
(CommonHFKeys.ATTENTION_PROBS_DROPOUT_PROB, None),
(CommonHFKeys.EMBEDDING_SIZE, None),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add embedding size here, which led to a conflict with the BERT model. This was what led to the thin wrapper class. It is def. possible to avoid it by implementing an if else logic in _config_from_hf.

This seemed like a reasonable comprise between not duplicating functionality and avoid coupling, though I could imagine you would want ELECTRA to be a part of BERT or completely independent.

(CommonHFKeys.HIDDEN_DROPOUT_PROB, None),
(CommonHFKeys.HIDDEN_SIZE, None),
(CommonHFKeys.HIDDEN_ACT, None),
(CommonHFKeys.INTERMEDIATE_SIZE, None),
(CommonHFKeys.LAYER_NORM_EPS, None),
(CommonHFKeys.NUM_ATTENTION_HEADS_UNIFORM, None),
(CommonHFKeys.NUM_HIDDEN_LAYERS, None),
(CommonHFKeys.VOCAB_SIZE, None),
(CommonHFKeys.TYPE_VOCAB_SIZE, None),
(CommonHFKeys.MAX_POSITION_EMBEDDINGS, None),
]

HF_SPECIFIC_CONFIG = HFSpecificConfig(
architectures=["ElectraModel"], model_type="electra"
)


def _config_from_hf(hf_config: Mapping[str, Any]) -> ELECTRAConfig:
kwargs = config_from_hf("ELECTRA", hf_config, HF_CONFIG_KEYS)
return ELECTRAConfig(
model_max_length=CommonHFKeys.MAX_POSITION_EMBEDDINGS.get_kwarg(kwargs),
**kwargs,
)


def _config_to_hf(curated_config: ELECTRAConfig) -> Dict[str, Any]:
out = config_to_hf(curated_config, [k for k, _ in HF_CONFIG_KEYS])
return HF_SPECIFIC_CONFIG.merge(out)
43 changes: 43 additions & 0 deletions curated_transformers/models/electra/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any, Dict, Mapping, Type, TypeVar

from torch import Tensor

from ..bert import BERTConfig as ELECTRAConfig
from ..bert import BERTEncoder
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ._hf import HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="ELECTRAEncoder")


class ELECTRAEncoder(BERTEncoder):
"""
ELECTRA (`Clark et al., 2020`_) encoder.

.. _Clark et al., 2020 : https://arxiv.org/abs/2003.10555
"""

@classmethod
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "electra"

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_from_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def state_dict_to_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_to_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def config_from_hf(cls, hf_config: Mapping[str, Any]) -> ELECTRAConfig:
return _config_from_hf(hf_config)

@classmethod
def config_to_hf(cls, curated_config: ELECTRAConfig) -> Mapping[str, Any]:
return _config_to_hf(curated_config)
Empty file.
29 changes: 29 additions & 0 deletions curated_transformers/tests/models/electra/test_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from curated_transformers.models.electra.encoder import ELECTRAEncoder

from ...compat import has_hf_transformers
from ...conftest import TORCH_DEVICES
from ..util import (
assert_encoder_output_equals_hf,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
@pytest.mark.parametrize(
"model_name",
[
"jonfd/electra-small-nordic",
"Maltehb/aelaectra-danish-electra-small-cased",
"google/electra-small-discriminator",
],
Comment on lines +15 to +19
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked using a variety of models, but I imagine you might want to replace this with a dummy Electra model.

)
def test_encoder(model_name: str, torch_device, with_torch_sdp):
assert_encoder_output_equals_hf(
ELECTRAEncoder,
"jonfd/electra-small-nordic",
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved
torch_device,
with_torch_sdp=with_torch_sdp,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from curated_transformers.tokenizers.legacy.bert_tokenizer import (
BERTTokenizer,
)

from ...compat import has_hf_transformers
from ..util import compare_tokenizer_outputs_with_hf_tokenizer


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize(
"model_name", ["jonfd/electra-small-nordic", "Maltehb/aelaectra-danish-electra-small-cased", "google/electra-small-discriminator"]
)
def test_from_hf_hub_equals_hf_tokenizer(model_name: str, sample_texts):
compare_tokenizer_outputs_with_hf_tokenizer(
sample_texts, model_name, BERTTokenizer
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is simply to show/test that the electra models can use the BERT tokenizer

Loading