-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding GPTNeoXBackbone
#1056
Adding GPTNeoXBackbone
#1056
Changes from 2 commits
9412a83
99a8296
afb7e1f
f0f6383
bfd56fa
97a347d
5ead767
5776ac1
e0d343b
451cdbc
e37fb22
ead11c5
c7117a4
c72e629
2341d0e
6112357
66afa7c
f363f24
7a66052
6f6f41e
f34ec47
34db7f7
1ecfe51
b3f06e4
a9f2230
122a3fb
e10ea50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
from keras_nlp.layers.transformer_layer_utils import compute_causal_mask | ||
from keras_nlp.models.gpt_neox.rotary_embedding import RotaryEmbedding | ||
|
||
|
||
class GPTNeoXAttention(keras.layers.Layer): | ||
def __init__( | ||
self, | ||
num_heads, | ||
hidden_dim, | ||
rotary_pct=0.25, | ||
max_position_embeddings=2048, | ||
): | ||
|
||
super().__init__() | ||
self.num_heads = num_heads | ||
self.hidden_dim = hidden_dim | ||
self.head_dim = hidden_dim // num_heads | ||
self.rotary_dim = self.head_dim * rotary_pct | ||
self.max_position_embeddings = max_position_embeddings | ||
self.rotary_embedding = RotaryEmbedding(self.rotary_pct) | ||
self.qkv = keras.layers.Dense(3 * self.hidden_dim) | ||
self.dense = keras.layers.Dense(self.hidden_dim) | ||
|
||
def _compute_attention( | ||
self, query, key, value, attention_mask=None, head_mask=None | ||
): | ||
|
||
batch_size, _, query_len, _ = tf.shape(query) | ||
key_len = tf.shape(key)[-2] | ||
# causal_mask = self.bias[:, :, key_len - query_len : key_len, :key_len] | ||
causal_mask = compute_causal_mask(batch_size, key_len, key_len) | ||
|
||
query = tf.reshape( | ||
query, [batch_size * self.num_heads, query_len, self.head_dim] | ||
) | ||
key = tf.reshape( | ||
key, [batch_size * self.num_heads, query_len, self.head_dim] | ||
) | ||
attention_scores = tf.zeros( | ||
[batch_size * self.num_heads, query_len, self.head_dim], | ||
dtype=query.dtype, | ||
) | ||
|
||
attention_scores = tf.linalg.matmul( | ||
attention_scores, | ||
query, | ||
tf.transpose(key, perm=[0, 2, 1]), | ||
beta=1.0, | ||
alpha=(tf.constant(1.0)), | ||
) | ||
attention_scores = tf.reshape( | ||
attention_scores, [batch_size, self.num_heads, query_len, key_len] | ||
) | ||
mask_value = tf.constant(float("-inf"), dtype=attention_scores.dtype) | ||
attention_scores = tf.where(causal_mask, attention_scores, mask_value) | ||
|
||
if attention_mask is not None: | ||
attention_scores += attention_mask | ||
|
||
attention_scores = tf.cast( | ||
tf.nn.softmax(attention_scores, axis=-1), dtype=value.dtype | ||
) | ||
|
||
if head_mask is not None: | ||
attention_scores *= head_mask | ||
|
||
attention_output = tf.matmul(attention_scores, value) | ||
return attention_output, attention_scores | ||
|
||
def call( | ||
self, | ||
hidden_states, | ||
attention_mask, | ||
head_mask, | ||
layer_past, | ||
return_attention_scores, | ||
): | ||
|
||
qkv = self.qkv(hidden_states) | ||
new_qkv_shape = tf.shape(hidden_states)[:-1] + [ | ||
self.num_heads, | ||
self.head_dim, | ||
] | ||
qkv = tf.reshape(qkv, new_qkv_shape) | ||
|
||
query = tf.transpose(qkv[..., : self.head_dim], (0, 2, 1, 3)) | ||
key = tf.transpose( | ||
qkv[..., : self.head_dim : 2 * self.head_dim], (0, 2, 1, 3) | ||
) | ||
value = tf.transpose(qkv[..., self.head_dim :], (0, 2, 1, 3)) | ||
|
||
query_rot, query_pass = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we would be better off moving this slice and concat logic into the
And the rotary embedding layer could also hold the percentage argument, which would conceptually be quite clean. Looks like falcon is doing this roughly -> https://huggingface.co/tiiuae/falcon-40b/blob/main/modelling_RW.py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for this wonderful suggestion ! |
||
query[..., : self.rotary_dim], | ||
query[..., self.rotary_dim :], | ||
) | ||
key_rot, key_pass = ( | ||
key[..., : self.rotary_dim], | ||
key[..., self.rotary_dim :], | ||
) | ||
|
||
query, key = self.rotary_embedding(query_rot, key_rot) | ||
query = tf.concat((query, query_pass), dim=-1) | ||
key = tf.concat((key, key_pass), dim=-1) | ||
|
||
if layer_past is not None: | ||
shivance marked this conversation as resolved.
Show resolved
Hide resolved
|
||
past_key, past_value = layer_past | ||
key = tf.concat((past_key, key), axis=-2) | ||
value = tf.concat((past_value, value), axis=-2) | ||
|
||
attention_output, attention_scores = self._compute_attention( | ||
query, key, value, attention_mask, head_mask | ||
) | ||
new_shape = tf.shape(attention_output)[:-2] + ( | ||
self.num_heads * self.head_dim | ||
) | ||
attention_output = tf.reshape( | ||
tf.transpose(attention_output, (0, 2, 1, 3)), new_shape | ||
) | ||
attention_output = self.dense(attention_output) | ||
|
||
if return_attention_scores: | ||
return (attention_output, attention_scores) | ||
|
||
return attention_output |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
from tensorflow.experimental import dtensor | ||
from tensorflow.experimental.dtensor import Layout | ||
|
||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.layers.position_embedding import PositionEmbedding | ||
from keras_nlp.models.backbone import Backbone | ||
from keras_nlp.models.gpt_neox.gpt_neox_decoder import GPTNeoXDecoder | ||
|
||
|
||
def _gpt_2_kernel_initializer(stddev=0.02): | ||
return keras.initializers.RandomNormal(stddev=stddev) | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.GPT2Backbone") | ||
class GPTNeoXBackbone(Backbone): | ||
def __init__( | ||
self, | ||
vocabulary_size, | ||
num_layers, | ||
num_heads, | ||
hidden_dim, | ||
intermediate_dim, | ||
dropout=0.1, | ||
max_sequence_length=1024, | ||
**kwargs, | ||
): | ||
# Inputs | ||
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") | ||
padding_mask = keras.Input( | ||
shape=(None,), dtype="int32", name="padding_mask" | ||
) | ||
|
||
# Embed tokens, positions. | ||
token_embedding = keras.layers.Embedding( | ||
input_dim=vocabulary_size, | ||
output_dim=hidden_dim, | ||
embeddings_initializer=_gpt_2_kernel_initializer(stddev=0.01), | ||
name="token_embedding", | ||
)(token_ids) | ||
|
||
# Can't use `TokenAndPositionEmbedding` layer here because of different | ||
# initializers. | ||
position_embedding = PositionEmbedding( | ||
initializer=_gpt_2_kernel_initializer(stddev=0.02), | ||
sequence_length=max_sequence_length, | ||
name="position_embedding", | ||
)(token_embedding) | ||
|
||
# Sum and apply dropout to embeddings. | ||
x = keras.layers.Add(name="embeddings_add")( | ||
(token_embedding, position_embedding) | ||
) | ||
x = keras.layers.Dropout( | ||
dropout, | ||
name="embeddings_dropout", | ||
)(x) | ||
|
||
# Apply successive transformer decoder blocks. | ||
for i in range(num_layers): | ||
x = GPTNeoXDecoder( | ||
intermediate_dim=intermediate_dim, | ||
num_heads=num_heads, | ||
dropout=dropout, | ||
layer_norm_epsilon=1e-05, | ||
activation=lambda x: keras.activations.gelu( | ||
x, approximate=True | ||
), | ||
kernel_initializer=_gpt_2_kernel_initializer(stddev=0.02), | ||
normalize_first=True, | ||
name=f"transformer_layer_{i}", | ||
)(x, decoder_padding_mask=padding_mask) | ||
|
||
sequence_output = keras.layers.LayerNormalization( | ||
name="layer_norm", | ||
axis=-1, | ||
epsilon=1e-05, | ||
shivance marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dtype=tf.float32, | ||
)(x) | ||
|
||
# Instantiate using Functional API Model constructor | ||
super().__init__( | ||
inputs={ | ||
"token_ids": token_ids, | ||
"padding_mask": padding_mask, | ||
}, | ||
outputs=sequence_output, | ||
**kwargs, | ||
) | ||
# All references to `self` below this line | ||
self.vocabulary_size = vocabulary_size | ||
self.num_layers = num_layers | ||
self.num_heads = num_heads | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.dropout = dropout | ||
self.max_sequence_length = max_sequence_length | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"vocabulary_size": self.vocabulary_size, | ||
"num_layers": self.num_layers, | ||
"num_heads": self.num_heads, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"dropout": self.dropout, | ||
"max_sequence_length": self.max_sequence_length, | ||
} | ||
) | ||
return config | ||
|
||
@property | ||
def token_embedding(self): | ||
return self.get_layer("token_embedding") | ||
|
||
@classmethod | ||
shivance marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def create_layout_map(cls, mesh): | ||
|
||
# We assert the mesh is 2D, and assume the first mesh dim is for data | ||
# parallel and the second dim is for model parallel. | ||
mesh_shape = mesh.shape() | ||
if len(mesh_shape) != 2: | ||
raise ValueError( | ||
f"Expect to create layout based on 2D mesh, received {mesh}" | ||
) | ||
_, model_dim = mesh.dim_names | ||
unshard_dim = dtensor.UNSHARDED | ||
|
||
layout_map = keras.dtensor.experimental.LayoutMap(mesh=mesh) | ||
# Embedding sharding | ||
layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh) | ||
|
||
# Transformer block sharding | ||
layout_map[r".*_(query|key|value)_dense.kernel"] = Layout( | ||
[unshard_dim, unshard_dim, model_dim], mesh | ||
) | ||
layout_map[r".*_(query|key|value)_dense.bias"] = Layout( | ||
[model_dim, unshard_dim], mesh | ||
) | ||
layout_map[r".*_feedforward_intermediate_dense.kernel"] = Layout( | ||
[unshard_dim, model_dim], mesh | ||
) | ||
layout_map[r".*_feedforward_intermediate_dense.bias"] = Layout( | ||
[model_dim], mesh | ||
) | ||
layout_map[r".*_feedforward_output_dense.kernel"] = Layout( | ||
[model_dim, unshard_dim], mesh | ||
) | ||
layout_map[r".*_feedforward_output_dense.bias"] = Layout( | ||
[unshard_dim], mesh | ||
) | ||
return layout_map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's consider better names for these two argument. probably
rotary_percentage
is more consistent with Keras' style, androtary_emb_base
is a little confusing, are there better names we could consider from the paper or elsewhere?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess one option is to document this the same as
max_wavelength
in our SinePositionEncoding layer. https://keras.io/api/keras_nlp/modeling_layers/sine_position_encoding/I'm not sure it's the best name, but at least it will be consistent across the library. We could name these arguments
rotary_percentage
androtary_max_wavelength
here, and justpercentage
andmax_wavelength
on the rotary layer itself.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed this !