-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding GPTNeoXBackbone
#1056
Merged
Merged
Adding GPTNeoXBackbone
#1056
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
9412a83
added gpt-neo attention+decoder+backbone
shivance 99a8296
fixed formatting + added backbone test
shivance afb7e1f
fixed rotary embedding and gpt neo attention layer
shivance f0f6383
updating decoder and backbone to current version
shivance bfd56fa
fixed decoder + backbone
shivance 97a347d
fix forward pass
shivance 5ead767
formatting + add checkpoint script
shivance 5776ac1
fix tpu_test, formatting
shivance e0d343b
removed unnecessary layernorms, correct arguments, fix unit tests (te…
shivance 451cdbc
fix dropout
shivance e37fb22
matching outputs with hf
shivance ead11c5
fix formating
shivance c7117a4
resolving few comments
shivance c72e629
fixed unit tests + formatting
shivance 2341d0e
refactored rotary embedding
shivance 6112357
revamped checkpoint conversion script
shivance 66afa7c
code format
shivance f363f24
putting old checkpoint script back until preset
shivance 7a66052
incorporated comments
shivance 6f6f41e
code format
shivance f34ec47
resolved comments + fixed formatting
shivance 34db7f7
added gpt neo x tokenizer
shivance 1ecfe51
added docstrings
shivance b3f06e4
formatting fix
shivance a9f2230
addressing comments
shivance 122a3fb
added tokenizer output verification
shivance e10ea50
Minor style fixes
mattdangerw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
from keras_nlp.models.gpt_neo_x.rotary_embedding import RotaryEmbedding | ||
from keras_nlp.utils.keras_utils import clone_initializer | ||
|
||
|
||
class GPTNeoXAttention(keras.layers.Layer): | ||
def __init__( | ||
self, | ||
num_heads, | ||
hidden_dim, | ||
dropout=0.1, | ||
max_sequence_length=512, | ||
kernel_initializer="glorot_uniform", | ||
bias_initializer="zeros", | ||
rotary_percentage=0.25, | ||
rotary_max_wavelength=10000, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.num_heads = num_heads | ||
self.hidden_dim = hidden_dim | ||
self.rotary_percentage = rotary_percentage | ||
self.dropout = dropout | ||
self.attn_head_size = hidden_dim // num_heads | ||
self.rotary_ndims = int(self.attn_head_size * rotary_percentage) | ||
self.rotary_max_wavelength = rotary_max_wavelength | ||
self.max_sequence_length = max_sequence_length | ||
self.rotary_embedding = RotaryEmbedding( | ||
self.rotary_ndims, rotary_max_wavelength | ||
) | ||
self.norm_factor = np.sqrt(self.attn_head_size) | ||
self._kernel_initializer = keras.initializers.get(kernel_initializer) | ||
self._bias_initializer = keras.initializers.get(bias_initializer) | ||
|
||
self._qkv_dense = keras.layers.EinsumDense( | ||
equation="abc,cde->abde", | ||
output_shape=(None, self.num_heads, 3 * self.attn_head_size), | ||
bias_axes="de", | ||
**self._get_common_kwargs_for_sublayer(use_bias=True), | ||
name="query", | ||
) | ||
|
||
self._attn_dropout_layer = keras.layers.Dropout( | ||
self.dropout, name="attention_dropout" | ||
) | ||
|
||
self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax") | ||
|
||
# Output. | ||
self._output_dense = keras.layers.EinsumDense( | ||
equation="abc,cd->abd", | ||
output_shape=(None, self.hidden_dim), | ||
bias_axes="d", | ||
**self._get_common_kwargs_for_sublayer(use_bias=True), | ||
name="attention_output", | ||
) | ||
|
||
def _get_common_kwargs_for_sublayer(self, use_bias=True): | ||
common_kwargs = {} | ||
|
||
kernel_initializer = clone_initializer(self._kernel_initializer) | ||
bias_initializer = clone_initializer(self._bias_initializer) | ||
|
||
common_kwargs["kernel_initializer"] = kernel_initializer | ||
if use_bias: | ||
common_kwargs["bias_initializer"] = bias_initializer | ||
|
||
return common_kwargs | ||
|
||
def _masked_softmax(self, attention_scores, attention_mask=None): | ||
if attention_mask is not None: | ||
mask_expansion_axis = -3 | ||
for _ in range( | ||
attention_scores.shape.rank - attention_mask.shape.rank | ||
): | ||
attention_mask = tf.expand_dims( | ||
attention_mask, axis=mask_expansion_axis | ||
) | ||
return self._softmax(attention_scores, attention_mask) | ||
|
||
def _compute_attention( | ||
self, query, key, value, attention_mask=None, training=None | ||
): | ||
attention_scores = tf.einsum("aecd,abcd->acbe", key, query) | ||
attention_scores /= self.norm_factor | ||
|
||
attention_scores = self._masked_softmax( | ||
attention_scores, attention_mask | ||
) | ||
attention_scores = self._attn_dropout_layer( | ||
attention_scores, training=training | ||
) | ||
attention_output = tf.einsum("acbe,aecd->abcd", attention_scores, value) | ||
|
||
return attention_output, attention_scores | ||
|
||
def call( | ||
self, | ||
hidden_states, | ||
attention_mask, | ||
return_attention_scores=False, | ||
training=None, | ||
): | ||
query_key_value = self._qkv_dense(hidden_states) | ||
|
||
query = query_key_value[..., : self.attn_head_size] | ||
key = query_key_value[ | ||
..., self.attn_head_size : 2 * self.attn_head_size | ||
] | ||
value = query_key_value[..., 2 * self.attn_head_size :] | ||
|
||
query, key = self.rotary_embedding(query, key) | ||
|
||
attention_output, attention_scores = self._compute_attention( | ||
query=query, | ||
key=key, | ||
value=value, | ||
attention_mask=attention_mask, | ||
training=training, | ||
) | ||
|
||
# Reshape `attention_output` to `(batch_size, sequence_length, hidden_dim)`. | ||
attention_output = tf.reshape( | ||
attention_output, | ||
[ | ||
tf.shape(attention_output)[0], | ||
tf.shape(attention_output)[1], | ||
self.hidden_dim, | ||
], | ||
) | ||
|
||
attention_output = self._output_dense(attention_output) | ||
|
||
if return_attention_scores: | ||
shivance marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return attention_output, attention_scores | ||
return attention_output | ||
shivance marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.models.backbone import Backbone | ||
from keras_nlp.models.gpt_neo_x.gpt_neo_x_decoder import GPTNeoXDecoder | ||
|
||
|
||
def _gpt_neo_x_kernel_initializer(stddev=0.02): | ||
return keras.initializers.RandomNormal(stddev=stddev) | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.GPTNeoXBackbone") | ||
class GPTNeoXBackbone(Backbone): | ||
def __init__( | ||
self, | ||
vocabulary_size, | ||
num_layers, | ||
num_heads, | ||
hidden_dim, | ||
intermediate_dim, | ||
dropout=0.0, | ||
rotary_percentage=0.25, | ||
rotary_max_wavelength=10000, | ||
layer_norm_epsilon=1e-5, | ||
max_sequence_length=512, | ||
**kwargs, | ||
): | ||
# Inputs | ||
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") | ||
padding_mask = keras.Input( | ||
shape=(None,), dtype="int32", name="padding_mask" | ||
) | ||
|
||
# Embed tokens, positions. | ||
shivance marked this conversation as resolved.
Show resolved
Hide resolved
|
||
token_embedding = keras.layers.Embedding( | ||
input_dim=vocabulary_size, | ||
output_dim=hidden_dim, | ||
embeddings_initializer=_gpt_neo_x_kernel_initializer(stddev=0.01), | ||
name="token_embedding", | ||
)(token_ids) | ||
|
||
x = keras.layers.Dropout( | ||
dropout, | ||
name="embeddings_dropout", | ||
)(token_embedding) | ||
|
||
# Apply successive transformer decoder blocks. | ||
for i in range(num_layers): | ||
x = GPTNeoXDecoder( | ||
intermediate_dim=intermediate_dim, | ||
num_heads=num_heads, | ||
dropout=dropout, | ||
max_sequence_length=max_sequence_length, | ||
rotary_percentage=rotary_percentage, | ||
rotary_max_wavelength=rotary_max_wavelength, | ||
layer_norm_epsilon=layer_norm_epsilon, | ||
activation=lambda x: keras.activations.gelu( | ||
x, approximate=True | ||
), | ||
kernel_initializer=_gpt_neo_x_kernel_initializer(stddev=0.02), | ||
name=f"transformer_layer_{i}", | ||
)(x, decoder_padding_mask=padding_mask) | ||
|
||
sequence_output = keras.layers.LayerNormalization( | ||
name="layer_norm", | ||
axis=-1, | ||
epsilon=layer_norm_epsilon, | ||
dtype=tf.float32, | ||
)(x) | ||
|
||
# Instantiate using Functional API Model constructor | ||
super().__init__( | ||
inputs={ | ||
"token_ids": token_ids, | ||
"padding_mask": padding_mask, | ||
}, | ||
outputs=sequence_output, | ||
**kwargs, | ||
) | ||
# All references to `self` below this line | ||
self.vocabulary_size = vocabulary_size | ||
self.num_layers = num_layers | ||
self.num_heads = num_heads | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.dropout = dropout | ||
self.max_sequence_length = max_sequence_length | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"vocabulary_size": self.vocabulary_size, | ||
"num_layers": self.num_layers, | ||
"num_heads": self.num_heads, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"dropout": self.dropout, | ||
"max_sequence_length": self.max_sequence_length, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
} | ||
) | ||
return config | ||
|
||
@property | ||
def token_embedding(self): | ||
return self.get_layer("token_embedding") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me like you can move this line down into the rotary layer itself, you can get at
attn_head_size
simply by reading the shape of the passedquery
andvalue
right?I would pass
percentage
andmax_wavelength
directly as arguments toRotaryEmbedding
, and keep all the logic there, that will keep things more compartmentalized.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputs match after this refactor :)