Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MistralAI's 7B Transformer as a backbone in KerasNLP Models #1314

Merged
merged 9 commits into from
Dec 20, 2023

Conversation

tirthasheshpatel
Copy link
Contributor

@tirthasheshpatel tirthasheshpatel commented Nov 13, 2023

Fixes #1275

This PR adds a MistralBackbone backbone model and all it's components.

Most of the components share a lot of code with #1203.

Reference implementation: mistralai/mistral-src

Colab for weight transfer: https://colab.research.google.com/drive/1MoD7JJasThxmalspG3c21oYMEc_qRbti?usp=sharing

TODOs:

  • Add docs for all the layers and the backbone.
  • Add tests to confirm the forward pass matches.
  • Add a checkpoint conversion script.
  • Add the 7B model preset
  • Add dropout to the CachedMistralAttention and MistralTransformerDecoder layers.

@tirthasheshpatel tirthasheshpatel added the type:feature New feature or request label Nov 13, 2023
@tirthasheshpatel tirthasheshpatel changed the title Add MistralAI's Transformer as a backbone in KerasNLP Models Add MistralAI's 7B Transformer as a backbone in KerasNLP Models Nov 13, 2023
@mattdangerw
Copy link
Member

Still need to take a pass, but a quick note on tests.

Looks like some Keras nightly changes broke us recently, debugging currently. You can ignore the Keras 3 failures. However, the Keras 2 failure looks mistral related and is worth digging into.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I will probably try to step more carefully through the sliding window caching part to understand it better, but left some initial comments.

# TODO(tirthasheshpatel): Generalize the attention layer
# TODO(tirthasheshpatel): Merge `LlamaAttention` with this layer
# TODO(tirthasheshpatel): Use flash attention
# TODO(tirthasheshpatel): Add dropout
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to do this one if it's easy enough. We usually try to add dropout along with the original architecture.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

query = self._query_dense(hidden_states)

# Note that the original PyTorch implementation uses
# view_as_complex/view_as_real while we use split/concatenate to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this a bit more? Why do we need to consider complex numbers here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the mistral source for computing frequencies (same as Llama 2) and computing the embeddings (same as Llama 2 too)

The frequencies and inputs are treated as complex numbers and the computation follows the "Theoretical Explanation" section in the paper.

PyTorch's view_as_complex is used to convert the tensors to complex numbers which reshapes the inputs to shape (*x.shape[:-1], x.shape[-1] // 2, 2) and treats each pair of elements in axis -1 as a (real, complex) pair. RotaryEmbedding uses ops.split(x, 2) to convert the inputs to a complex representation (after splitting, the first half of the inputs become the real part and the other half becomes the complex part).

This is the only fundamental difference in both the computations. We can get the same results if we shuffle the inputs such that the alternate elements get moved to the end of the tensor. Hence, the x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1) bit before passing the inputs to the rotary embedding layer.

The reverse transformation exactly mirrors/undoes what we did above.

Code demonstration of the above explaination
import torch
import numpy as np
from keras import ops

def _reshape_for_broadcast(freqs_cis, x):
    """
    freqs_cis: complex - (seq_len, head_dim / 2)
    x: complex - (bsz, seq_len, head_dim / 2)
    """
    ndim = x.ndim
    assert 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
        freqs_cis.shape,
        (x.shape[1], x.shape[-1]),
    )
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


# Llama's version of rotary embeddings
def apply_rotary_emb(
    xq,
    freqs_cis,
):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq)


# Our version of the same computation.
# With transformations to match the `apply_rotary_emb` function above.
def apply_rotary_pos_emb(tensor, cos_emb, sin_emb):
    tensor = ops.concatenate((tensor[..., ::2], tensor[..., 1::2]), axis=-1)
    x1, x2 = ops.split(tensor, 2, axis=-1)
    half_rot_tensor = ops.concatenate((-x2, x1), axis=-1)
    res = (tensor * cos_emb) + (half_rot_tensor * sin_emb)
    return res

x = np.random.standard_normal((1,2,1,16))
cos_emb = np.random.standard_normal((2,8))
sin_emb = np.random.standard_normal((2,8))

print(x)

print(ops.concatenate((x[..., ::2], x[..., 1::2]), axis=-1))
print(np.split(np.concatenate((x[..., ::2], x[..., 1::2]), axis=-1), 2, axis=-1))

print(torch.view_as_complex(torch.tensor(x).reshape(*x.shape[:-1], -1, 2)))
print(apply_rotary_emb(torch.tensor(x), torch.tensor(cos_emb + sin_emb * 1.0j)))

y = apply_rotary_pos_emb(
    x,
    np.concatenate([cos_emb[None, :, None, :]]*2, axis=-1),
    np.concatenate([sin_emb[None, :, None, :]]*2, axis=-1)
)

print(ops.reshape(ops.stack(ops.split(y, 2, axis=-1), axis=-1), (y.shape[0], y.shape[1], y.shape[2], -1)))

A bit complicated but it should be possible to achieve the same behavior by shuffling the weights using the same transformations. I believe that's what the huggingface folks have done which is why this isn't required in the Llama backbone PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

by shuffling the weights using the same transformations

Shuffling what weights?

At the highest level, we should just consider whether we should pull this into the lower level RotaryEmbedding layer. We want it to be useful for the most common use cases of rotary embeddings.

key = ops.cast(
cache_k[
:,
: (cache_update_index + seq_len - 1) % self._sliding_window
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general note, use intermediate variables to improve the readability here if you can. especially if you can come up with good names, say for xx = (cache_update_index + seq_len - 1) % self._sliding_window

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. This part was removed to make the caching step XLA compatible.

attention_scores, attention_mask
)
attention_output = ops.einsum(
"acbe,aecd->abcd", attention_scores, value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider something like this, where we collocate all einsum equations in build, and we add a nice key at the top. Helps readability.

https://github.com/keras-team/keras/blob/master/keras/layers/attention/grouped_query_attention.py#L124-L167

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could still pull these up into build for a nice co-location, and rewrite to use the same symbols as above key.

in build...
self._dot_product_equation = ...
self._combine_equation = ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

sliding_window=512,
**kwargs,
):
decoder_sequence_shape = kwargs.pop("decoder_sequence_shape", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we needed this for TransformerDecoder was Keras 2 struggle with multiple build shape arguments. We shouldn't need this here I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"decoder_sequence_shape": self._decoder_sequence_shape,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should need this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

)
# Below is a workaround for `ops.triu` for Keras 2.
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed.
# causal_mask = ops.triu(causal_mask_upper, k=-self.sliding_window)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ops.triu is ready now, we could do this like

if config.keras_3:
    ops.triu(...)
else:
    ops.arange...

What does the overall structure of this mask look like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistral uses a banded matrix structure. For example, for inputs of sequence length 5 and sliding window of size 2, we would have something like:

In [1]: from keras import ops

In [2]: ops.triu(ops.tril(ops.ones((5, 5)), k=0), k=-2)  # generally k = -sliding_window
Out[2]: 
<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0.],
       [1., 1., 1., 0., 0.],
       [0., 1., 1., 1., 0.],
       [0., 0., 1., 1., 1.]], dtype=float32)>

@tirthasheshpatel
Copy link
Contributor Author

@mattdangerw I think I have addressed all your comments except the docs one. Will add docs in the next commit.

**kwargs,
):
# Get the dtype
dtype = kwargs.pop("dtype", keras.backend.floatx())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dtypes work as expected for the TensorFlow and JAX backends but PyTorch currently fails internally in Keras 3 due to dtype issues.

Comment on lines +225 to +228
cache_k = cache_k[:, :update_end_index, ...]
cache_v = cache_v[:, :update_end_index, ...]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX fails here if cache_update_index is a traced JAX array. But the value of cache_update_index should be known at each step. I think the right fix here is to make sure that the GenerateTask model passes concrete values here. Otherwise, it would be pretty tricky to make sliding window attention work in JAX.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this looks unsupported by XLA today, as this would involve dynamic shapes in a compiled while_loop. Let's work on a fix as a follow up.

Comment on lines +133 to +136
cache=None,
cache_update_index=None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Right now, caching doesn't work when the sequence length is greater than the sliding window.

Can be addressed as a follow-up when adding the Generator model; shouldn't be a blocker here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. If the upstream version is not solving this correctly, let's not worry too much about this.

@mattdangerw
Copy link
Member

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Feel free to pull this in after test green and addressing remaining comments.

@@ -97,7 +97,7 @@ def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
return (tensor * cos_emb) + (half_rot_tensor * sin_emb)

def _compute_cos_sin_embedding(self, x, rotary_dim, start_index):
freq_range = ops.arange(0, rotary_dim, 2, dtype="float32")
freq_range = ops.cast(ops.arange(0, rotary_dim, 2), self.compute_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a double cast (see the next line). Remove one or the other?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +133 to +136
cache=None,
cache_update_index=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. If the upstream version is not solving this correctly, let's not worry too much about this.

update_end_index = (
cache_update_index + seq_len - 1
) % self._sliding_window + 1
update_end_index = ops.cast(update_end_index, "int32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a general note, torch and jax like int32 on gpu, but tensorflow has limited op support with int32 (and does better with int64). We probably don't have coverage for this code path on GPU on TF with an accelerator yet, but might come up down the line.

query = self._query_dense(hidden_states)

# Note that the original PyTorch implementation uses
# view_as_complex/view_as_real while we use split/concatenate to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

by shuffling the weights using the same transformations

Shuffling what weights?

At the highest level, we should just consider whether we should pull this into the lower level RotaryEmbedding layer. We want it to be useful for the most common use cases of rotary embeddings.

Comment on lines +225 to +228
cache_k = cache_k[:, :update_end_index, ...]
cache_v = cache_v[:, :update_end_index, ...]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this looks unsupported by XLA today, as this would involve dynamic shapes in a compiled while_loop. Let's work on a fix as a follow up.

layers in each transformer decoder. Only `sliding_window` number of tokens
are saved in the cache and used to generate the next token.
Defaults to `512`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document dtype here, as most models won't support it in the way this backbone does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

from keras_nlp.backend import ops


# TODO: Deprecate this in favor of `keras.layers.LayerNormalization` once
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keras.layers.LayerNormalization(rms_scaling=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


class MistralTransformerDecoder(keras.layers.Layer):
"""A Transformer decoder layer for the Mistral backbone."""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove newline.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def __init__(
self,
*,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove star for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@sampathweb sampathweb added the kokoro:force-run Runs Tests on GPU label Dec 9, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Dec 9, 2023
@mattdangerw
Copy link
Member

Looks all green! Let's pull this in.

@mattdangerw mattdangerw merged commit 4ea8c23 into keras-team:master Dec 20, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Mistral.AI's Transformer Model to KerasNLP
4 participants