Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLaMA Causal LM with 7B presets #1526

Merged
merged 7 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@
)
from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.models.llama.llama_backbone import LlamaBackbone
from keras_nlp.models.llama.llama_causal_lm import LlamaCausalLM
from keras_nlp.models.llama.llama_causal_lm_preprocessor import (
LlamaCausalLMPreprocessor,
)
from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor
from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer
tirthasheshpatel marked this conversation as resolved.
Show resolved Hide resolved
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone
from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
Expand Down
15 changes: 6 additions & 9 deletions keras_nlp/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# import copy
import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_nlp.models.backbone import Backbone

# from keras_nlp.models.llama.llama_presets import backbone_presets
from keras_nlp.models.llama.llama_decoder import LlamaTransformerDecoder
from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm

# from keras_nlp.utils.python_utils import classproperty
from keras_nlp.models.llama.llama_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty


def _llama_kernel_initializer(stddev=0.02):
Expand Down Expand Up @@ -191,6 +188,6 @@ def get_config(self):
)
return config

# @classproperty
# def presets(cls):
# return copy.deepcopy(backbone_presets)
@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
31 changes: 31 additions & 0 deletions keras_nlp/models/llama/llama_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,34 @@ def test_saved_model(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

def test_num_parameters(self):
model = LlamaBackbone(**self.init_kwargs)
# Reference value calculated using the PyTorch model
self.assertEqual(model.count_params(), 968)

@pytest.mark.extra_large
def test_smallest_preset(self):
self.run_preset_test(
cls=LlamaBackbone,
preset="llama2_7b_en",
input_data={
"token_ids": ops.array([[1, 1824, 349, 524, 11234, 28804]]),
"padding_mask": ops.ones((1, 6), dtype="int32"),
},
expected_output_shape=(1, 6, 4096),
# The forward pass from a preset should be stable!
# Reference values computed using PyTorch HF model.
expected_partial_output=ops.array(
[0.0153, 1.1657, 2.2452, -2.0192, -0.5801]
),
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in LlamaBackbone.presets:
self.run_preset_test(
cls=LlamaBackbone,
preset=preset,
input_data=self.input_data,
)
221 changes: 221 additions & 0 deletions keras_nlp/models/llama/llama_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.models.generative_task import GenerativeTask
from keras_nlp.models.llama.llama_backbone import LlamaBackbone
from keras_nlp.models.llama.llama_causal_lm_preprocessor import (
LlamaCausalLMPreprocessor,
)
from keras_nlp.models.llama.llama_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tensor_utils import any_equal


@keras_nlp_export("keras_nlp.models.LlamaCausalLM")
class LlamaCausalLM(GenerativeTask):
"""An end-to-end Llama model for causal language modeling.

A causal language model (LM) predicts the next token based on previous
tokens. This task setup can be used to train the model unsupervised on
plain text input, or to autoregressively generate plain text similar to
the data used for training. This task can be used for pre-training or
fine-tuning a LLaMA model, simply by calling `fit()`.

This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_nlp.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.

Args:
backbone: A `keras_nlp.models.LlamaBackbone` instance.
preprocessor: A `keras_nlp.models.LlamaCausalLMPreprocessor` or `None`.
If `None`, this model will not apply preprocessing, and inputs
should be preprocessed before calling the model.
"""

def __init__(self, backbone, preprocessor=None, **kwargs):
# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.inputs
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
inputs=inputs,
outputs=outputs,
**kwargs,
)

# === Default compilation ===
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(2e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)

@classproperty
def backbone_cls(cls):
return LlamaBackbone

@classproperty
def preprocessor_cls(cls):
return LlamaCausalLMPreprocessor

def call_with_cache(
self,
token_ids,
cache,
cache_update_index,
):
"""Forward pass of `LlamaCausalLM` with cache.

`call_with_cache` adds an additional forward pass for the model for
autoregressive inference. Unlike calling the model directly, this method
allows caching previous key/value Tensors in multi-head attention layer,
and avoids recomputing the outputs of seen tokens.

Args:
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
cache: a dense float Tensor, the cache of key and value.
cache_update_index: int, or int Tensor. The index of current inputs
in the whole sequence.

Returns:
A (logits, hidden_states, cache) tuple. Where `logits` is the
language model logits for the input token_ids, `hidden_states` is
the final hidden representation of the input tokens, and `cache` is
the decoding cache.
"""
x = self.backbone.token_embedding(token_ids)
# Each decoder layer has a cache; we update them separately.
updated_cache = []
for i in range(self.backbone.num_layers):
current_cache = cache[:, i, ...]
x, next_cache = self.backbone.transformer_layers[i](
x,
self_attention_cache=current_cache,
self_attention_cache_update_index=cache_update_index,
)
updated_cache.append(next_cache)
cache = ops.stack(updated_cache, axis=1)
hidden_states = x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)
return logits, hidden_states, cache

def _build_cache(self, token_ids):
"""Build an empty cache for use with `call_with_cache()`."""
batch_size = ops.shape(token_ids)[0]
max_length = ops.shape(token_ids)[1]
num_layers = self.backbone.num_layers
num_key_value_heads = self.backbone.num_key_value_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads
shape = [
batch_size,
num_layers,
2,
max_length,
num_key_value_heads,
head_dim,
]
cache = ops.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
return hidden_states, cache

def generate_step(
self,
inputs,
stop_token_ids=None,
):
"""A compilable generation function for a single batch of inputs.

This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.

Args:
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
stop_token_ids: Tuple of id's of the end token to stop on. If all
sequences have produced a new stop token, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_update_index = index - 1
batch_size = ops.shape(prompt)[0]
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
logits, hidden_states, cache = self.call_with_cache(
prompt,
cache,
cache_update_index,
)
return (
ops.squeeze(logits, axis=1),
ops.squeeze(hidden_states, axis=1),
cache,
)

token_ids = self._sampler(
next=next,
prompt=token_ids,
cache=cache,
index=index,
mask=padding_mask,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if stop_token_ids is not None:
# Build a mask of stop token locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = ops.logical_and(
any_equal(token_ids, stop_token_ids),
ops.logical_not(padding_mask),
)
end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
padding_mask = ops.ones_like(token_ids, dtype="bool")
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
Loading
Loading