-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MistralAI's 7B Transformer as a backbone in KerasNLP Models #1314
Add MistralAI's 7B Transformer as a backbone in KerasNLP Models #1314
Conversation
Still need to take a pass, but a quick note on tests. Looks like some Keras nightly changes broke us recently, debugging currently. You can ignore the Keras 3 failures. However, the Keras 2 failure looks mistral related and is worth digging into. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I will probably try to step more carefully through the sliding window caching part to understand it better, but left some initial comments.
# TODO(tirthasheshpatel): Generalize the attention layer | ||
# TODO(tirthasheshpatel): Merge `LlamaAttention` with this layer | ||
# TODO(tirthasheshpatel): Use flash attention | ||
# TODO(tirthasheshpatel): Add dropout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to do this one if it's easy enough. We usually try to add dropout along with the original architecture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
query = self._query_dense(hidden_states) | ||
|
||
# Note that the original PyTorch implementation uses | ||
# view_as_complex/view_as_real while we use split/concatenate to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this a bit more? Why do we need to consider complex numbers here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's the mistral source for computing frequencies (same as Llama 2) and computing the embeddings (same as Llama 2 too)
The frequencies and inputs are treated as complex numbers and the computation follows the "Theoretical Explanation" section in the paper.
PyTorch's view_as_complex
is used to convert the tensors to complex numbers which reshapes the inputs to shape (*x.shape[:-1], x.shape[-1] // 2, 2)
and treats each pair of elements in axis -1 as a (real, complex)
pair. RotaryEmbedding
uses ops.split(x, 2)
to convert the inputs to a complex representation (after splitting, the first half of the inputs become the real part and the other half becomes the complex part).
This is the only fundamental difference in both the computations. We can get the same results if we shuffle the inputs such that the alternate elements get moved to the end of the tensor. Hence, the x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1)
bit before passing the inputs to the rotary embedding layer.
The reverse transformation exactly mirrors/undoes what we did above.
Code demonstration of the above explaination
import torch
import numpy as np
from keras import ops
def _reshape_for_broadcast(freqs_cis, x):
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim = x.ndim
assert 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
freqs_cis.shape,
(x.shape[1], x.shape[-1]),
)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
# Llama's version of rotary embeddings
def apply_rotary_emb(
xq,
freqs_cis,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
return xq_out.type_as(xq)
# Our version of the same computation.
# With transformations to match the `apply_rotary_emb` function above.
def apply_rotary_pos_emb(tensor, cos_emb, sin_emb):
tensor = ops.concatenate((tensor[..., ::2], tensor[..., 1::2]), axis=-1)
x1, x2 = ops.split(tensor, 2, axis=-1)
half_rot_tensor = ops.concatenate((-x2, x1), axis=-1)
res = (tensor * cos_emb) + (half_rot_tensor * sin_emb)
return res
x = np.random.standard_normal((1,2,1,16))
cos_emb = np.random.standard_normal((2,8))
sin_emb = np.random.standard_normal((2,8))
print(x)
print(ops.concatenate((x[..., ::2], x[..., 1::2]), axis=-1))
print(np.split(np.concatenate((x[..., ::2], x[..., 1::2]), axis=-1), 2, axis=-1))
print(torch.view_as_complex(torch.tensor(x).reshape(*x.shape[:-1], -1, 2)))
print(apply_rotary_emb(torch.tensor(x), torch.tensor(cos_emb + sin_emb * 1.0j)))
y = apply_rotary_pos_emb(
x,
np.concatenate([cos_emb[None, :, None, :]]*2, axis=-1),
np.concatenate([sin_emb[None, :, None, :]]*2, axis=-1)
)
print(ops.reshape(ops.stack(ops.split(y, 2, axis=-1), axis=-1), (y.shape[0], y.shape[1], y.shape[2], -1)))
A bit complicated but it should be possible to achieve the same behavior by shuffling the weights using the same transformations. I believe that's what the huggingface folks have done which is why this isn't required in the Llama backbone PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
by shuffling the weights using the same transformations
Shuffling what weights?
At the highest level, we should just consider whether we should pull this into the lower level RotaryEmbedding layer. We want it to be useful for the most common use cases of rotary embeddings.
key = ops.cast( | ||
cache_k[ | ||
:, | ||
: (cache_update_index + seq_len - 1) % self._sliding_window |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general note, use intermediate variables to improve the readability here if you can. especially if you can come up with good names, say for xx = (cache_update_index + seq_len - 1) % self._sliding_window
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. This part was removed to make the caching step XLA compatible.
attention_scores, attention_mask | ||
) | ||
attention_output = ops.einsum( | ||
"acbe,aecd->abcd", attention_scores, value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider something like this, where we collocate all einsum equations in build, and we add a nice key at the top. Helps readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could still pull these up into build for a nice co-location, and rewrite to use the same symbols as above key.
in build...
self._dot_product_equation = ...
self._combine_equation = ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
sliding_window=512, | ||
**kwargs, | ||
): | ||
decoder_sequence_shape = kwargs.pop("decoder_sequence_shape", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason we needed this for TransformerDecoder
was Keras 2 struggle with multiple build shape arguments. We shouldn't need this here I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
"kernel_initializer": keras.initializers.serialize( | ||
self.kernel_initializer | ||
), | ||
"decoder_sequence_shape": self._decoder_sequence_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should need this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
) | ||
# Below is a workaround for `ops.triu` for Keras 2. | ||
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed. | ||
# causal_mask = ops.triu(causal_mask_upper, k=-self.sliding_window) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ops.triu
is ready now, we could do this like
if config.keras_3:
ops.triu(...)
else:
ops.arange...
What does the overall structure of this mask look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mistral uses a banded matrix structure. For example, for inputs of sequence length 5 and sliding window of size 2, we would have something like:
In [1]: from keras import ops
In [2]: ops.triu(ops.tril(ops.ones((5, 5)), k=0), k=-2) # generally k = -sliding_window
Out[2]:
<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[0., 1., 1., 1., 0.],
[0., 0., 1., 1., 1.]], dtype=float32)>
@mattdangerw I think I have addressed all your comments except the docs one. Will add docs in the next commit. |
**kwargs, | ||
): | ||
# Get the dtype | ||
dtype = kwargs.pop("dtype", keras.backend.floatx()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dtypes work as expected for the TensorFlow and JAX backends but PyTorch currently fails internally in Keras 3 due to dtype issues.
cache_k = cache_k[:, :update_end_index, ...] | ||
cache_v = cache_v[:, :update_end_index, ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX fails here if cache_update_index
is a traced JAX array. But the value of cache_update_index should be known at each step. I think the right fix here is to make sure that the GenerateTask
model passes concrete values here. Otherwise, it would be pretty tricky to make sliding window attention work in JAX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this looks unsupported by XLA today, as this would involve dynamic shapes in a compiled while_loop
. Let's work on a fix as a follow up.
cache=None, | ||
cache_update_index=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: Right now, caching doesn't work when the sequence length is greater than the sliding window.
Can be addressed as a follow-up when adding the Generator model; shouldn't be a blocker here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. If the upstream version is not solving this correctly, let's not worry too much about this.
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Feel free to pull this in after test green and addressing remaining comments.
@@ -97,7 +97,7 @@ def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): | |||
return (tensor * cos_emb) + (half_rot_tensor * sin_emb) | |||
|
|||
def _compute_cos_sin_embedding(self, x, rotary_dim, start_index): | |||
freq_range = ops.arange(0, rotary_dim, 2, dtype="float32") | |||
freq_range = ops.cast(ops.arange(0, rotary_dim, 2), self.compute_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a double cast (see the next line). Remove one or the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
cache=None, | ||
cache_update_index=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. If the upstream version is not solving this correctly, let's not worry too much about this.
update_end_index = ( | ||
cache_update_index + seq_len - 1 | ||
) % self._sliding_window + 1 | ||
update_end_index = ops.cast(update_end_index, "int32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a general note, torch and jax like int32 on gpu, but tensorflow has limited op support with int32 (and does better with int64). We probably don't have coverage for this code path on GPU on TF with an accelerator yet, but might come up down the line.
query = self._query_dense(hidden_states) | ||
|
||
# Note that the original PyTorch implementation uses | ||
# view_as_complex/view_as_real while we use split/concatenate to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
by shuffling the weights using the same transformations
Shuffling what weights?
At the highest level, we should just consider whether we should pull this into the lower level RotaryEmbedding layer. We want it to be useful for the most common use cases of rotary embeddings.
cache_k = cache_k[:, :update_end_index, ...] | ||
cache_v = cache_v[:, :update_end_index, ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this looks unsupported by XLA today, as this would involve dynamic shapes in a compiled while_loop
. Let's work on a fix as a follow up.
layers in each transformer decoder. Only `sliding_window` number of tokens | ||
are saved in the cache and used to generate the next token. | ||
Defaults to `512`. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document dtype here, as most models won't support it in the way this backbone does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
from keras_nlp.backend import ops | ||
|
||
|
||
# TODO: Deprecate this in favor of `keras.layers.LayerNormalization` once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keras.layers.LayerNormalization(rms_scaling=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
class MistralTransformerDecoder(keras.layers.Layer): | ||
"""A Transformer decoder layer for the Mistral backbone.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove newline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
def __init__( | ||
self, | ||
*, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove star for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
for Keras 2 compatibility
32f166a
to
c3d71c2
Compare
Looks all green! Let's pull this in. |
Fixes #1275
This PR adds a
MistralBackbone
backbone model and all it's components.Most of the components share a lot of code with #1203.
Reference implementation: mistralai/mistral-src
Colab for weight transfer: https://colab.research.google.com/drive/1MoD7JJasThxmalspG3c21oYMEc_qRbti?usp=sharing
TODOs:
CachedMistralAttention
andMistralTransformerDecoder
layers.