Skip to content

Commit

Permalink
Update bert_backbone.py
Browse files Browse the repository at this point in the history
  • Loading branch information
soma2000-lang authored Feb 7, 2023
1 parent 8aae1b8 commit 52e0b4c
Showing 1 changed file with 37 additions and 185 deletions.
222 changes: 37 additions & 185 deletions keras_nlp/models/bert/bert_backbone.py
Original file line number Diff line number Diff line change
@@ -1,226 +1,78 @@
# Copyright 2023 The KerasNLP Authors
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT backbone model."""

import copy

import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.layers.transformer_encoder import TransformerEncoder
@@ -24,7 +24,6 @@
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.bert.bert_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring


def bert_kernel_initializer(stddev=0.02):
return keras.initializers.TruncatedNormal(stddev=stddev)


@keras.utils.register_keras_serializable(package="keras_nlp")
class BertBackbone(Backbone):
"""BERT encoder network.
This class implements a bi-directional Transformer-based encoder as
described in ["BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding"](https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or next sentence prediction heads.
The default constructor gives a fully customizable, randomly initialized BERT
encoder with any number of layers, heads, and embedding dimensions. To load
preset architectures and weights, use the `from_preset` constructor.
Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind.
Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
num_heads: int. The number of attention heads for each transformer.
The hidden size must be divisible by the number of attention heads.
hidden_dim: int. The size of the transformer encoding and pooler layers.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
dropout: float. Dropout probability for the Transformer encoder.
max_sequence_length: int. The maximum sequence length that this encoder
can consume. If None, `max_sequence_length` uses the value from
sequence length. This determines the variable shape for positional
embeddings.
num_segments: int. The number of types that the 'segment_ids' input can
take.
Examples:
```python
input_data = {
"token_ids": tf.ones(shape=(1, 12), dtype=tf.int64),
"segment_ids": tf.constant(
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
),
"padding_mask": tf.constant(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
),
@@ -77,7 +76,7 @@ class BertBackbone(Backbone):
}
# Pretrained BERT encoder
model = keras_nlp.models.BertBackbone.from_preset("bert_base_en_uncased")
model = keras_nlp.models.BertBackbone.from_preset("base_base_en_uncased")
output = model(input_data)
# Randomly initialized BERT encoder with a custom config
model = keras_nlp.models.BertBackbone(
vocabulary_size=30552,
num_layers=12,
num_heads=12,
hidden_dim=768,
intermediate_dim=3072,
max_sequence_length=12,
)
output = model(input_data)
```
"""

def __init__(
self,
vocabulary_size,
num_layers,
num_heads,
hidden_dim,
intermediate_dim,
dropout=0.1,
max_sequence_length=512,
@@ -105,7 +104,6 @@ def __init__(
num_segments=2,
**kwargs,
):

# Index of classification token in the vocabulary
cls_token_index = 0
# Inputs
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
segment_id_input = keras.Input(
shape=(None,), dtype="int32", name="segment_ids"
)
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# Embed tokens, positions, and segment ids.
token_embedding_layer = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
embeddings_initializer=bert_kernel_initializer(),
name="token_embedding",
)
token_embedding = token_embedding_layer(token_id_input)
position_embedding = PositionEmbedding(
initializer=bert_kernel_initializer(),
sequence_length=max_sequence_length,
name="position_embedding",
)(token_embedding)
segment_embedding = keras.layers.Embedding(
input_dim=num_segments,
output_dim=hidden_dim,
embeddings_initializer=bert_kernel_initializer(),
name="segment_embedding",
)(segment_id_input)

# Sum, normalize and apply dropout to embeddings.
x = keras.layers.Add()(
(token_embedding, position_embedding, segment_embedding)
)
x = keras.layers.LayerNormalization(
name="embeddings_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32,
)(x)
x = keras.layers.Dropout(
dropout,
name="embeddings_dropout",
)(x)

# Apply successive transformer encoder blocks.
for i in range(num_layers):
x = TransformerEncoder(
num_heads=num_heads,
intermediate_dim=intermediate_dim,
activation=lambda x: keras.activations.gelu(
@@ -163,6 +161,7 @@ def __init__(
x, approximate=True
),
dropout=dropout,
layer_norm_epsilon=1e-12,
kernel_initializer=bert_kernel_initializer(),
name=f"transformer_layer_{i}",
)(x, padding_mask=padding_mask)

# Construct the two BERT outputs. The pooled output is a dense layer on
# top of the [CLS] token.
sequence_output = x
pooled_output = keras.layers.Dense(
hidden_dim,
kernel_initializer=bert_kernel_initializer(),
activation="tanh",
name="pooled_dense",
)(x[:, cls_token_index, :])

# Instantiate using Functional API Model constructor
super().__init__(
inputs={
"token_ids": token_id_input,
"segment_ids": segment_id_input,
"padding_mask": padding_mask,
},
outputs={
"sequence_output": sequence_output,
"pooled_output": pooled_output,
@@ -190,44 +189,38 @@ def __init__(
},
**kwargs,
)

# All references to `self` below this line
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.num_layers = num_layers
self.num_heads = num_heads
self.max_sequence_length = max_sequence_length
self.num_segments = num_segments
self.dropout = dropout
self.token_embedding = token_embedding_layer
self.cls_token_index = cls_token_index

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
}
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")
return {
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
"dropout": self.dropout,
"name": self.name,
"trainable": self.trainable,
}

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classmethod
def from_preset(cls, preset, load_weights=True, **kwargs):
return super().from_preset(preset, load_weights, **kwargs)


BertBackbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__
format_docstring(
model_name=BertBackbone.__name__,
example_preset_name="bert_base_en_uncased",
preset_names='", "'.join(BertBackbone.presets),
)(BertBackbone.from_preset.__func__)

0 comments on commit 52e0b4c

Please sign in to comment.