Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding GPTNeoXBackbone #1056

Merged
merged 27 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9412a83
added gpt-neo attention+decoder+backbone
shivance May 29, 2023
99a8296
fixed formatting + added backbone test
shivance May 29, 2023
afb7e1f
fixed rotary embedding and gpt neo attention layer
shivance Jun 6, 2023
f0f6383
updating decoder and backbone to current version
shivance Jun 6, 2023
bfd56fa
fixed decoder + backbone
shivance Jun 7, 2023
97a347d
fix forward pass
shivance Jun 10, 2023
5ead767
formatting + add checkpoint script
shivance Jun 10, 2023
5776ac1
fix tpu_test, formatting
shivance Jun 10, 2023
e0d343b
removed unnecessary layernorms, correct arguments, fix unit tests (te…
shivance Jun 12, 2023
451cdbc
fix dropout
shivance Jun 12, 2023
e37fb22
matching outputs with hf
shivance Jun 14, 2023
ead11c5
fix formating
shivance Jun 14, 2023
c7117a4
resolving few comments
shivance Jun 14, 2023
c72e629
fixed unit tests + formatting
shivance Jun 16, 2023
2341d0e
refactored rotary embedding
shivance Jun 16, 2023
6112357
revamped checkpoint conversion script
shivance Jun 16, 2023
66afa7c
code format
shivance Jun 16, 2023
f363f24
putting old checkpoint script back until preset
shivance Jun 16, 2023
7a66052
incorporated comments
shivance Jun 17, 2023
6f6f41e
code format
shivance Jun 17, 2023
f34ec47
resolved comments + fixed formatting
shivance Jun 17, 2023
34db7f7
added gpt neo x tokenizer
shivance Jun 17, 2023
1ecfe51
added docstrings
shivance Jun 21, 2023
b3f06e4
formatting fix
shivance Jun 21, 2023
a9f2230
addressing comments
shivance Jun 23, 2023
122a3fb
added tokenizer output verification
shivance Jun 23, 2023
e10ea50
Minor style fixes
mattdangerw Jun 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
)
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone
from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM
from keras_nlp.models.opt.opt_causal_lm_preprocessor import (
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/models/gpt_neo_x/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
193 changes: 193 additions & 0 deletions keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras

from keras_nlp.models.gpt_neo_x.rotary_embedding import RotaryEmbedding
from keras_nlp.utils.keras_utils import clone_initializer


class GPTNeoXAttention(keras.layers.Layer):
"""GPTNeoXAttention layer.

This is an implementation of attention layer as described in the
paper ["GPT-NeoX-20B: An Open-Source Autoregressive Language Model"](https://arxiv.org/abs/2204.06745).
Effectively, this layer implements Multi-Head Self Attention with a rotary
embedding for encoding position information.

Args:
num_heads: int. Number of attention heads.
hidden_dim: int. Hidden dimension of the input, i.e., `hidden_states`.
bucket_size: int. The size of the relative position
buckets. Generally equal to `max_sequence_length // 2`.
dropout: float. Dropout probability.
kernel_initializer: string or `keras.initializers` initializer.
The kernel initializer for the dense layers.
bias_initializer: string or `keras.initializers` initializer.
The bias initializer for the dense layers.
rotary_percentage: float. The percentage by which query, key, value
matrices are to be rotated.
rotary_max_wavelength: int. The maximum angular wavelength of the
sine/cosine curves, for rotary embeddings.
max_sequence_length: int. The maximum input sequence length.
"""

def __init__(
self,
num_heads,
hidden_dim,
dropout=0.0,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
rotary_percentage=0.25,
rotary_max_wavelength=10000,
max_sequence_length=512,
**kwargs,
):
super().__init__(**kwargs)
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.rotary_percentage = rotary_percentage
self.dropout = dropout
self.attn_head_size = hidden_dim // num_heads
self.rotary_max_wavelength = rotary_max_wavelength
self.rotary_embedding = RotaryEmbedding(
self.rotary_percentage, rotary_max_wavelength
)
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self.max_sequence_length = max_sequence_length

self._qkv_dense = keras.layers.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, self.num_heads, 3 * self.attn_head_size),
bias_axes="de",
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="query",
)

self._attn_dropout_layer = keras.layers.Dropout(
self.dropout, name="attention_dropout"
)

self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax")

# Output.
self._output_dense = keras.layers.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, self.hidden_dim),
bias_axes="d",
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="attention_output",
)

def _get_common_kwargs_for_sublayer(self, use_bias=True):
common_kwargs = {}

kernel_initializer = clone_initializer(self.kernel_initializer)
bias_initializer = clone_initializer(self.bias_initializer)

common_kwargs["kernel_initializer"] = kernel_initializer
if use_bias:
common_kwargs["bias_initializer"] = bias_initializer

return common_kwargs

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
mask_expansion_axis = -3
for _ in range(
attention_scores.shape.rank - attention_mask.shape.rank
):
attention_mask = tf.expand_dims(
attention_mask, axis=mask_expansion_axis
)
return self._softmax(attention_scores, attention_mask)

def _compute_attention(
self, query, key, value, attention_mask=None, training=None
):
attention_scores = tf.einsum("aecd,abcd->acbe", key, query)
norm_factor = tf.sqrt(
tf.constant(self.attn_head_size, dtype=tf.float32)
)
attention_scores /= norm_factor

attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = self._attn_dropout_layer(
attention_scores, training=training
)
attention_output = tf.einsum("acbe,aecd->abcd", attention_scores, value)

return attention_output

def call(
self,
hidden_states,
attention_mask,
training=None,
):
query_key_value = self._qkv_dense(hidden_states)

query = query_key_value[..., : self.attn_head_size]
key = query_key_value[
..., self.attn_head_size : 2 * self.attn_head_size
]
value = query_key_value[..., 2 * self.attn_head_size :]

query, key = self.rotary_embedding(query, key)

attention_output = self._compute_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
training=training,
)

# Reshape `attention_output` to `(batch_size, sequence_length, hidden_dim)`.
attention_output = tf.reshape(
attention_output,
[
tf.shape(attention_output)[0],
tf.shape(attention_output)[1],
self.hidden_dim,
],
)

attention_output = self._output_dense(attention_output)

return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"dropout": self.dropout,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"bias_initializer": keras.initializers.serialize(
self.bias_initializer
),
"rotary_percentage": self.rotary_percentage,
"rotary_max_wavelength": self.rotary_max_wavelength,
"max_sequence_length": self.max_sequence_length,
}
)
return config
159 changes: 159 additions & 0 deletions keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.gpt_neo_x.gpt_neo_x_decoder import GPTNeoXDecoder


def _gpt_neo_x_kernel_initializer(stddev=0.02):
return keras.initializers.RandomNormal(stddev=stddev)


@keras_nlp_export("keras_nlp.models.GPTNeoXBackbone")
class GPTNeoXBackbone(Backbone):
"""GPT-2 core network with hyperparameters.

This network implements a Transformer-based decoder network,
Generative Pretrained Transformer-Neo-X (GPTNeoX), as described in
["GPT-NeoX-20B: An Open-Source Autoregressive Language Model"](https://arxiv.org/abs/2204.06745).
It includes the embedding lookups and transformer layers.

The default constructor gives a fully customizable, randomly initialized
GPT-NeoX model with any number of layers, heads, and embedding
dimensions.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/EleutherAI/gpt-neox/).

Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
num_heads: int. The number of attention heads for each transformer.
The hidden size must be divisible by the number of attention heads.
hidden_dim: int. The size of the transformer encoding and pooler layers.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
dropout: float. Dropout probability for the Transformer encoder.
layer_norm_epsilon: float. a value added to the denominator for
numerical stability.
rotary_max_wavelength: int. The maximum angular wavelength of the
sine/cosine curves, for rotary embeddings.
rotary_percentage: float. The percentage by which query, key, value
matrices are to be rotated
max_sequence_length: int. The maximum sequence length that this encoder
can consume. If `None`, `max_sequence_length` uses the value from
sequence length. This determines the variable shape for positional
embeddings.
"""

def __init__(
self,
vocabulary_size,
num_layers,
num_heads,
hidden_dim,
intermediate_dim,
dropout=0.0,
rotary_percentage=0.25,
rotary_max_wavelength=10000,
layer_norm_epsilon=1e-5,
max_sequence_length=512,
**kwargs,
):
# Inputs
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# Embed tokens
token_embedding = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
embeddings_initializer=_gpt_neo_x_kernel_initializer(stddev=0.01),
name="token_embedding",
)(token_ids)

x = keras.layers.Dropout(
dropout,
name="embeddings_dropout",
)(token_embedding)

# Apply successive transformer decoder blocks.
for i in range(num_layers):
x = GPTNeoXDecoder(
intermediate_dim=intermediate_dim,
num_heads=num_heads,
dropout=dropout,
max_sequence_length=max_sequence_length,
rotary_percentage=rotary_percentage,
rotary_max_wavelength=rotary_max_wavelength,
layer_norm_epsilon=layer_norm_epsilon,
activation=lambda x: keras.activations.gelu(
x, approximate=True
),
kernel_initializer=_gpt_neo_x_kernel_initializer(stddev=0.02),
name=f"transformer_layer_{i}",
)(x, decoder_padding_mask=padding_mask)

sequence_output = keras.layers.LayerNormalization(
name="layer_norm",
axis=-1,
epsilon=layer_norm_epsilon,
dtype=tf.float32,
)(x)

# Instantiate using Functional API Model constructor
super().__init__(
inputs={
"token_ids": token_ids,
"padding_mask": padding_mask,
},
outputs=sequence_output,
**kwargs,
)
# All references to `self` below this line
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.max_sequence_length = max_sequence_length
self.layer_norm_epsilon = layer_norm_epsilon

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"layer_norm_epsilon": self.layer_norm_epsilon,
}
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")
Loading