From ec00c0271d56ac7c8c68459b04166285614f3aa5 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 27 Mar 2024 00:46:08 +0000 Subject: [PATCH 1/6] Add LLaMA Causal LM --- keras_nlp/models/__init__.py | 5 + keras_nlp/models/llama/llama_backbone.py | 15 +- keras_nlp/models/llama/llama_causal_lm.py | 220 +++++++++++ keras_nlp/models/llama/llama_preprocessor.py | 7 + .../models/llama/llama_preprocessor_test.py | 11 + keras_nlp/models/llama/llama_presets.py | 38 ++ keras_nlp/models/llama/llama_tokenizer.py | 8 + .../models/llama/llama_tokenizer_test.py | 20 + .../convert_llama_checkpoints.py | 350 +++++++++++++----- 9 files changed, 569 insertions(+), 105 deletions(-) create mode 100644 keras_nlp/models/llama/llama_causal_lm.py create mode 100644 keras_nlp/models/llama/llama_presets.py diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 033a9dc87..b830be3ee 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -107,6 +107,11 @@ ) from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index b5383d528..ec35989e0 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -11,20 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# import copy +import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone - -# from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.models.llama.llama_decoder import LlamaTransformerDecoder from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm - -# from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.llama.llama_presets import backbone_presets +from keras_nlp.utils.python_utils import classproperty def _llama_kernel_initializer(stddev=0.02): @@ -191,6 +188,6 @@ def get_config(self): ) return config - # @classproperty - # def presets(cls): - # return copy.deepcopy(backbone_presets) + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_causal_lm.py b/keras_nlp/models/llama/llama_causal_lm.py new file mode 100644 index 000000000..feb4df034 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm.py @@ -0,0 +1,220 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_presets import backbone_presets +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.LlamaCausalLM") +class LlamaCausalLM(GenerativeTask): + """An end-to-end Llama model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a GPT-NeoX model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_nlp.models.LlamaBackbone` instance. + preprocessor: A `keras_nlp.models.LlamaCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + """ + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.inputs + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + jit_compile=True, + ) + + @classproperty + def backbone_cls(cls): + return LlamaBackbone + + @classproperty + def preprocessor_cls(cls): + return LlamaCausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `LlamaCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + end_token_id=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + end_token_id: The id of the end token to stop on. If all + sequences have produced a new `end_token_id`, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + end_token_id=end_token_id, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if end_token_id is not None: + # Build a mask of `end_token_id` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = ops.logical_and( + ops.equal(token_ids, end_token_id), + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_preprocessor.py b/keras_nlp/models/llama/llama_preprocessor.py index 580557f50..a24c42508 100644 --- a/keras_nlp/models/llama/llama_preprocessor.py +++ b/keras_nlp/models/llama/llama_preprocessor.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( @@ -189,3 +192,7 @@ def sequence_length(self, value): @classproperty def tokenizer_cls(cls): return LlamaTokenizer + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_preprocessor_test.py b/keras_nlp/models/llama/llama_preprocessor_test.py index 680788681..52a559aa2 100644 --- a/keras_nlp/models/llama/llama_preprocessor_test.py +++ b/keras_nlp/models/llama/llama_preprocessor_test.py @@ -14,6 +14,8 @@ import os +import pytest + from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.tests.test_case import TestCase @@ -55,3 +57,12 @@ def test_errors_for_2d_list_input(self): ambiguous_input = [["one", "two"], ["three", "four"]] with self.assertRaises(ValueError): preprocessor(ambiguous_input) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaPreprocessor.presets: + self.run_preset_test( + cls=LlamaPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_presets.py b/keras_nlp/models/llama/llama_presets.py new file mode 100644 index 000000000..4c8bd6de5 --- /dev/null +++ b/keras_nlp/models/llama/llama_presets.py @@ -0,0 +1,38 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llama model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "llama_7b_en": { + "metadata": { + "description": "Llama 7B Base model", + "params": 6738415616, + "official_name": "Llama", + "path": "llama", + "model_card": "https://github.com/llamaai/llama-src/blob/main/README.md", + }, + "kaggle_handle": "kaggle://keras/llama/keras/llama_7b_en/1", + }, + "llama_instruct_7b_en": { + "metadata": { + "description": "LLaMA 7B Chat model", + "params": 6738415616, + "official_name": "LLaMA", + "path": "llama", + "model_card": "https://github.com/llamaai/llama-src/blob/main/README.md", + }, + "kaggle_handle": "kaggle://keras/llama/keras/llama_instruct_7b_en/1", + }, +} diff --git a/keras_nlp/models/llama/llama_tokenizer.py b/keras_nlp/models/llama/llama_tokenizer.py index 7acdf8687..07b0f2103 100644 --- a/keras_nlp/models/llama/llama_tokenizer.py +++ b/keras_nlp/models/llama/llama_tokenizer.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer +from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.LlamaTokenizer") @@ -79,3 +83,7 @@ def set_proto(self, proto): self.start_token_id = None self.end_token_id = None self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_tokenizer_test.py b/keras_nlp/models/llama/llama_tokenizer_test.py index 9a3c22545..3ff573b67 100644 --- a/keras_nlp/models/llama/llama_tokenizer_test.py +++ b/keras_nlp/models/llama/llama_tokenizer_test.py @@ -14,6 +14,8 @@ import os +import pytest + from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.tests.test_case import TestCase @@ -44,3 +46,21 @@ def test_errors_missing_special_tokens(self): self.get_test_data_dir(), "no_special_token_vocab.spm" ) ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=LlamaTokenizer, + preset="llama_7b_en", + input_data=["The quick brown fox."], + expected_output=[[415, 2936, 9060, 285, 1142, 28723]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaTokenizer.presets: + self.run_preset_test( + cls=LlamaTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/tools/checkpoint_conversion/convert_llama_checkpoints.py b/tools/checkpoint_conversion/convert_llama_checkpoints.py index 5eb3973f3..ec473c5fc 100644 --- a/tools/checkpoint_conversion/convert_llama_checkpoints.py +++ b/tools/checkpoint_conversion/convert_llama_checkpoints.py @@ -11,131 +11,289 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os +import shutil +import tempfile +import traceback -import torch -from transformers import AutoModel +import numpy as np +import requests +from absl import app +from absl import flags +from keras import ops +from transformers import AutoTokenizer +from transformers import LlamaForCausalLM -from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models import LlamaBackbone +from keras_nlp.models import LlamaCausalLMPreprocessor +from keras_nlp.models import LlamaTokenizer +from keras_nlp.utils.preset_utils import save_to_preset -os.environ["KERAS_BACKEND"] = "torch" +PRESET_MAP = { + "llama_7b_en": "meta-llama/Llama-2-7b-hf", + "llama_instruct_7b_en": "meta-llama/Llama-2-7b-chat-hf", +} -# from huggingface_hub import login -# llama weights as of now are on request access -# login(token=' Huggingface model and tokenizer loaded") - # MLP - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_intermediate_dense.kernel.assign( - hf_wts[f"layers.{ilayer}.mlp.up_proj.weight"].numpy().T - ) + # === Load the KerasNLP model === + backbone_kwargs = dict( + vocabulary_size=hf_model.config.vocab_size, + hidden_dim=hf_model.config.hidden_size, + num_layers=hf_model.config.num_hidden_layers, + num_query_heads=hf_model.config.num_attention_heads, + num_key_value_heads=hf_model.config.num_key_value_heads, + intermediate_dim=hf_model.config.intermediate_size, + layer_norm_epsilon=hf_model.config.rms_norm_eps, + rope_max_wavelength=hf_model.config.rope_theta, + dtype="float32", + ) + keras_nlp_model = LlamaBackbone(**backbone_kwargs) - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_gate_dense.kernel.assign( - hf_wts[f"layers.{ilayer}.mlp.gate_proj.weight"].numpy().T - ) + # === Download the tokenizer from Huggingface model card === + spm_path = ( + f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" + ) + response = requests.get(spm_path) + if not response.ok: + raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") + tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") + with open(tokenizer_path, "wb") as tokenizer_file: + tokenizer_file.write(response.content) + keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path) + print("\n-> Keras 3 model and tokenizer loaded.") - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_output_dense.kernel.assign( - hf_wts[f"layers.{ilayer}.mlp.down_proj.weight"].numpy().T - ) + # === Port the weights === + convert_checkpoints(keras_nlp_model, hf_model) + print("\n-> Weight transfer done.") - # LAYERNORM - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._self_attention_layernorm.weight.assign( - hf_wts[f"layers.{ilayer}.input_layernorm.weight"] - ) + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) + test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) + print("\n-> Tests passed!") - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_layernorm.weight.assign( - hf_wts[f"layers.{ilayer}.post_attention_layernorm.weight"] - ) + # === Save the model weights in float32 format === + keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) + print("\n-> Saved the model weights in float32") + del keras_nlp_model, hf_model + gc.collect() -keras_model.get_layer("layer_norm").gamma.assign(hf_wts["norm.weight"]) + # === Save the weights again in float16 === + backbone_kwargs["dtype"] = "float16" + keras_nlp_model = LlamaBackbone(**backbone_kwargs) + keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) + save_to_preset(keras_nlp_model, preset) + print("\n-> Saved the model preset in float16") -token_ids = [1, 2181, 8522, 338] -padding_mask = [1, 1, 1, 1] + # === Save the tokenizer === + save_to_preset( + keras_nlp_tokenizer, preset, config_filename="tokenizer.json" + ) + print("\n-> Saved the tokenizer") + finally: + shutil.rmtree(temp_dir) -keras_inputs = { - "token_ids": torch.tensor([token_ids]), - "padding_mask": torch.tensor([padding_mask]), -} -with torch.no_grad(): - keras_outputs = keras_model(keras_inputs) -print("Keras output = ", keras_outputs.numpy()) +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) From 460e7e3d5ae2aa8181e498c45472edc26daab30b Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 27 Mar 2024 00:57:31 +0000 Subject: [PATCH 2/6] Add causal lm to the public API --- keras_nlp/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index b830be3ee..5ca6b657a 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -107,6 +107,7 @@ ) from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm import LlamaCausalLM from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( LlamaCausalLMPreprocessor, ) From 4c1661b01ecd96bb8a9ac3284454df6e03da3980 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 27 Mar 2024 20:32:43 +0000 Subject: [PATCH 3/6] Update preset names and fix checkpoint script --- keras_nlp/models/llama/llama_presets.py | 24 +++++++++---------- .../convert_llama_checkpoints.py | 17 ++++--------- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/keras_nlp/models/llama/llama_presets.py b/keras_nlp/models/llama/llama_presets.py index 4c8bd6de5..292848a11 100644 --- a/keras_nlp/models/llama/llama_presets.py +++ b/keras_nlp/models/llama/llama_presets.py @@ -15,24 +15,24 @@ # Metadata for loading pretrained model weights. backbone_presets = { - "llama_7b_en": { + "llama2_7b_en": { "metadata": { - "description": "Llama 7B Base model", + "description": "LLaMA 2 7B Base model", "params": 6738415616, - "official_name": "Llama", - "path": "llama", - "model_card": "https://github.com/llamaai/llama-src/blob/main/README.md", + "official_name": "LLaMA 2", + "path": "llama2", + "model_card": "https://github.com/meta-llama/llama", }, - "kaggle_handle": "kaggle://keras/llama/keras/llama_7b_en/1", + "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/1", }, - "llama_instruct_7b_en": { + "llama2_instruct_7b_en": { "metadata": { - "description": "LLaMA 7B Chat model", + "description": "LLaMA 2 7B Chat model", "params": 6738415616, - "official_name": "LLaMA", - "path": "llama", - "model_card": "https://github.com/llamaai/llama-src/blob/main/README.md", + "official_name": "LLaMA 2", + "path": "llama2", + "model_card": "https://github.com/meta-llama/llama", }, - "kaggle_handle": "kaggle://keras/llama/keras/llama_instruct_7b_en/1", + "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/1", }, } diff --git a/tools/checkpoint_conversion/convert_llama_checkpoints.py b/tools/checkpoint_conversion/convert_llama_checkpoints.py index ec473c5fc..4e127b2c7 100644 --- a/tools/checkpoint_conversion/convert_llama_checkpoints.py +++ b/tools/checkpoint_conversion/convert_llama_checkpoints.py @@ -18,7 +18,6 @@ import traceback import numpy as np -import requests from absl import app from absl import flags from keras import ops @@ -31,8 +30,8 @@ from keras_nlp.utils.preset_utils import save_to_preset PRESET_MAP = { - "llama_7b_en": "meta-llama/Llama-2-7b-hf", - "llama_instruct_7b_en": "meta-llama/Llama-2-7b-chat-hf", + "llama2_7b_en": "meta-llama/Llama-2-7b-hf", + "llama2_instruct_7b_en": "meta-llama/Llama-2-7b-chat-hf", } FLAGS = flags.FLAGS @@ -249,16 +248,8 @@ def main(_): ) keras_nlp_model = LlamaBackbone(**backbone_kwargs) - # === Download the tokenizer from Huggingface model card === - spm_path = ( - f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" - ) - response = requests.get(spm_path) - if not response.ok: - raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") - tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") - with open(tokenizer_path, "wb") as tokenizer_file: - tokenizer_file.write(response.content) + # === Get the tokenizer from the Huggingface model === + tokenizer_path = hf_tokenizer.vocab_file keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path) print("\n-> Keras 3 model and tokenizer loaded.") From d0756970c4826fcad02c6f44eec5f4011367dec0 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 27 Mar 2024 23:45:06 +0000 Subject: [PATCH 4/6] Fix discrepancies and add tests --- keras_nlp/models/llama/llama_backbone_test.py | 31 +++++++++++++++++++ keras_nlp/models/llama/llama_causal_lm.py | 2 +- .../models/llama/llama_tokenizer_test.py | 4 +-- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/keras_nlp/models/llama/llama_backbone_test.py b/keras_nlp/models/llama/llama_backbone_test.py index 56d8c44bd..b641a0152 100644 --- a/keras_nlp/models/llama/llama_backbone_test.py +++ b/keras_nlp/models/llama/llama_backbone_test.py @@ -49,3 +49,34 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + def test_num_parameters(self): + model = LlamaBackbone(**self.init_kwargs) + # Reference value calculated using the PyTorch model + self.assertEqual(model.count_params(), 968) + + @pytest.mark.extra_large + def test_smallest_preset(self): + self.run_preset_test( + cls=LlamaBackbone, + preset="llama2_7b_en", + input_data={ + "token_ids": ops.array([[1, 1824, 349, 524, 11234, 28804]]), + "padding_mask": ops.ones((1, 6), dtype="int32"), + }, + expected_output_shape=(1, 6, 4096), + # The forward pass from a preset should be stable! + # Reference values computed using PyTorch HF model. + expected_partial_output=ops.array( + [0.0153, 1.1657, 2.2452, -2.0192, -0.5801] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaBackbone.presets: + self.run_preset_test( + cls=LlamaBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_causal_lm.py b/keras_nlp/models/llama/llama_causal_lm.py index feb4df034..a5f23bd97 100644 --- a/keras_nlp/models/llama/llama_causal_lm.py +++ b/keras_nlp/models/llama/llama_causal_lm.py @@ -33,7 +33,7 @@ class LlamaCausalLM(GenerativeTask): tokens. This task setup can be used to train the model unsupervised on plain text input, or to autoregressively generate plain text similar to the data used for training. This task can be used for pre-training or - fine-tuning a GPT-NeoX model, simply by calling `fit()`. + fine-tuning a LLaMA model, simply by calling `fit()`. This model has a `generate()` method, which generates text based on a prompt. The generation strategy used is controlled by an additional diff --git a/keras_nlp/models/llama/llama_tokenizer_test.py b/keras_nlp/models/llama/llama_tokenizer_test.py index 3ff573b67..51687731e 100644 --- a/keras_nlp/models/llama/llama_tokenizer_test.py +++ b/keras_nlp/models/llama/llama_tokenizer_test.py @@ -51,9 +51,9 @@ def test_errors_missing_special_tokens(self): def test_smallest_preset(self): self.run_preset_test( cls=LlamaTokenizer, - preset="llama_7b_en", + preset="llama2_7b_en", input_data=["The quick brown fox."], - expected_output=[[415, 2936, 9060, 285, 1142, 28723]], + expected_output=[[450, 4996, 17354, 1701, 29916, 29889]], ) @pytest.mark.extra_large From f8326a135830501fbb30fa6195aca58ce77099e7 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 27 Mar 2024 23:46:21 +0000 Subject: [PATCH 5/6] Add tests for CausalLM --- .../models/llama/llama_causal_lm_test.py | 130 ++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 keras_nlp/models/llama/llama_causal_lm_test.py diff --git a/keras_nlp/models/llama/llama_causal_lm_test.py b/keras_nlp/models/llama/llama_causal_lm_test.py new file mode 100644 index 000000000..ff71a75b3 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_test.py @@ -0,0 +1,130 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm import LlamaCausalLM +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaCausalLMTest(TestCase): + def setUp(self): + self.preprocessor = LlamaCausalLMPreprocessor( + LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "llama_test_vocab.spm" + ) + ), + sequence_length=8, + ) + self.backbone = LlamaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the earth is round"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=LlamaCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 10), + ) + + def test_generate(self): + causal_lm = LlamaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = LlamaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the earth"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = LlamaCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=LlamaCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaCausalLM.presets: + self.run_preset_test( + cls=LlamaCausalLM, + preset=preset, + input_data=self.input_data, + ) From 9a0804ac01422ffd3ee28d7282ce646d9185105d Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Thu, 28 Mar 2024 00:35:53 +0000 Subject: [PATCH 6/6] end_token -> stop_token_ids --- keras_nlp/models/llama/llama_causal_lm.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/keras_nlp/models/llama/llama_causal_lm.py b/keras_nlp/models/llama/llama_causal_lm.py index a5f23bd97..7527766f0 100644 --- a/keras_nlp/models/llama/llama_causal_lm.py +++ b/keras_nlp/models/llama/llama_causal_lm.py @@ -23,6 +23,7 @@ ) from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.LlamaCausalLM") @@ -143,7 +144,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -154,8 +155,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -188,17 +189,17 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original # prompt (not in locations where `padding_mask` is True). end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), + any_equal(token_ids, stop_token_ids), ops.logical_not(padding_mask), ) end_locations = ops.cast(end_locations, "int32")