Skip to content

Commit

Permalink
Add presets for Electra and checkpoint conversion script (#1384)
Browse files Browse the repository at this point in the history
* Added ElectraBackbone

* Added backbone tests for ELECTRA

* Fix config

* Add model import to __init__

* add electra tokenizer

* add tests for tokenizer

* add __init__ file

* add tokenizer and backbone to models __init__

* Fix Failing tokenization test

* Add example on usage of the tokenizer with custom vocabulary

* Add conversion script to convert weights from checkpoint

* Add electra preprocessor

* Add presets and tests

* Add presets config with model weights

* Add checkpoint conversion script

* Name conversion for electra models

* Update naming conventions according to preset names

* Fix failing tokenizer tests

* Update checkpoint conversion script according to kaggle

* Add validate function

* Kaggle preset

* update preset link

* Add electra presets

* Complete run_small_preset test for electra

* Add large variations of electra in presets

* Fix case issues with electra presets

* Fix format

---------

Co-authored-by: Matt Watson <mattdangerw@gmail.com>
  • Loading branch information
2 people authored and abuelnasr0 committed Apr 2, 2024
1 parent 6703d76 commit 6a8166e
Show file tree
Hide file tree
Showing 9 changed files with 684 additions and 2 deletions.
1 change: 1 addition & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
DistilBertTokenizer,
)
from keras_nlp.models.electra.electra_backbone import ElectraBackbone
from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor
from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.models.f_net.f_net_classifier import FNetClassifier
Expand Down
20 changes: 18 additions & 2 deletions keras_nlp/models/electra/electra_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.electra.electra_presets import backbone_presets
from keras_nlp.utils.keras_utils import gelu_approximate
from keras_nlp.utils.python_utils import classproperty


def electra_kernel_initializer(stddev=0.02):
Expand All @@ -36,8 +40,9 @@ class ElectraBackbone(Backbone):
or classification task networks.
The default constructor gives a fully customizable, randomly initialized
Electra encoder with any number of layers, heads, and embedding
dimensions.
ELECTRA encoder with any number of layers, heads, and embedding
dimensions. To load preset architectures and weights, use the
`from_preset()` constructor.
Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
Expand Down Expand Up @@ -70,6 +75,13 @@ class ElectraBackbone(Backbone):
"segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}
# Pre-trained ELECTRA encoder.
model = keras_nlp.models.ElectraBackbone.from_preset(
"electra_base_discriminator_en"
)
model(input_data)
# Randomly initialized Electra encoder
backbone = keras_nlp.models.ElectraBackbone(
vocabulary_size=1000,
Expand Down Expand Up @@ -234,3 +246,7 @@ def get_config(self):
}
)
return config

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
34 changes: 34 additions & 0 deletions keras_nlp/models/electra/electra_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,37 @@ def test_saved_model(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
cls=ElectraBackbone,
preset="electra_small_discriminator_uncased_en",
input_data={
"token_ids": ops.array([[101, 1996, 4248, 102]], dtype="int32"),
"segment_ids": ops.zeros((1, 4), dtype="int32"),
"padding_mask": ops.ones((1, 4), dtype="int32"),
},
expected_output_shape={
"sequence_output": (1, 4, 256),
"pooled_output": (1, 256),
},
# The forward pass from a preset should be stable!
expected_partial_output={
"sequence_output": (
ops.array([0.32287, 0.18754, -0.22272, -0.24177, 1.18977])
),
"pooled_output": (
ops.array([-0.02974, 0.23383, 0.08430, -0.19471, 0.14822])
),
},
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in ElectraBackbone.presets:
self.run_preset_test(
cls=ElectraBackbone,
preset=preset,
input_data=self.input_data,
)
163 changes: 163 additions & 0 deletions keras_nlp/models/electra/electra_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.models.electra.electra_presets import backbone_presets
from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
from keras_nlp.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.models.ElectraPreprocessor")
class ElectraPreprocessor(Preprocessor):
"""A ELECTRA preprocessing layer which tokenizes and packs inputs.
This preprocessing layer will do three things:
1. Tokenize any number of input segments using the `tokenizer`.
2. Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`.
with the appropriate `"[CLS]"`, `"[SEP]"` and `"[PAD]"` tokens.
3. Construct a dictionary of with keys `"token_ids"` and `"padding_mask"`,
that can be passed directly to a ELECTRA model.
This layer can be used directly with `tf.data.Dataset.map` to preprocess
string data in the `(x, y, sample_weight)` format used by
`keras.Model.fit`.
Args:
tokenizer: A `keras_nlp.models.ElectraTokenizer` instance.
sequence_length: The length of the packed inputs.
truncate: string. The algorithm to truncate a list of batched segments
to fit within `sequence_length`. The value can be either
`round_robin` or `waterfall`:
- `"round_robin"`: Available space is assigned one token at a
time in a round-robin fashion to the inputs that still need
some, until the limit is reached.
- `"waterfall"`: The allocation of the budget is done using a
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It supports an arbitrary number of segments.
Call arguments:
x: A tensor of single string sequences, or a tuple of multiple
tensor sequences to be packed together. Inputs may be batched or
unbatched. For single sequences, raw python inputs will be converted
to tensors. For multiple sequences, pass tensors directly.
y: Any label data. Will be passed through unaltered.
sample_weight: Any label weight data. Will be passed through unaltered.
Examples:
Directly calling the layer on data.
```python
preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset(
"electra_base_discriminator_en"
)
preprocessor(["The quick brown fox jumped.", "Call me Ishmael."])
# Custom vocabulary.
vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
vocab += ["The", "quick", "brown", "fox", "jumped", "."]
tokenizer = keras_nlp.models.ElectraTokenizer(vocabulary=vocab)
preprocessor = keras_nlp.models.ElectraPreprocessor(tokenizer)
preprocessor("The quick brown fox jumped.")
```
Mapping with `tf.data.Dataset`.
```python
preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset(
"electra_base_discriminator_en"
)
first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."])
second = tf.constant(["The fox tripped.", "Oh look, a whale."])
label = tf.constant([1, 1])
# Map labeled single sentences.
ds = tf.data.Dataset.from_tensor_slices((first, label))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
# Map unlabeled single sentences.
ds = tf.data.Dataset.from_tensor_slices(first)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
# Map labeled sentence pairs.
ds = tf.data.Dataset.from_tensor_slices(((first, second), label))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
# Map unlabeled sentence pairs.
ds = tf.data.Dataset.from_tensor_slices((first, second))
# Watch out for tf.data's default unpacking of tuples here!
# Best to invoke the `preprocessor` directly in this case.
ds = ds.map(
lambda first, second: preprocessor(x=(first, second)),
num_parallel_calls=tf.data.AUTOTUNE,
)
```
"""

def __init__(
self,
tokenizer,
sequence_length=512,
truncate="round_robin",
**kwargs,
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
self.packer = MultiSegmentPacker(
start_value=self.tokenizer.cls_token_id,
end_value=self.tokenizer.sep_token_id,
pad_value=self.tokenizer.pad_token_id,
truncate=truncate,
sequence_length=sequence_length,
)

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.packer.sequence_length,
"truncate": self.packer.truncate,
}
)
return config

def call(self, x, y=None, sample_weight=None):
x = convert_inputs_to_list_of_tensor_segments(x)
x = [self.tokenizer(segment) for segment in x]
token_ids, segment_ids = self.packer(x)
x = {
"token_ids": token_ids,
"segment_ids": segment_ids,
"padding_mask": token_ids != self.tokenizer.pad_token_id,
}
return pack_x_y_sample_weight(x, y, sample_weight)

@classproperty
def tokenizer_cls(cls):
return ElectraTokenizer

@classproperty
def presets(cls):
return copy.deepcopy({**backbone_presets})
67 changes: 67 additions & 0 deletions keras_nlp/models/electra/electra_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor
from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer
from keras_nlp.tests.test_case import TestCase


class ElectraPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
self.vocab += ["THE", "QUICK", "BROWN", "FOX"]
self.vocab += ["the", "quick", "brown", "fox"]
self.tokenizer = ElectraTokenizer(vocabulary=self.vocab)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 8,
}
self.input_data = (
["THE QUICK BROWN FOX."],
[1], # Pass through labels.
[1.0], # Pass through sample_weights.
)

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
cls=ElectraPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=(
{
"token_ids": [[2, 5, 6, 7, 8, 1, 3, 0]],
"segment_ids": [[0, 0, 0, 0, 0, 0, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
[1], # Pass through labels.
[1.0], # Pass through sample_weights.
),
)

def test_errors_for_2d_list_input(self):
preprocessor = ElectraPreprocessor(**self.init_kwargs)
ambiguous_input = [["one", "two"], ["three", "four"]]
with self.assertRaises(ValueError):
preprocessor(ambiguous_input)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in ElectraPreprocessor.presets:
self.run_preset_test(
cls=ElectraPreprocessor,
preset=preset,
input_data=self.input_data,
)
Loading

0 comments on commit 6a8166e

Please sign in to comment.