Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding GPTNeoXBackbone #1056

Merged
merged 27 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9412a83
added gpt-neo attention+decoder+backbone
shivance May 29, 2023
99a8296
fixed formatting + added backbone test
shivance May 29, 2023
afb7e1f
fixed rotary embedding and gpt neo attention layer
shivance Jun 6, 2023
f0f6383
updating decoder and backbone to current version
shivance Jun 6, 2023
bfd56fa
fixed decoder + backbone
shivance Jun 7, 2023
97a347d
fix forward pass
shivance Jun 10, 2023
5ead767
formatting + add checkpoint script
shivance Jun 10, 2023
5776ac1
fix tpu_test, formatting
shivance Jun 10, 2023
e0d343b
removed unnecessary layernorms, correct arguments, fix unit tests (te…
shivance Jun 12, 2023
451cdbc
fix dropout
shivance Jun 12, 2023
e37fb22
matching outputs with hf
shivance Jun 14, 2023
ead11c5
fix formating
shivance Jun 14, 2023
c7117a4
resolving few comments
shivance Jun 14, 2023
c72e629
fixed unit tests + formatting
shivance Jun 16, 2023
2341d0e
refactored rotary embedding
shivance Jun 16, 2023
6112357
revamped checkpoint conversion script
shivance Jun 16, 2023
66afa7c
code format
shivance Jun 16, 2023
f363f24
putting old checkpoint script back until preset
shivance Jun 16, 2023
7a66052
incorporated comments
shivance Jun 17, 2023
6f6f41e
code format
shivance Jun 17, 2023
f34ec47
resolved comments + fixed formatting
shivance Jun 17, 2023
34db7f7
added gpt neo x tokenizer
shivance Jun 17, 2023
1ecfe51
added docstrings
shivance Jun 21, 2023
b3f06e4
formatting fix
shivance Jun 21, 2023
a9f2230
addressing comments
shivance Jun 23, 2023
122a3fb
added tokenizer output verification
shivance Jun 23, 2023
e10ea50
Minor style fixes
mattdangerw Jun 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions keras_nlp/layers/masked_lm_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,7 @@ def call(self, inputs):
# convert dense to ragged.
inputs = tf.RaggedTensor.from_tensor(inputs)

(
token_ids,
mask_positions,
mask_ids,
) = tf_text.mask_language_model(
(token_ids, mask_positions, mask_ids,) = tf_text.mask_language_model(
inputs,
item_selector=self._random_selector,
mask_values_chooser=self._mask_values_chooser,
Expand Down
140 changes: 140 additions & 0 deletions keras_nlp/models/gpt_neox/gpt_neox_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.transformer_layer_utils import compute_causal_mask
from keras_nlp.models.gpt_neox.rotary_embedding import RotaryEmbedding


class GPTNeoXAttention(keras.layers.Layer):
def __init__(
self,
num_heads,
hidden_dim,
rotary_pct=0.25,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's consider better names for these two argument. probably rotary_percentage is more consistent with Keras' style, and rotary_emb_base is a little confusing, are there better names we could consider from the paper or elsewhere?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one option is to document this the same as max_wavelength in our SinePositionEncoding layer. https://keras.io/api/keras_nlp/modeling_layers/sine_position_encoding/

I'm not sure it's the best name, but at least it will be consistent across the library. We could name these arguments rotary_percentage and rotary_max_wavelength here, and just percentage and max_wavelength on the rotary layer itself.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed this !

max_position_embeddings=2048,
):

super().__init__()
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.head_dim = hidden_dim // num_heads
self.rotary_dim = self.head_dim * rotary_pct
self.max_position_embeddings = max_position_embeddings
self.rotary_embedding = RotaryEmbedding(self.rotary_pct)
self.qkv = keras.layers.Dense(3 * self.hidden_dim)
self.dense = keras.layers.Dense(self.hidden_dim)

def _compute_attention(
self, query, key, value, attention_mask=None, head_mask=None
):

batch_size, _, query_len, _ = tf.shape(query)
key_len = tf.shape(key)[-2]
# causal_mask = self.bias[:, :, key_len - query_len : key_len, :key_len]
causal_mask = compute_causal_mask(batch_size, key_len, key_len)

query = tf.reshape(
query, [batch_size * self.num_heads, query_len, self.head_dim]
)
key = tf.reshape(
key, [batch_size * self.num_heads, query_len, self.head_dim]
)
attention_scores = tf.zeros(
[batch_size * self.num_heads, query_len, self.head_dim],
dtype=query.dtype,
)

attention_scores = tf.linalg.matmul(
attention_scores,
query,
tf.transpose(key, perm=[0, 2, 1]),
beta=1.0,
alpha=(tf.constant(1.0)),
)
attention_scores = tf.reshape(
attention_scores, [batch_size, self.num_heads, query_len, key_len]
)
mask_value = tf.constant(float("-inf"), dtype=attention_scores.dtype)
attention_scores = tf.where(causal_mask, attention_scores, mask_value)

if attention_mask is not None:
attention_scores += attention_mask

attention_scores = tf.cast(
tf.nn.softmax(attention_scores, axis=-1), dtype=value.dtype
)

if head_mask is not None:
attention_scores *= head_mask

attention_output = tf.matmul(attention_scores, value)
return attention_output, attention_scores

def call(
self,
hidden_states,
attention_mask,
head_mask,
layer_past,
return_attention_scores,
):

qkv = self.qkv(hidden_states)
new_qkv_shape = tf.shape(hidden_states)[:-1] + [
self.num_heads,
self.head_dim,
]
qkv = tf.reshape(qkv, new_qkv_shape)

query = tf.transpose(qkv[..., : self.head_dim], (0, 2, 1, 3))
key = tf.transpose(
qkv[..., : self.head_dim : 2 * self.head_dim], (0, 2, 1, 3)
)
value = tf.transpose(qkv[..., self.head_dim :], (0, 2, 1, 3))

query_rot, query_pass = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would be better off moving this slice and concat logic into the RotaryEmbedding call. Then our usage here could look a little more like...

query = self.rotary_embedding(query)
key = self.rotary_embedding(key)

And the rotary embedding layer could also hold the percentage argument, which would conceptually be quite clean. Looks like falcon is doing this roughly -> https://huggingface.co/tiiuae/falcon-40b/blob/main/modelling_RW.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this wonderful suggestion !

query[..., : self.rotary_dim],
query[..., self.rotary_dim :],
)
key_rot, key_pass = (
key[..., : self.rotary_dim],
key[..., self.rotary_dim :],
)

query, key = self.rotary_embedding(query_rot, key_rot)
query = tf.concat((query, query_pass), dim=-1)
key = tf.concat((key, key_pass), dim=-1)

if layer_past is not None:
past_key, past_value = layer_past
key = tf.concat((past_key, key), axis=-2)
value = tf.concat((past_value, value), axis=-2)

attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask, head_mask
)
new_shape = tf.shape(attention_output)[:-2] + (
self.num_heads * self.head_dim
)
attention_output = tf.reshape(
tf.transpose(attention_output, (0, 2, 1, 3)), new_shape
)
attention_output = self.dense(attention_output)

if return_attention_scores:
return (attention_output, attention_scores)

return attention_output
168 changes: 168 additions & 0 deletions keras_nlp/models/gpt_neox/gpt_neox_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras
from tensorflow.experimental import dtensor
from tensorflow.experimental.dtensor import Layout

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.gpt_neox.gpt_neox_decoder import GPTNeoXDecoder


def _gpt_2_kernel_initializer(stddev=0.02):
return keras.initializers.RandomNormal(stddev=stddev)


@keras_nlp_export("keras_nlp.models.GPT2Backbone")
class GPTNeoXBackbone(Backbone):
def __init__(
self,
vocabulary_size,
num_layers,
num_heads,
hidden_dim,
intermediate_dim,
dropout=0.1,
max_sequence_length=1024,
**kwargs,
):
# Inputs
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# Embed tokens, positions.
token_embedding = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
embeddings_initializer=_gpt_2_kernel_initializer(stddev=0.01),
name="token_embedding",
)(token_ids)

# Can't use `TokenAndPositionEmbedding` layer here because of different
# initializers.
position_embedding = PositionEmbedding(
initializer=_gpt_2_kernel_initializer(stddev=0.02),
sequence_length=max_sequence_length,
name="position_embedding",
)(token_embedding)

# Sum and apply dropout to embeddings.
x = keras.layers.Add(name="embeddings_add")(
(token_embedding, position_embedding)
)
x = keras.layers.Dropout(
dropout,
name="embeddings_dropout",
)(x)

# Apply successive transformer decoder blocks.
for i in range(num_layers):
x = GPTNeoXDecoder(
intermediate_dim=intermediate_dim,
num_heads=num_heads,
dropout=dropout,
layer_norm_epsilon=1e-05,
activation=lambda x: keras.activations.gelu(
x, approximate=True
),
kernel_initializer=_gpt_2_kernel_initializer(stddev=0.02),
normalize_first=True,
name=f"transformer_layer_{i}",
)(x, decoder_padding_mask=padding_mask)

sequence_output = keras.layers.LayerNormalization(
name="layer_norm",
axis=-1,
epsilon=1e-05,
dtype=tf.float32,
)(x)

# Instantiate using Functional API Model constructor
super().__init__(
inputs={
"token_ids": token_ids,
"padding_mask": padding_mask,
},
outputs=sequence_output,
**kwargs,
)
# All references to `self` below this line
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.max_sequence_length = max_sequence_length

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")

@classmethod
def create_layout_map(cls, mesh):

# We assert the mesh is 2D, and assume the first mesh dim is for data
# parallel and the second dim is for model parallel.
mesh_shape = mesh.shape()
if len(mesh_shape) != 2:
raise ValueError(
f"Expect to create layout based on 2D mesh, received {mesh}"
)
_, model_dim = mesh.dim_names
unshard_dim = dtensor.UNSHARDED

layout_map = keras.dtensor.experimental.LayoutMap(mesh=mesh)
# Embedding sharding
layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh)

# Transformer block sharding
layout_map[r".*_(query|key|value)_dense.kernel"] = Layout(
[unshard_dim, unshard_dim, model_dim], mesh
)
layout_map[r".*_(query|key|value)_dense.bias"] = Layout(
[model_dim, unshard_dim], mesh
)
layout_map[r".*_feedforward_intermediate_dense.kernel"] = Layout(
[unshard_dim, model_dim], mesh
)
layout_map[r".*_feedforward_intermediate_dense.bias"] = Layout(
[model_dim], mesh
)
layout_map[r".*_feedforward_output_dense.kernel"] = Layout(
[model_dim, unshard_dim], mesh
)
layout_map[r".*_feedforward_output_dense.bias"] = Layout(
[unshard_dim], mesh
)
return layout_map
Loading