Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding GPTNeoXBackbone #1056

Merged
merged 27 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9412a83
added gpt-neo attention+decoder+backbone
shivance May 29, 2023
99a8296
fixed formatting + added backbone test
shivance May 29, 2023
afb7e1f
fixed rotary embedding and gpt neo attention layer
shivance Jun 6, 2023
f0f6383
updating decoder and backbone to current version
shivance Jun 6, 2023
bfd56fa
fixed decoder + backbone
shivance Jun 7, 2023
97a347d
fix forward pass
shivance Jun 10, 2023
5ead767
formatting + add checkpoint script
shivance Jun 10, 2023
5776ac1
fix tpu_test, formatting
shivance Jun 10, 2023
e0d343b
removed unnecessary layernorms, correct arguments, fix unit tests (te…
shivance Jun 12, 2023
451cdbc
fix dropout
shivance Jun 12, 2023
e37fb22
matching outputs with hf
shivance Jun 14, 2023
ead11c5
fix formating
shivance Jun 14, 2023
c7117a4
resolving few comments
shivance Jun 14, 2023
c72e629
fixed unit tests + formatting
shivance Jun 16, 2023
2341d0e
refactored rotary embedding
shivance Jun 16, 2023
6112357
revamped checkpoint conversion script
shivance Jun 16, 2023
66afa7c
code format
shivance Jun 16, 2023
f363f24
putting old checkpoint script back until preset
shivance Jun 16, 2023
7a66052
incorporated comments
shivance Jun 17, 2023
6f6f41e
code format
shivance Jun 17, 2023
f34ec47
resolved comments + fixed formatting
shivance Jun 17, 2023
34db7f7
added gpt neo x tokenizer
shivance Jun 17, 2023
1ecfe51
added docstrings
shivance Jun 21, 2023
b3f06e4
formatting fix
shivance Jun 21, 2023
a9f2230
addressing comments
shivance Jun 23, 2023
122a3fb
added tokenizer output verification
shivance Jun 23, 2023
e10ea50
Minor style fixes
mattdangerw Jun 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
)
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM
from keras_nlp.models.opt.opt_causal_lm_preprocessor import (
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/models/gpt_neo_x/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
152 changes: 152 additions & 0 deletions keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import tensorflow as tf
from tensorflow import keras

from keras_nlp.models.gpt_neo_x.rotary_embedding import RotaryEmbedding
from keras_nlp.utils.keras_utils import clone_initializer


class GPTNeoXAttention(keras.layers.Layer):
def __init__(
self,
num_heads,
hidden_dim,
dropout=0.1,
max_sequence_length=512,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
rotary_percentage=0.25,
rotary_max_wavelength=10000,
**kwargs,
):
super().__init__(**kwargs)
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.rotary_percentage = rotary_percentage
self.dropout = dropout
self.attn_head_size = hidden_dim // num_heads
self.rotary_ndims = int(self.attn_head_size * rotary_percentage)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me like you can move this line down into the rotary layer itself, you can get at attn_head_size simply by reading the shape of the passed query and value right?

I would pass percentage and max_wavelength directly as arguments to RotaryEmbedding, and keep all the logic there, that will keep things more compartmentalized.

Copy link
Collaborator Author

@shivance shivance Jun 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputs match after this refactor :)

self.rotary_max_wavelength = rotary_max_wavelength
self.max_sequence_length = max_sequence_length
self.rotary_embedding = RotaryEmbedding(
self.rotary_ndims, rotary_max_wavelength
)
self.norm_factor = np.sqrt(self.attn_head_size)
self._kernel_initializer = keras.initializers.get(kernel_initializer)
self._bias_initializer = keras.initializers.get(bias_initializer)

self._qkv_dense = keras.layers.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, self.num_heads, 3 * self.attn_head_size),
bias_axes="de",
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="query",
)

self._attn_dropout_layer = keras.layers.Dropout(
self.dropout, name="attention_dropout"
)

self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax")

# Output.
self._output_dense = keras.layers.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, self.hidden_dim),
bias_axes="d",
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="attention_output",
)

def _get_common_kwargs_for_sublayer(self, use_bias=True):
common_kwargs = {}

kernel_initializer = clone_initializer(self._kernel_initializer)
bias_initializer = clone_initializer(self._bias_initializer)

common_kwargs["kernel_initializer"] = kernel_initializer
if use_bias:
common_kwargs["bias_initializer"] = bias_initializer

return common_kwargs

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
mask_expansion_axis = -3
for _ in range(
attention_scores.shape.rank - attention_mask.shape.rank
):
attention_mask = tf.expand_dims(
attention_mask, axis=mask_expansion_axis
)
return self._softmax(attention_scores, attention_mask)

def _compute_attention(
self, query, key, value, attention_mask=None, training=None
):
attention_scores = tf.einsum("aecd,abcd->acbe", key, query)
attention_scores /= self.norm_factor

attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = self._attn_dropout_layer(
attention_scores, training=training
)
attention_output = tf.einsum("acbe,aecd->abcd", attention_scores, value)

return attention_output, attention_scores

def call(
self,
hidden_states,
attention_mask,
return_attention_scores=False,
training=None,
):
query_key_value = self._qkv_dense(hidden_states)

query = query_key_value[..., : self.attn_head_size]
key = query_key_value[
..., self.attn_head_size : 2 * self.attn_head_size
]
value = query_key_value[..., 2 * self.attn_head_size :]

query, key = self.rotary_embedding(query, key)

attention_output, attention_scores = self._compute_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
training=training,
)

# Reshape `attention_output` to `(batch_size, sequence_length, hidden_dim)`.
attention_output = tf.reshape(
attention_output,
[
tf.shape(attention_output)[0],
tf.shape(attention_output)[1],
self.hidden_dim,
],
)

attention_output = self._output_dense(attention_output)

if return_attention_scores:
return attention_output, attention_scores
return attention_output
122 changes: 122 additions & 0 deletions keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.gpt_neo_x.gpt_neo_x_decoder import GPTNeoXDecoder


def _gpt_neo_x_kernel_initializer(stddev=0.02):
return keras.initializers.RandomNormal(stddev=stddev)


@keras_nlp_export("keras_nlp.models.GPTNeoXBackbone")
class GPTNeoXBackbone(Backbone):
def __init__(
self,
vocabulary_size,
num_layers,
num_heads,
hidden_dim,
intermediate_dim,
dropout=0.0,
rotary_percentage=0.25,
rotary_max_wavelength=10000,
layer_norm_epsilon=1e-5,
max_sequence_length=512,
**kwargs,
):
# Inputs
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# Embed tokens, positions.
token_embedding = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
embeddings_initializer=_gpt_neo_x_kernel_initializer(stddev=0.01),
name="token_embedding",
)(token_ids)

x = keras.layers.Dropout(
dropout,
name="embeddings_dropout",
)(token_embedding)

# Apply successive transformer decoder blocks.
for i in range(num_layers):
x = GPTNeoXDecoder(
intermediate_dim=intermediate_dim,
num_heads=num_heads,
dropout=dropout,
max_sequence_length=max_sequence_length,
rotary_percentage=rotary_percentage,
rotary_max_wavelength=rotary_max_wavelength,
layer_norm_epsilon=layer_norm_epsilon,
activation=lambda x: keras.activations.gelu(
x, approximate=True
),
kernel_initializer=_gpt_neo_x_kernel_initializer(stddev=0.02),
name=f"transformer_layer_{i}",
)(x, decoder_padding_mask=padding_mask)

sequence_output = keras.layers.LayerNormalization(
name="layer_norm",
axis=-1,
epsilon=layer_norm_epsilon,
dtype=tf.float32,
)(x)

# Instantiate using Functional API Model constructor
super().__init__(
inputs={
"token_ids": token_ids,
"padding_mask": padding_mask,
},
outputs=sequence_output,
**kwargs,
)
# All references to `self` below this line
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.max_sequence_length = max_sequence_length
self.layer_norm_epsilon = layer_norm_epsilon

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"layer_norm_epsilon": self.layer_norm_epsilon,
}
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")
Loading