diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index 6f4ae449de..45f77ce494 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -85,10 +85,7 @@ def __init__( self.built = True def call(self, inputs, start_index=0): - rotary_dim = ops.shape(inputs)[-1] - cos_emb, sin_emb = self._compute_cos_sin_embedding( - inputs, rotary_dim, start_index - ) + cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, start_index) return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): @@ -96,34 +93,44 @@ def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): half_rot_tensor = ops.concatenate((-x2, x1), axis=self.feature_axis) return (tensor * cos_emb) + (half_rot_tensor * sin_emb) - def _compute_cos_sin_embedding(self, x, rotary_dim, start_index): - freq_range = ops.arange(0, rotary_dim, 2) - freq_range = ops.cast(freq_range, self.compute_dtype) - freq_range = freq_range / ops.cast( - self.scaling_factor, self.compute_dtype - ) - inverse_freq = 1.0 / ( - self.max_wavelength - ** (freq_range / ops.cast(rotary_dim, self.compute_dtype)) - ) - seq_len = ops.shape(x)[self.sequence_axis] - tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index - tensor = ops.cast(tensor, dtype=inverse_freq.dtype) - freq = ops.einsum("i, j -> ij", tensor, inverse_freq) - embedding = ops.concatenate((freq, freq), axis=self.feature_axis) - + def _compute_cos_sin_embedding(self, inputs, start_index=0): def get_axis(axis): - return axis if axis > 0 else len(x.shape) + axis + return axis if axis > 0 else len(inputs.shape) + axis feature_axis = get_axis(self.feature_axis) sequence_axis = get_axis(self.sequence_axis) - for axis in range(len(x.shape)): + rotary_dim = ops.shape(inputs)[feature_axis] + inverse_freq = self._get_inverse_freq(rotary_dim) + + seq_len = ops.shape(inputs)[self.sequence_axis] + tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index + + tensor = ops.cast(tensor, dtype=inverse_freq.dtype) + freq = ops.einsum("i,j->ij", tensor, inverse_freq) + embedding = ops.concatenate((freq, freq), axis=-1) + + # Reshape the embedding to be broadcastable with input shape. + if feature_axis < sequence_axis: + embedding = ops.transpose(embedding) + for axis in range(len(inputs.shape)): if axis != sequence_axis and axis != feature_axis: embedding = ops.expand_dims(embedding, axis) return ops.cos(embedding), ops.sin(embedding) + def _get_inverse_freq(self, rotary_dim): + freq_range = ops.arange(0, rotary_dim, 2) + freq_range = ops.cast(freq_range, self.compute_dtype) + freq_range = freq_range / ops.cast( + self.scaling_factor, self.compute_dtype + ) + inverse_freq = 1.0 / ( + self.max_wavelength + ** (freq_range / ops.cast(rotary_dim, self.compute_dtype)) + ) + return inverse_freq + def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/layers/modeling/rotary_embedding_test.py b/keras_nlp/layers/modeling/rotary_embedding_test.py index 9874f69e5e..c0fc2906e7 100644 --- a/keras_nlp/layers/modeling/rotary_embedding_test.py +++ b/keras_nlp/layers/modeling/rotary_embedding_test.py @@ -97,6 +97,18 @@ def test_start_index(self): ) self.assertAllClose(full_output, sequential_output) + def test_permuted_axes(self): + batch_size, seq_length, feature_size = 2, 3, 4 + data = random.uniform(shape=(batch_size, seq_length, feature_size)) + layer = RotaryEmbedding(seq_length) + outputs = layer(data) + permuted_data = ops.transpose(data, (0, 2, 1)) + permuted_layer = RotaryEmbedding( + seq_length, sequence_axis=-1, feature_axis=-2 + ) + permuted_outputs = permuted_layer(permuted_data) + self.assertAllClose(outputs, ops.transpose(permuted_outputs, (0, 2, 1))) + def test_float16_dtype(self): embedding_layer = RotaryEmbedding(dtype="float16") seq_length = 100 diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 8f8e3a2ab3..ab04d8eae0 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -89,6 +89,7 @@ GPTNeoXPreprocessor, ) from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_nlp.models.llama.llama_backbone import LlamaBackbone from keras_nlp.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM diff --git a/keras_nlp/models/llama/__init__.py b/keras_nlp/models/llama/__init__.py new file mode 100644 index 0000000000..ba0c2545e4 --- /dev/null +++ b/keras_nlp/models/llama/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/llama/llama_attention.py b/keras_nlp/models/llama/llama_attention.py new file mode 100644 index 0000000000..a2604e5351 --- /dev/null +++ b/keras_nlp/models/llama/llama_attention.py @@ -0,0 +1,201 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_nlp.utils.keras_utils import clone_initializer + + +class LlamaAttention(keras.layers.Layer): + """Grouped query attention for Llama models""" + + def __init__( + self, + num_query_heads, + num_key_value_heads, + rope_scaling_factor=1.0, + kernel_initializer="glorot_uniform", + rope_max_wavelength=10000, + max_sequence_length=512, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + + self.num_key_value_groups = num_query_heads // num_key_value_heads + + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.max_sequence_length = max_sequence_length + + self.rope_scaling_factor = rope_scaling_factor + self.rope_max_wavelength = rope_max_wavelength + + def build(self, inputs_shape): + self.hidden_dim = inputs_shape[-1] + self.attn_head_size = self.hidden_dim // self.num_query_heads + + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self.num_query_heads, self.attn_head_size), + kernel_initializer=clone_initializer(self.kernel_initializer), + name="query", + ) + self._query_dense.build(inputs_shape) + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, self.attn_head_size), + kernel_initializer=clone_initializer(self.kernel_initializer), + name="key", + ) + self._key_dense.build(inputs_shape) + + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, self.attn_head_size), + kernel_initializer=clone_initializer(self.kernel_initializer), + name="value", + ) + self._value_dense.build(inputs_shape) + + self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax") + + self._output_dense = keras.layers.EinsumDense( + equation="bqm,mh->bqh", + output_shape=(None, self.hidden_dim), + kernel_initializer=clone_initializer(self.kernel_initializer), + name="attention_output", + ) + self._output_dense.build(inputs_shape) + + self._rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, + ) + self._rotary_embedding_layer.build(inputs_shape) + + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + ): + query = self._query_dense(hidden_states) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update = self._key_dense(hidden_states) + value_update = self._value_dense(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + key = self._key_dense(hidden_states) + value = self._value_dense(hidden_states) + + query = self._rotary_embedding_layer(query) + key = self._rotary_embedding_layer(key) + + key = ops.tile(key, [1, 1, self.num_key_value_groups, 1]) + value = ops.tile(value, [1, 1, self.num_key_value_groups, 1]) + + attention_output, attention_scores = self._compute_attention( + query, key, value, attention_mask + ) + + attention_output_shape = ops.shape(attention_output) + + attention_output = ops.reshape( + attention_output, + [ + attention_output_shape[0], + attention_output_shape[1], + self.hidden_dim, + ], + ) + + attention_output = self._output_dense(attention_output) + + if cache is not None: + return (attention_output, cache) + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + if attention_mask is not None: + mask_expansion_axis = -3 + for _ in range( + len(attention_scores.shape) - len(attention_mask.shape) + ): + attention_mask = ops.expand_dims( + attention_mask, axis=mask_expansion_axis + ) + return self._softmax(attention_scores, attention_mask) + + def _compute_attention(self, query, key, value, attention_mask=None): + attention_scores = ops.einsum("aecd,abcd->acbe", key, query) + + norm_factor = ops.sqrt( + ops.convert_to_tensor(self.attn_head_size, self.compute_dtype) + ) + + attention_scores /= norm_factor + + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_output = ops.einsum( + "acbe,aecd->abcd", attention_scores, value + ) + + return attention_output, attention_scores + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "num_key_value_heads": self.num_key_value_heads, + "max_sequence_length": self.max_sequence_length, + } + ) + return config diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py new file mode 100644 index 0000000000..63438544cc --- /dev/null +++ b/keras_nlp/models/llama/llama_backbone.py @@ -0,0 +1,156 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.llama.llama_decoder import LlamaDecoder +from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm + + +def _llama_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_nlp_export("keras_nlp.models.LlamaBackbone") +class LlamaBackbone(Backbone): + """ + LLaMA core network with hyperparameters. + + This network implements a Transformer-based decoder network, + LLaMA, as described in ["LLaMA: Open Foundation and Fine-Tuned Language Models"](https://arxiv.org/abs/2302.13971). + + The default constructor gives a fully customizable, randomly initialized + LLaMA model with any number of layers, heads, and embedding + dimensions. This backbone also supports LLaMA2 checkpoints. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The size of the transformer encoding and pooler layers. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + num_key_value_heads: int. This is the number of key_value heads that + should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, + the model will use Multi Head Attention (MHA), if num_key_value_heads=1 + the model will use Multi Query Attention (MQA) + rope_scaling_factor: float. The scaling factor for calculation of rotary + embedding + rope_max_wavelength: int. The maximum angular wavelength of the + sine/cosine curves, for rotary embeddings. + layer_norm_epsilon: float. a value added to the denominator for + numerical stability. + max_sequence_length: int. The maximum sequence length that this encoder + can consume. If `None`, `max_sequence_length` uses the value from + sequence length. This determines the variable shape for positional + embeddings. + + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + hidden_dim, + intermediate_dim, + num_key_value_heads, + rope_scaling_factor=1.0, + rope_max_wavelength=10000, + layer_norm_epsilon=1e-5, + max_sequence_length=4096, + **kwargs, + ): + # Inputs + token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # Embed tokens + token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=_llama_kernel_initializer(stddev=0.01), + tie_weights=False, + name="token_embedding", + )(token_ids) + + x = token_embedding + + # Apply successive transformer decoder blocks. + for i in range(num_layers): + x = LlamaDecoder( + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + rope_scaling_factor=rope_scaling_factor, + max_sequence_length=max_sequence_length, + rope_max_wavelength=rope_max_wavelength, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_llama_kernel_initializer(stddev=0.02), + name=f"transformer_layer_{i}", + )(x, decoder_padding_mask=padding_mask) + + sequence_output = LlamaLayerNorm( + name="layer_norm", + epsilon=layer_norm_epsilon, + )(x) + + # Instantiate using Functional API Model constructor + super().__init__( + inputs={ + "token_ids": token_ids, + "padding_mask": padding_mask, + }, + outputs=sequence_output, + **kwargs, + ) + # All references to `self` below this line + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.rope_max_wavelength = rope_max_wavelength + self.num_key_value_heads = num_key_value_heads + self.rope_scaling_factor = rope_scaling_factor + self.max_sequence_length = max_sequence_length + self.layer_norm_epsilon = layer_norm_epsilon + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "num_key_value_heads": self.num_key_value_heads, + "max_sequence_length": self.max_sequence_length, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + + @property + def token_embedding(self): + return self.get_layer("token_embedding") diff --git a/keras_nlp/models/llama/llama_backbone_test.py b/keras_nlp/models/llama/llama_backbone_test.py new file mode 100644 index 0000000000..efff972c6b --- /dev/null +++ b/keras_nlp/models/llama/llama_backbone_test.py @@ -0,0 +1,52 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.tests.test_case import TestCase + + +class LlamaTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 2, + "hidden_dim": 8, + "intermediate_dim": 8, + "max_sequence_length": 10, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=LlamaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 8), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=LlamaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py new file mode 100644 index 0000000000..47bac478cc --- /dev/null +++ b/keras_nlp/models/llama/llama_decoder.py @@ -0,0 +1,206 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.llama.llama_attention import LlamaAttention +from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm +from keras_nlp.utils.keras_utils import clone_initializer + + +class LlamaDecoder(keras.layers.Layer): + """Llama decoder block.""" + + def __init__( + self, + intermediate_dim, + num_query_heads, + num_key_value_heads, + rope_scaling_factor=1.0, + activation="relu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + rope_max_wavelength=10000, + max_sequence_length=512, + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + + self.max_sequence_length = max_sequence_length + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + + def build(self, decoder_sequence_shape): + self.hidden_dim = decoder_sequence_shape[-1] + + # Self attention layers. + self._self_attention_layer = LlamaAttention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + rope_max_wavelength=self.rope_max_wavelength, + max_sequence_length=self.max_sequence_length, + rope_scaling_factor=self.rope_scaling_factor, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + self._self_attention_layer.build(decoder_sequence_shape) + + self._self_attention_layernorm = LlamaLayerNorm( + epsilon=self.layer_norm_epsilon, + ) + self._self_attention_layernorm.build(decoder_sequence_shape) + + # Feedforward layers. + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + + self._feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + activation=self.activation, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + self._feedforward_gate_dense.build(decoder_sequence_shape) + + self._feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + + intermediate_shape = list(decoder_sequence_shape) + intermediate_shape[-1] = self.intermediate_dim + self._feedforward_output_dense.build(tuple(intermediate_shape)) + + self._feedforward_layernorm = LlamaLayerNorm( + epsilon=self.layer_norm_epsilon, + ) + self._feedforward_layernorm.build(decoder_sequence_shape) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=None, + ): + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, + ) + residual = decoder_sequence + + x = self._self_attention_layernorm( + decoder_sequence, + ) + + x = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + + if self_attention_cache is not None: + x, self_attention_cache = x + + x = x + residual + residual = x + + x = self._feedforward_layernorm(x) + gate_output = self._feedforward_gate_dense(x) + + x = self._feedforward_intermediate_dense(x) + + x = self._feedforward_output_dense(ops.multiply(x, gate_output)) + + decoder_output = x + residual + + if self_attention_cache is not None: + return (decoder_output, self_attention_cache) + return decoder_output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + self_attention_cache=None, + self_attention_cache_update_index=None, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + causal_mask = compute_causal_mask( + batch_size, + input_length, + output_length, + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index, + ) + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "hidden_dim": self.hidden_dim, + "num_query_heads": self.num_query_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "num_key_value_heads": self.num_key_value_heads, + "max_sequence_length": self.max_sequence_length, + "activation": keras.activations.serialize(self.activation), + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + } + ) + return config diff --git a/keras_nlp/models/llama/llama_layernorm.py b/keras_nlp/models/llama/llama_layernorm.py new file mode 100644 index 0000000000..0e85a45625 --- /dev/null +++ b/keras_nlp/models/llama/llama_layernorm.py @@ -0,0 +1,37 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops + +# TODO: Should be replaced with LayerNormalization with `rms_scaling` param +# https://github.com/keras-team/keras-core/pull/726 + + +class LlamaLayerNorm(keras.layers.Layer): + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + self.weight = self.add_weight( + name="weight", + shape=(input_shape[-1],), + initializer="ones", + ) + self.built = True + + def call(self, hidden_states): + variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True) + hidden_states = hidden_states * 1 / ops.sqrt(variance + self.epsilon) + return self.weight * hidden_states diff --git a/tools/checkpoint_conversion/convert_llama_checkpoints.py b/tools/checkpoint_conversion/convert_llama_checkpoints.py new file mode 100644 index 0000000000..5eb3973f36 --- /dev/null +++ b/tools/checkpoint_conversion/convert_llama_checkpoints.py @@ -0,0 +1,141 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from transformers import AutoModel + +from keras_nlp.models.llama.llama_backbone import LlamaBackbone + +os.environ["KERAS_BACKEND"] = "torch" + +# from huggingface_hub import login +# llama weights as of now are on request access +# login(token='