Skip to content

Commit

Permalink
Added backbone tests for ELECTRA
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavvp16 committed Oct 31, 2023
1 parent 879020a commit 30c1a85
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 21 deletions.
65 changes: 44 additions & 21 deletions keras_nlp/models/electra/electra_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.layers.modeling.token_and_position_embedding import (
PositionEmbedding, ReversibleEmbedding
)
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder
from keras_nlp.models.backbone import Backbone
from keras_nlp.utils.python_utils import classproperty
Expand Down Expand Up @@ -46,6 +45,42 @@ class ElectraBackbone(Backbone):
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://huggingface.co/docs/transformers/model_doc/electra#overview).
Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
num_heads: int. The number of attention heads for each transformer.
The hidden size must be divisible by the number of attention heads.
hidden_dim: int. The size of the transformer encoding and pooler layers.
embedding_size: int. The size of the token embeddings.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
dropout: float. Dropout probability for the Transformer encoder.
max_sequence_length: int. The maximum sequence length that this encoder
can consume. If None, `max_sequence_length` uses the value from
sequence length. This determines the variable shape for positional
embeddings.
Examples:
```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
"segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}
# Randomly initialized Electra encoder
backbone = keras_nlp.models.ElectraBackbone(
vocabulary_size=1000,
num_layers=2,
num_heads=2,
hidden_size=32,
intermediate_dim=64,
dropout=0.1,
max_sequence_length=512,
)
# Returns sequence and pooled outputs.
sequence_output, pooled_output = backbone(input_data)
```
"""

def __init__(
Expand All @@ -55,7 +90,7 @@ def __init__(
num_heads,
embedding_size,
hidden_size,
intermediate_dim,
intermediate_size,
dropout=0.1,
max_sequence_length=512,
num_segments=2,
Expand Down Expand Up @@ -83,14 +118,12 @@ def __init__(
)
token_embedding = token_embedding_layer(token_id_input)
position_embedding = PositionEmbedding(
input_dim=max_sequence_length,
output_dim=embedding_size,
merge_mode="add",
embeddings_initializer=electra_kernel_initializer(),
initializer=electra_kernel_initializer(),
sequence_length=max_sequence_length,
name="position_embedding",
)(token_embedding)
segment_embedding = keras.layers.Embedding(
input_dim=max_sequence_length,
input_dim=num_segments,
output_dim=embedding_size,
embeddings_initializer=electra_kernel_initializer(),
name="segment_embedding",
Expand Down Expand Up @@ -124,7 +157,7 @@ def __init__(
for i in range(num_layers):
x = TransformerEncoder(
num_heads=num_heads,
intermediate_dim=intermediate_dim,
intermediate_dim=intermediate_size,
activation="gelu",
dropout=dropout,
layer_norm_epsilon=1e-12,
Expand Down Expand Up @@ -161,7 +194,7 @@ def __init__(
self.num_heads = num_heads
self.hidden_size = hidden_size
self.embedding_size = embedding_size
self.intermediate_dim = intermediate_dim
self.intermediate_dim = intermediate_size
self.dropout = dropout
self.max_sequence_length = max_sequence_length
self.num_segments = num_segments
Expand All @@ -186,13 +219,3 @@ def get_config(self):
}
)
return config










54 changes: 54 additions & 0 deletions keras_nlp/models/electra/electra_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from keras_nlp.backend import ops
from keras_nlp.models.electra.electra_backbone import ElectraBackbone
from keras_nlp.tests.test_case import TestCase

class ElectraBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 10,
"num_layers": 2,
"num_heads": 2,
"hidden_size": 2,
"embedding_size": 2,
"intermediate_size": 4,
"max_sequence_length": 5,
}
self.input_data = {
"token_ids": ops.ones((2, 5), dtype="int32"),
"segment_ids": ops.zeros((2, 5), dtype="int32"),
"padding_mask": ops.ones((2, 5), dtype="int32"),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=ElectraBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape={
"sequence_output": (2, 5, 2),
"pooled_output": (2, 2),
},
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=ElectraBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

0 comments on commit 30c1a85

Please sign in to comment.