Skip to content

Commit

Permalink
Adding an AlbertMaskedLM task model and preprocessor (#725)
Browse files Browse the repository at this point in the history
* albert lm init commit

* fixing preprocessor tests

* fixing the main model test + formatting + docstrings

* fixing bug in masked lm head

* fixing none condition in masked_lm_head_test

* fixing formatting

* fixing test_valid_call_with_embedding_weights

* minor docstring changes

* Minor fixes

* addressing some comments

* working on fixing unit tests for masking

* working on fixing unit tests for masking

* adding mask to preprocessor + fixing tests

* code format

* fixing classifier test failures

* fixing formatting

---------

Co-authored-by: Matt Watson <mattdangerw@gmail.com>
  • Loading branch information
shivance and mattdangerw authored Feb 17, 2023
1 parent 6b1e37d commit 30cb703
Show file tree
Hide file tree
Showing 11 changed files with 701 additions and 20 deletions.
5 changes: 4 additions & 1 deletion keras_nlp/layers/masked_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def __init__(
self.vocabulary_size = shape[0]

def build(self, input_shapes):
feature_size = input_shapes[-1]
if self.embedding_weights is not None:
feature_size = self.embedding_weights.shape[-1]
else:
feature_size = input_shapes[-1]

self._dense = keras.layers.Dense(
feature_size,
Expand Down
12 changes: 7 additions & 5 deletions keras_nlp/layers/masked_lm_head_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@ def test_valid_call_with_embedding_weights(self):
embedding_weights=embedding.embeddings,
activation="softmax",
)
encoded_tokens = keras.Input(shape=(10, 16))
# Use a difference "hidden dim" for the model than "embedding dim", we
# need to support this in the layer.
sequence = keras.Input(shape=(10, 32))
positions = keras.Input(shape=(5,), dtype="int32")
outputs = head(encoded_tokens, mask_positions=positions)
model = keras.Model((encoded_tokens, positions), outputs)
token_data = tf.random.uniform(shape=(4, 10, 16))
outputs = head(sequence, mask_positions=positions)
model = keras.Model((sequence, positions), outputs)
sequence_data = tf.random.uniform(shape=(4, 10, 32))
position_data = tf.random.uniform(
shape=(4, 5), maxval=10, dtype="int32"
)
model((token_data, position_data))
model((sequence_data, position_data))

def test_get_config_and_from_config(self):
head = masked_lm_head.MaskedLMHead(
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.

from keras_nlp.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.models.albert.albert_masked_lm import AlbertMaskedLM
from keras_nlp.models.albert.albert_masked_lm_preprocessor import (
AlbertMaskedLMPreprocessor,
)
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.models.bart.bart_backbone import BartBackbone
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/models/albert/albert_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for BERT classification model."""
"""Tests for ALBERT classification model."""

import io
import os
Expand Down Expand Up @@ -57,6 +57,7 @@ def setUp(self):
unk_piece="<unk>",
bos_piece="[CLS]",
eos_piece="[SEP]",
user_defined_symbols="[MASK]",
)
self.proto = bytes_io.getvalue()

Expand Down
154 changes: 154 additions & 0 deletions keras_nlp/models/albert/albert_masked_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ALBERT masked LM model."""

import copy

from tensorflow import keras

from keras_nlp.layers.masked_lm_head import MaskedLMHead
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.models.albert.albert_backbone import albert_kernel_initializer
from keras_nlp.models.albert.albert_masked_lm_preprocessor import (
AlbertMaskedLMPreprocessor,
)
from keras_nlp.models.albert.albert_presets import backbone_presets
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class AlbertMaskedLM(Task):
"""An end-to-end ALBERT model for the masked language modeling task.
This model will train ALBERT on a masked language modeling task.
The model will predict labels for a number of masked tokens in the
input data. For usage of this model with pre-trained weights, see the
`from_preset()` method.
This model can optionally be configured with a `preprocessor` layer, in
which case inputs can be raw string features during `fit()`, `predict()`,
and `evaluate()`. Inputs will be tokenized and dynamically masked during
training and evaluation. This is done by default when creating the model
with `from_preset()`.
Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind.
Args:
backbone: A `keras_nlp.models.AlbertBackbone` instance.
preprocessor: A `keras_nlp.models.AlbertMaskedLMPreprocessor` or
`None`. If `None`, this model will not apply preprocessing, and
inputs should be preprocessed before calling the model.
Example usage:
Raw string inputs and pretrained backbone.
```python
# Create a dataset with raw string features. Labels are inferred.
features = ["The quick brown fox jumped.", "I forgot my homework."]
# Create a AlbertMaskedLM with a pretrained backbone and further train
# on an MLM task.
masked_lm = keras_nlp.models.AlbertMaskedLM.from_preset(
"albert_base_en_uncased",
)
masked_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
masked_lm.fit(x=features, batch_size=2)
```
Preprocessed inputs and custom backbone.
```python
# Create a preprocessed dataset where 0 is the mask token.
preprocessed_features = {
"segment_ids": tf.constant(
[[1, 0, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8)
),
"token_ids": tf.constant(
[[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8)
),
"mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2))
}
# Labels are the original masked values.
labels = [[3, 5]] * 2
# Randomly initialize a ALBERT encoder
backbone = keras_nlp.models.AlbertBackbone(
vocabulary_size=1000,
num_layers=2,
num_heads=2,
embedding_dim=64,
hidden_dim=64,
intermediate_dim=128,
max_sequence_length=128)
# Create a ALBERT masked LM and fit the data.
masked_lm = keras_nlp.models.AlbertMaskedLM(
backbone,
preprocessor=None,
)
masked_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
jit_compile=True
)
masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2)
```
"""

def __init__(self, backbone, preprocessor=None, **kwargs):
inputs = {
**backbone.input,
"mask_positions": keras.Input(
shape=(None,), dtype="int32", name="mask_positions"
),
}

backbone_outputs = backbone(backbone.input)
outputs = MaskedLMHead(
vocabulary_size=backbone.vocabulary_size,
embedding_weights=backbone.token_embedding.embeddings,
intermediate_activation=lambda x: keras.activations.gelu(
x, approximate=True
),
kernel_initializer=albert_kernel_initializer(),
name="mlm_head",
)(backbone_outputs["sequence_output"], inputs["mask_positions"])

super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs
)

self.backbone = backbone
self.preprocessor = preprocessor

@classproperty
def backbone_cls(cls):
return AlbertBackbone

@classproperty
def preprocessor_cls(cls):
return AlbertMaskedLMPreprocessor

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
Loading

0 comments on commit 30cb703

Please sign in to comment.