From 4ea8c2311f4dff2db640e8b411410e2f671108a4 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Tue, 19 Dec 2023 17:37:07 -0700 Subject: [PATCH] Add MistralAI's 7B Transformer as a backbone in KerasNLP Models (#1314) * Add MistralBackbone * Fix Keras 2 failure * Fix another Keras 2 failure * Make the caching step XLA compatible * Add dtype support for the MistralBackbone * Address review comments * Add docs; Make args keyword-only; Cosmetic fixes * Use keras.backend.floatx() instead of keras.config.floatx() for Keras 2 compatibility * Add review comments --- keras_nlp/layers/modeling/rotary_embedding.py | 4 +- keras_nlp/models/__init__.py | 1 + keras_nlp/models/mistral/__init__.py | 13 + keras_nlp/models/mistral/mistral_attention.py | 293 ++++++++++++ keras_nlp/models/mistral/mistral_backbone.py | 196 ++++++++ .../models/mistral/mistral_backbone_test.py | 56 +++ .../models/mistral/mistral_layer_norm.py | 48 ++ .../mistral/mistral_transformer_decoder.py | 233 +++++++++ .../convert_mistral_checkpoints.py | 443 ++++++++++++++++++ 9 files changed, 1285 insertions(+), 2 deletions(-) create mode 100644 keras_nlp/models/mistral/__init__.py create mode 100644 keras_nlp/models/mistral/mistral_attention.py create mode 100644 keras_nlp/models/mistral/mistral_backbone.py create mode 100644 keras_nlp/models/mistral/mistral_backbone_test.py create mode 100644 keras_nlp/models/mistral/mistral_layer_norm.py create mode 100644 keras_nlp/models/mistral/mistral_transformer_decoder.py create mode 100644 tools/checkpoint_conversion/convert_mistral_checkpoints.py diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index b3402f7e21..6f4ae449de 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -97,7 +97,7 @@ def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): return (tensor * cos_emb) + (half_rot_tensor * sin_emb) def _compute_cos_sin_embedding(self, x, rotary_dim, start_index): - freq_range = ops.arange(0, rotary_dim, 2, dtype="float32") + freq_range = ops.arange(0, rotary_dim, 2) freq_range = ops.cast(freq_range, self.compute_dtype) freq_range = freq_range / ops.cast( self.scaling_factor, self.compute_dtype @@ -107,7 +107,7 @@ def _compute_cos_sin_embedding(self, x, rotary_dim, start_index): ** (freq_range / ops.cast(rotary_dim, self.compute_dtype)) ) seq_len = ops.shape(x)[self.sequence_axis] - tensor = ops.arange(seq_len, dtype="float32") + start_index + tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index tensor = ops.cast(tensor, dtype=inverse_freq.dtype) freq = ops.einsum("i, j -> ij", tensor, inverse_freq) embedding = ops.concatenate((freq, freq), axis=self.feature_axis) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 23500c7460..8f8e3a2ab3 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -89,6 +89,7 @@ GPTNeoXPreprocessor, ) from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_nlp.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_nlp/models/mistral/__init__.py b/keras_nlp/models/mistral/__init__.py new file mode 100644 index 0000000000..ba0c2545e4 --- /dev/null +++ b/keras_nlp/models/mistral/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/mistral/mistral_attention.py b/keras_nlp/models/mistral/mistral_attention.py new file mode 100644 index 0000000000..680f1f6d1b --- /dev/null +++ b/keras_nlp/models/mistral/mistral_attention.py @@ -0,0 +1,293 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_nlp.utils.keras_utils import clone_initializer + + +# This is just a self-attention layer in Mistral. But it can be generalized +# to use the `keras_nlp.layers.CachedMultiHeadAttention` API. Since this layer +# implements grouped-query attention and sliding window attention, it might be +# useful outside of Mistral itself. +# TODO(tirthasheshpatel): Generalize the attention layer +# TODO(tirthasheshpatel): Merge `LlamaAttention` with this layer +# TODO(tirthasheshpatel): Use flash attention +class CachedMistralAttention(keras.layers.Layer): + """A cached grounded query attention layer with sliding window.""" + + def __init__( + self, + num_query_heads, + num_key_value_heads, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + kernel_initializer="glorot_uniform", + sliding_window=512, + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + self._num_query_heads = num_query_heads + self._num_key_value_heads = num_key_value_heads + self._sliding_window = sliding_window + self._dropout = dropout + + self._num_key_value_groups = num_query_heads // num_key_value_heads + self._rope_max_wavelength = rope_max_wavelength + + self._kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + + self._rope_scaling_factor = rope_scaling_factor + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + self._hidden_dim = inputs_shape[-1] + self._head_dim = self._hidden_dim // self._num_query_heads + + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self._num_query_heads, self._head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.compute_dtype, + name="query", + ) + self._query_dense.build(inputs_shape) + + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self._num_key_value_heads, + self._head_dim, + ), + kernel_initializer=self._kernel_initializer, + dtype=self.compute_dtype, + name="key", + ) + self._key_dense.build(inputs_shape) + + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self._num_key_value_heads, + self._head_dim, + ), + kernel_initializer=self._kernel_initializer, + dtype=self.compute_dtype, + name="value", + ) + self._value_dense.build(inputs_shape) + + self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax") + + self._dropout_layer = keras.layers.Dropout( + rate=self._dropout, dtype=self.compute_dtype + ) + + self._output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, self._hidden_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.compute_dtype, + name="attention_output", + ) + self._output_dense.build( + (None, None, self._num_query_heads, self._head_dim) + ) + + self.rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self._rope_max_wavelength, + scaling_factor=self._rope_scaling_factor, + dtype=self.compute_dtype, + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + seq_len = ops.shape(hidden_states)[1] + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) + # If `cache_update_index` is a tensor, RotaryEmbedding expects it + # to have dtype `self.compute_dtype`. + start_index = ops.cast( + start_index, self.rotary_embedding_layer.compute_dtype + ) + + query = self._query_dense(hidden_states) + + # Note that the original PyTorch implementation uses + # view_as_complex/view_as_real while we use split/concatenate to + # convert to/from complex numbers. The transformations below make + # the rope computation numerically equivalent to the original + # implementation. + def _mistral_rope(x): + x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1) + x = self.rotary_embedding_layer(x, start_index=start_index) + x = ops.reshape( + ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x) + ) + return x + + # Compute RoPE for queries + query = _mistral_rope(query) + + def _compute_key_value(x): + key, value = self._key_dense(x), self._value_dense(x) + key = _mistral_rope(key) + return key, value + + if cache is not None: + cache_k = cache[:, 0, ...] + cache_v = cache[:, 1, ...] + + if cache_update_index is not None: + # Compute the new keys and values + key, value = _compute_key_value(hidden_states) + + # Cache is a rotating buffer, we want to warp around if + # the sequence length exceeds the sliding window. + update_end_index = ( + cache_update_index + seq_len - 1 + ) % self._sliding_window + 1 + update_end_index = ops.cast(update_end_index, "int32") + cache_update_index = cache_update_index % self._sliding_window + update_start_index = ops.cond( + update_end_index > cache_update_index, + lambda: ops.cast(cache_update_index, "int32"), + lambda: ops.cast(0, "int32"), + ) + # Also note that the update step below assumes that the + # sequence length is always one when `cache_update_index != 0`. + # This is necessary to support XLA compilation. Ideally, we + # would want to use + # `key[:, -(update_end_index - update_start_index):, ...]` + # as the update but updating using a dynamic slice gives an + # XLA compilation error in TensorFlow. + # Passing a sequence of length > 1 with cache update might give + # incorrect results (since there is no way to determine how + # many most recent tokens are to be saved if the tokens exceed + # the sliding window length). + cache_k = ops.slice_update( + cache_k, + [0, update_start_index, 0, 0], + # We slice the keys and values since if the user has passed + # a sequence of length > `self._sliding_window`. We want to + # prefill the cache using just the most recent values in the + # sliding window. + ops.cast( + key[:, -self._sliding_window :, ...], cache_k.dtype + ), + ) + cache_v = ops.slice_update( + cache_v, + [0, update_start_index, 0, 0], + ops.cast( + value[:, -self._sliding_window :, ...], cache_v.dtype + ), + ) + cache = ops.stack([cache_k, cache_v], axis=1) + + # Get the required keys and values from the cache. + # Since we expect the user to pass a fixed-size cache, we just + # pick the first few slices up-to and including the newly computed + # keys and values. + cache_k = cache_k[:, :update_end_index, ...] + cache_v = cache_v[:, :update_end_index, ...] + + key = ops.cast(cache_k, dtype=self.compute_dtype) + value = ops.cast(cache_v, dtype=self.compute_dtype) + else: + # Compute keys and values + key, value = _compute_key_value(hidden_states) + + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, key, value, attention_mask + ) + + attention_output = self._dropout_layer( + attention_output, training=training + ) + + attention_output = self._output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention(self, query, key, value, attention_mask=None): + attention_scores = ops.einsum(self._dot_product_equation, key, query) + + norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype)) + + attention_scores = attention_scores / norm_factor + + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self._num_query_heads, + "num_key_value_heads": self._num_key_value_heads, + "rope_max_wavelength": self._rope_max_wavelength, + "rope_scaling_factor": self._rope_scaling_factor, + "kernel_initializer": keras.initializers.serialize( + self._kernel_initializer + ), + "sliding_window": self._sliding_window, + "dropout": self._dropout, + } + ) + return config diff --git a/keras_nlp/models/mistral/mistral_backbone.py b/keras_nlp/models/mistral/mistral_backbone.py new file mode 100644 index 0000000000..42cec8b218 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_backbone.py @@ -0,0 +1,196 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.mistral.mistral_layer_norm import ( + MistralLayerNormalization, +) +from keras_nlp.models.mistral.mistral_transformer_decoder import ( + MistralTransformerDecoder, +) + + +def _mistral_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_nlp_export("keras_nlp.models.MistralBackbone") +class MistralBackbone(Backbone): + """ + The Mistral Transformer core architecture with hyperparameters. + + This network implements a Transformer-based decoder network, + Mistral, as described in + ["Mistral 7B"](https://arxiv.org/pdf/2310.06825.pdf). + It includes the embedding lookups and transformer layers. + + The default constructor gives a fully customizable, randomly initialized + Mistral model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. + + Args: + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + num_query_heads (int): The number of query attention heads for + each transformer. + hidden_dim (int): The size of the transformer encoding and pooling layers. + intermediate_dim (int): The output dimension of the first Dense layer in a + three-layer feedforward network for each transformer. + num_key_value_heads (int): The number of key and value attention heads for + each transformer. + rope_max_wavelength (int, optional): The maximum angular wavelength of the + sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_factor (float, optional): The scaling factor for calculation + of roatary embedding. Defaults to `1.0`. + layer_norm_epsilon (float, optional): Epsilon for the layer normalization + layers in the transformer decoder. Defaults to `1e-6`. + sliding_window (int, optional): The sliding window for the mistral + attention layers. This controls the maximum cache size for the attention + layers in each transformer decoder. Only `sliding_window` number of tokens + are saved in the cache and used to generate the next token. + Defaults to `512`. + dtype (str, optional): The dtype policy for the mistral model. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Mistral decoder. + model = keras_nlp.models.MistralBackbone.from_preset("mistral7b_base_en") + model(input_data) + + # Randomly initialized Mistral decoder with custom config. + model = keras_nlp.models.MistralBackbone( + vocabulary_size=10, + hidden_dim=512, + num_layers=2, + num_query_heads=32, + num_key_value_heads=8, + intermediate_dim=1024, + sliding_window=512, + layer_norm_epsilon=1e-6, + dtype="float32" + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + hidden_dim, + intermediate_dim, + num_key_value_heads, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + sliding_window=512, + dropout=0, + **kwargs, + ): + # Get the dtype + dtype = kwargs.pop("dtype", keras.backend.floatx()) + + # Inputs + token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # Embed Tokens + token_embedding_layer = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=False, + embeddings_initializer=_mistral_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + x = token_embedding_layer(token_ids) + + # Apply successive transformer decoder blocks + for i in range(num_layers): + x = MistralTransformerDecoder( + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_mistral_kernel_initializer(stddev=0.02), + sliding_window=sliding_window, + dropout=dropout, + dtype=dtype, + name=f"transformer_layer_{i}", + )(x, decoder_padding_mask=padding_mask) + + sequence_output = MistralLayerNormalization( + name="sequence_output_layernorm", + epsilon=layer_norm_epsilon, + dtype=dtype, + )(x) + + # Instantiate using Functional API Model constructor + super().__init__( + inputs={ + "token_ids": token_ids, + "padding_mask": padding_mask, + }, + outputs=sequence_output, + **kwargs, + ) + + # All references to `self` below this line + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.rope_max_wavelength = rope_max_wavelength + self.num_key_value_heads = num_key_value_heads + self.rope_scaling_factor = rope_scaling_factor + self.sliding_window = sliding_window + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.token_embedding = token_embedding_layer + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "num_key_value_heads": self.num_key_value_heads, + "sliding_window": self.sliding_window, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_nlp/models/mistral/mistral_backbone_test.py b/keras_nlp/models/mistral/mistral_backbone_test.py new file mode 100644 index 0000000000..fc2b0a592b --- /dev/null +++ b/keras_nlp/models/mistral/mistral_backbone_test.py @@ -0,0 +1,56 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.tests.test_case import TestCase + + +class MistralBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_query_heads": 8, + "num_key_value_heads": 4, + "hidden_dim": 16, + "intermediate_dim": 8, + "sliding_window": 2, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=MistralBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 16), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MistralBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_num_parameters(self): + model = MistralBackbone(**self.init_kwargs) + # Reference value calculated using the PyTorch model + self.assertEqual(model.count_params(), 2704) diff --git a/keras_nlp/models/mistral/mistral_layer_norm.py b/keras_nlp/models/mistral/mistral_layer_norm.py new file mode 100644 index 0000000000..9f9ddf26b5 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_layer_norm.py @@ -0,0 +1,48 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +# TODO: Deprecate this in favor of +# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is +# removed. +class MistralLayerNormalization(keras.layers.Layer): + """A normalization layer for Mistral that implements RMS normalization.""" + + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self._epsilon = epsilon + + def build(self, input_shape): + self._dim = input_shape[-1] + self._weight = self.add_weight( + name="weight", + trainable=True, + shape=(self._dim,), + initializer="ones", + dtype=self.compute_dtype, + ) + self.built = True + + def call(self, x): + x = x * ops.rsqrt( + ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + self._epsilon + ) + return x * self._weight + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self._epsilon}) + return config diff --git a/keras_nlp/models/mistral/mistral_transformer_decoder.py b/keras_nlp/models/mistral/mistral_transformer_decoder.py new file mode 100644 index 0000000000..9b6f7fdbf8 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_transformer_decoder.py @@ -0,0 +1,233 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.mistral.mistral_attention import CachedMistralAttention +from keras_nlp.models.mistral.mistral_layer_norm import ( + MistralLayerNormalization, +) +from keras_nlp.utils.keras_utils import clone_initializer + + +class MistralTransformerDecoder(keras.layers.Layer): + """A Transformer decoder layer for the Mistral backbone.""" + + def __init__( + self, + intermediate_dim, + num_query_heads, + num_key_value_heads, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + activation="relu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + sliding_window=512, + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + + self.dropout = dropout + + self.sliding_window = sliding_window + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + + self.supports_masking = True + + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape + self.hidden_dim = decoder_sequence_shape[-1] + + # Self attention layer. + self._self_attention_layer = CachedMistralAttention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, + sliding_window=self.sliding_window, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, + dtype=self.compute_dtype, + name="self_attention", + ) + self._self_attention_layer.build(decoder_sequence_shape) + + self._self_attention_layernorm = MistralLayerNormalization( + epsilon=self.layer_norm_epsilon, + name="self_attention_layernorm", + dtype=self.compute_dtype, + ) + self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.compute_dtype, + name="self_attention_dropout", + ) + + # Feedforward layers. + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.compute_dtype, + name="feedforward_intermediate_dense", + ) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + + self._feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + activation=self.activation, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + name="feedforward_gate_dense", + ) + self._feedforward_gate_dense.build(decoder_sequence_shape) + + self._feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.compute_dtype, + name="feedforward_output_dense", + ) + + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) + + self._feedforward_layernorm = MistralLayerNormalization( + epsilon=self.layer_norm_epsilon, + name="feedforward_layernorm", + dtype=self.compute_dtype, + ) + self._feedforward_layernorm.build(decoder_sequence_shape) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=None, + training=None, + ): + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + ) + residual = decoder_sequence + + x = self._self_attention_layernorm(decoder_sequence) + + # Self attention block. + x = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + + if self_attention_cache is not None: + x, self_attention_cache = x + + x = self._self_attention_dropout(x, training=training) + + x = x + residual + residual = x + + x = self._feedforward_layernorm(x) + gate_output = self._feedforward_gate_dense(x) + + x = self._feedforward_intermediate_dense(x) + + x = self._feedforward_output_dense(ops.multiply(x, gate_output)) + + decoder_output = x + residual + + if self_attention_cache is not None: + return decoder_output, self_attention_cache + return decoder_output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + + # Mistral uses a banded attention mask + causal_mask_lower = compute_causal_mask( + batch_size, input_length, output_length, 0 + ) + # Below is a workaround for `ops.triu` for Keras 2. + # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed. + # causal_mask = ops.triu(causal_mask_lower, k=-self.sliding_window) + i = ops.arange(output_length)[:, None] + j = ops.arange(input_length)[None, :] + causal_mask_upper = ops.cast(i <= j + self.sliding_window, "int32") + causal_mask = ops.minimum(causal_mask_lower, causal_mask_upper) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "num_query_heads": self.num_query_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "num_key_value_heads": self.num_key_value_heads, + "sliding_window": self.sliding_window, + "activation": keras.activations.serialize(self.activation), + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + } + ) + return config diff --git a/tools/checkpoint_conversion/convert_mistral_checkpoints.py b/tools/checkpoint_conversion/convert_mistral_checkpoints.py new file mode 100644 index 0000000000..3bc443d910 --- /dev/null +++ b/tools/checkpoint_conversion/convert_mistral_checkpoints.py @@ -0,0 +1,443 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import pathlib +from dataclasses import dataclass +from pathlib import Path +from typing import Optional +from typing import Tuple + +import torch +from torch import nn + +from keras_nlp.models import MistralBackbone + +MODEL_PATH = pathlib.Path("mistral-7B-v0.1") + +# Torch model taken from: +# https://github.com/mistralai/mistral-src/blob/147c4e68279b90eb61b19bdea44e16f5539d5a5d/one_file_ref.py + + +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + sliding_window: int + norm_eps: float + vocab_size: int + + max_batch_size: int = 0 + + +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int): + keys = torch.repeat_interleave(keys, repeats=repeats, dim=2) + values = torch.repeat_interleave(values, repeats=repeats, dim=2) + return keys, values + + +def _reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor +) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( + freqs_cis.shape, + (x.shape[1], x.shape[-1]), + ) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + self.sliding_window = self.args.sliding_window + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear( + args.dim, args.n_kv_heads * args.head_dim, bias=False + ) + self.wv = nn.Linear( + args.dim, args.n_kv_heads * args.head_dim, bias=False + ) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.cache_k = torch.empty( + ( + args.max_batch_size, + args.sliding_window, + self.n_kv_heads, + self.args.head_dim, + ), + dtype=torch.float16, + ) + self.cache_v = torch.empty( + ( + args.max_batch_size, + args.sliding_window, + self.n_kv_heads, + self.args.head_dim, + ), + dtype=torch.float16, + ) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + positions: torch.Tensor, + mask: Optional[torch.Tensor], + ) -> torch.Tensor: + bsz, seqlen, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # The cache is a rotating buffer + scatter_pos = (positions[-self.sliding_window :] % self.sliding_window)[ + None, :, None, None + ] + scatter_pos = scatter_pos.repeat( + bsz, 1, self.n_kv_heads, self.args.head_dim + ) + self.cache_k[:bsz].scatter_( + dim=1, + index=scatter_pos, + src=xk[:, -self.sliding_window :].to(self.cache_k.dtype), + ) + self.cache_v[:bsz].scatter_( + dim=1, + index=scatter_pos, + src=xv[:, -self.sliding_window :].to(self.cache_v.dtype), + ) + + if positions.shape[0] > 1: + # prefill + key, value = repeat_kv(xk, xv, self.repeats) + else: + cur_pos = positions[-1].item() + 1 + key, value = repeat_kv( + self.cache_k[:bsz, :cur_pos, ...].to(xk.dtype), + self.cache_v[:bsz, :cur_pos, ...].to(xv.dtype), + self.repeats, + ) + + query = xq.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + # scores : [bsz, n_heads, seqlen | 1, seqlen] + scores = torch.matmul(query, key.transpose(2, 3)) * self.scale + + if mask is not None: + scores += mask[None, None, ...] + + scores = scores.float() + scores = nn.functional.softmax(scores, dim=-1).type_as(query) + output = torch.matmul( + scores, value + ) # (bs, n_local_heads, slen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + + def forward(self, x) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = FeedForward(args=args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + positions: torch.Tensor, + mask: Optional[torch.Tensor], + ) -> torch.Tensor: + r = self.attention.forward( + self.attention_norm(x), freqs_cis, positions, mask + ) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0 +) -> torch.Tensor: + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + return torch.polar(torch.ones_like(freqs), freqs) # complex64 + + +class TorchTransformer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.n_layers = args.n_layers + assert self.vocab_size > 0 + + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + + self.layers = torch.nn.ModuleList( + [TransformerBlock(args=args) for _ in range(args.n_layers)] + ) + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + + self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ): + h = self.tok_embeddings(input_ids) + freqs_cis = self.freqs_cis[positions] + + mask: Optional[torch.Tensor] = None + if input_ids.shape[1] > 1: + seqlen = input_ids.shape[1] + tensor = torch.full( + (seqlen, seqlen), + dtype=h.dtype, + fill_value=1, + device=h.device, + ) + mask = torch.tril(tensor, diagonal=0).to(h.dtype) + # make the mask banded to account for sliding window + mask = torch.triu(mask, diagonal=-self.args.sliding_window) + mask = torch.log(mask) + + for layer in self.layers: + h = layer(h, freqs_cis, positions, mask) + + return self.output(self.norm(h)).float() + + @staticmethod + def from_folder( + folder: Path, max_batch_size: int = 1, device="cpu", dtype=torch.float16 + ): + with open(folder / "params.json", "r") as f: + model_args = ModelArgs(**json.loads(f.read())) + model_args.max_batch_size = max_batch_size + model = TorchTransformer(model_args).to(device=device, dtype=dtype) + loaded = torch.load(folder / "consolidated.00.pth") + model.load_state_dict(loaded) + return model + + +def port_weights( + model_k3: MistralBackbone, model_torch: TorchTransformer, params: ModelArgs +): + model_k3.get_layer("token_embedding").embeddings.assign( + model_torch.tok_embeddings.weight.detach().cpu().numpy() + ) + + for i in range(model_k3.num_layers): + model_k3.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._key_dense.set_weights( + [ + model_torch.layers[i] + .attention.wk.weight.T.reshape( + params.dim, params.n_kv_heads, params.head_dim + ) + .detach() + .cpu() + .numpy() + ] + ) + model_k3.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._query_dense.set_weights( + [ + model_torch.layers[i] + .attention.wq.weight.T.reshape( + params.dim, params.n_heads, params.head_dim + ) + .detach() + .cpu() + .numpy() + ] + ) + model_k3.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._value_dense.set_weights( + [ + model_torch.layers[i] + .attention.wv.weight.T.reshape( + params.dim, params.n_kv_heads, params.head_dim + ) + .detach() + .cpu() + .numpy() + ] + ) + model_k3.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._output_dense.set_weights( + [ + model_torch.layers[i] + .attention.wo.weight.T.reshape( + params.n_heads, params.head_dim, params.dim + ) + .detach() + .cpu() + .numpy() + ] + ) + model_k3.get_layer( + f"transformer_layer_{i}" + )._self_attention_layernorm.set_weights( + [model_torch.layers[i].attention_norm.weight.detach().cpu().numpy()] + ) + model_k3.get_layer( + f"transformer_layer_{i}" + )._feedforward_intermediate_dense.set_weights( + [ + model_torch.layers[i] + .feed_forward.w3.weight.T.detach() + .cpu() + .numpy() + ] + ) + model_k3.get_layer( + f"transformer_layer_{i}" + )._feedforward_output_dense.set_weights( + [ + model_torch.layers[i] + .feed_forward.w2.weight.T.detach() + .cpu() + .numpy() + ] + ) + model_k3.get_layer( + f"transformer_layer_{i}" + )._feedforward_gate_dense.set_weights( + [ + model_torch.layers[i] + .feed_forward.w1.weight.T.detach() + .cpu() + .numpy() + ] + ) + model_k3.get_layer( + f"transformer_layer_{i}" + )._feedforward_layernorm.set_weights( + [model_torch.layers[i].ffn_norm.weight.detach().cpu().numpy()] + ) + + model_k3.get_layer("sequence_output_layernorm").set_weights( + [model_torch.norm.weight.detach().cpu().numpy()] + ) + model_k3.get_layer("token_embedding").reverse_embeddings.assign( + model_torch.output.weight.T.detach().cpu().numpy() + ) + + +if __name__ == "__main__": + with open(MODEL_PATH / "params.json", "r") as params_file: + params = ModelArgs(**json.load(params_file)) + + model_torch = TorchTransformer.from_folder( + MODEL_PATH, device="cpu", dtype=torch.float16 + ) + print("Torch model loaded") + model_k3 = MistralBackbone( + vocabulary_size=32000, + hidden_dim=4096, + num_layers=32, + num_query_heads=32, + num_key_value_heads=8, + intermediate_dim=14336, + sliding_window=4096, + layer_norm_epsilon=1e-6, + dtype="float16", + ) + print("Keras 3 model loaded.") + + port_weights(model_k3, model_torch, params) + print("Weight transfer done.") + + model_k3.save_weights("mistral_7b.weights.h5") + print("Weights saved.")