-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LlamaBackbone
#1203
Add LlamaBackbone
#1203
Conversation
LLaMa uses SILU (sigmoid linear unit) activation function, we don't have it in keras yet (?) ref: https://huggingface.co/meta-llama/Llama-2-13b-chat-hf/blob/main/config.json [Update] : We have it here , but it's not available in docs page. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall structure looks good! Left some high level comments for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few more comments.
TODO : Conversion script |
As |
@mattdangerw I've added a checkpooint conversion script. The output matching is almost there just can't figure out one thing. |
@awsaf49 We absolutely should! But we first need to figure out when we drop Keras 2 support from KerasCV and KerasNLP. Until we do, we can't rely on symbols that only exists in Keras 3. But long term no question, we should used the grouped query attention layer to cut a lot of code from KerasNLP! |
Hey @mattdangerw ! |
/gcbrun |
All checks pass. Nice. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome this is working! Left a few comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops didn't mean to approve till we have the update tests and docstrings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mistral shares almost exactly the same llama backbone so just a few comments so it can easily be reused in Mistral.
@mattdangerw I've added caching as well, outputs continue to match. |
/gcbrun |
Added docstrings too. |
mask_expansion_axis = -3 | ||
for _ in range( | ||
len(attention_scores.shape) - len(attention_mask.shape) | ||
): | ||
attention_mask = ops.expand_dims( | ||
attention_mask, axis=mask_expansion_axis | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the inputs are constrained to be 3 dimensional, we can simplify this as:
mask_expansion_axis = -3 | |
for _ in range( | |
len(attention_scores.shape) - len(attention_mask.shape) | |
): | |
attention_mask = ops.expand_dims( | |
attention_mask, axis=mask_expansion_axis | |
) | |
attention_mask = attention_mask[:, None, :, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC @mattdangerw and I had a conversation about it. Let's keep this as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong feeling. The thing to keep in mind here is what is public API and what's internal to the model.
RotaryEmbedding
is public, that's the one we want to support with multiple different call ranks/configurations.
Llama attention is unexposed, so it's ok to make assumptions about the input shape as long as it's valid for llama models.
CI failures seem unrelated |
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Mostly minor changes.
if axis != sequence_axis and axis != feature_axis: | ||
embedding = ops.expand_dims(embedding, axis) | ||
|
||
return ops.cos(embedding), ops.sin(embedding) | ||
|
||
def _get_inverse_freq(self, rotary_dim): | ||
freq_range = ops.arange(0, rotary_dim, 2, dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should still add that unit test I was mentioning, as it's clear we have been breaking the feature_axis
and sequence_axis
args without meaning to. Something like below as a new unit test in for this file.
inputs = random(batch, sequence, feature)
permuted_inputs = permute(inputs, (0, 2, 1)
outputs = RotaryEmbedding(inputs)
permuted_outputs = RotaryEmbedding(permuted_inputs, sequence_axis=-1, feature_axis=-2)
assertAllEqual(outputs, permute(premuted_outputs, (0, 2, 1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mattdangerw for the suggestion here. Turns out that it does break.
I don't have bandwidth to fix this today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good! We have been heads down getting this Kaggle integration ready anyway, which will be needed before we can actually provide any llama 2 checkpoints.
Let's check in next week. If you are strapped for time I can just patch this in and fix here, I think this and comparing outputs in the conversion script are basically the last remaining issues?
self.rope_scaling_factor = rope_scaling_factor | ||
self.rope_max_wavelength = rope_max_wavelength | ||
|
||
def build(self, inputs_shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as mistral... Consider something like this, where we collocate all einsum equations in build, and we add a nice key at the top. Helps readability.
(ok if we want to punt on this for this pr)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks good! Added.
|
||
with torch.no_grad(): | ||
keras_outputs = keras_model(keras_inputs) | ||
print("Keras output = ", keras_outputs.numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a line that also runs output through the hf version and compares the difference? How close do we get?
/gcbrun |
d054b44
to
834fb8b
Compare
Talked with @shivance, going to try to merge this in with some last fixes to the rotary embedding layer. We will need to follow up and fix the conversion script so it actually validates the output. |
834fb8b
to
a525b68
Compare
a525b68
to
7eb04e0
Compare
`Can you please tell how to load the llama code in KerasNLP just like how we load the bert-model as below? classifier = keras_nlp.models.BertClassifier.from_preset( |
Keras team has accepted #1162 . This PR adds attention, decoder and backbone for Llama.
Here is the colab
This PR is still a work in progress!
@mattdangerw @fchollet