Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LlamaBackbone #1203

Merged
merged 2 commits into from
Dec 22, 2023
Merged

Add LlamaBackbone #1203

merged 2 commits into from
Dec 22, 2023

Conversation

shivance
Copy link
Collaborator

@shivance shivance commented Aug 9, 2023

Keras team has accepted #1162 . This PR adds attention, decoder and backbone for Llama.

Here is the colab

This PR is still a work in progress!

@mattdangerw @fchollet

@shivance shivance requested a review from mattdangerw August 9, 2023 03:46
@shivance
Copy link
Collaborator Author

shivance commented Aug 9, 2023

LLaMa uses SILU (sigmoid linear unit) activation function, we don't have it in keras yet (?)

ref: https://huggingface.co/meta-llama/Llama-2-13b-chat-hf/blob/main/config.json

[Update] : We have it here , but it's not available in docs page.

@shivance shivance marked this pull request as draft August 9, 2023 04:18
@shivance shivance marked this pull request as ready for review August 9, 2023 15:52
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall structure looks good! Left some high level comments for now.

@shivance shivance requested a review from mattdangerw August 15, 2023 06:07
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few more comments.

@shivance
Copy link
Collaborator Author

TODO : Conversion script

@awsaf49
Copy link

awsaf49 commented Oct 28, 2023

As GropuedQueryAttention has been added to Keras I think it would be nice to have it in llama-v2. PR: keras-team/keras#18488.
@fchollet @mattdangerw

@shivance
Copy link
Collaborator Author

shivance commented Nov 3, 2023

@mattdangerw I've added a checkpooint conversion script. The output matching is almost there just can't figure out one thing.
The output of each decoder layer matches that of huggingface model (with high precision) but after last decoder layer something is happening and the final outputs differ by a margin. Not sure why is it so?

@mattdangerw
Copy link
Member

mattdangerw commented Nov 3, 2023

As GropuedQueryAttention has been added to Keras I think it would be nice to have it in llama-v2. PR: keras-team/keras#18488.

@awsaf49 We absolutely should! But we first need to figure out when we drop Keras 2 support from KerasCV and KerasNLP. Until we do, we can't rely on symbols that only exists in Keras 3.

But long term no question, we should used the grouped query attention layer to cut a lot of code from KerasNLP!

@shivance shivance requested a review from mattdangerw November 4, 2023 10:14
@shivance
Copy link
Collaborator Author

shivance commented Nov 4, 2023

Hey @mattdangerw !
This PR is ready to merge. The outputs are matching (finally) .

Huggingface Llama outputs:
post_layer_norm_hf

KerasNLP llama outputs:
keras_outputs

@shivance
Copy link
Collaborator Author

shivance commented Nov 4, 2023

/gcbrun

@shivance
Copy link
Collaborator Author

shivance commented Nov 4, 2023

All checks pass. Nice.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome this is working! Left a few comments.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops didn't mean to approve till we have the update tests and docstrings.

Copy link
Contributor

@tirthasheshpatel tirthasheshpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistral shares almost exactly the same llama backbone so just a few comments so it can easily be reused in Mistral.

@shivance
Copy link
Collaborator Author

@mattdangerw I've added caching as well, outputs continue to match.

@shivance
Copy link
Collaborator Author

/gcbrun

@shivance
Copy link
Collaborator Author

Added docstrings too.

Comment on lines +188 to +165
mask_expansion_axis = -3
for _ in range(
len(attention_scores.shape) - len(attention_mask.shape)
):
attention_mask = ops.expand_dims(
attention_mask, axis=mask_expansion_axis
)
Copy link
Contributor

@tirthasheshpatel tirthasheshpatel Nov 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the inputs are constrained to be 3 dimensional, we can simplify this as:

Suggested change
mask_expansion_axis = -3
for _ in range(
len(attention_scores.shape) - len(attention_mask.shape)
):
attention_mask = ops.expand_dims(
attention_mask, axis=mask_expansion_axis
)
attention_mask = attention_mask[:, None, :, :]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC @mattdangerw and I had a conversation about it. Let's keep this as is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong feeling. The thing to keep in mind here is what is public API and what's internal to the model.

RotaryEmbedding is public, that's the one we want to support with multiple different call ranks/configurations.

Llama attention is unexposed, so it's ok to make assumptions about the input shape as long as it's valid for llama models.

@shivance
Copy link
Collaborator Author

CI failures seem unrelated

@mattdangerw
Copy link
Member

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Mostly minor changes.

if axis != sequence_axis and axis != feature_axis:
embedding = ops.expand_dims(embedding, axis)

return ops.cos(embedding), ops.sin(embedding)

def _get_inverse_freq(self, rotary_dim):
freq_range = ops.arange(0, rotary_dim, 2, dtype="float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should still add that unit test I was mentioning, as it's clear we have been breaking the feature_axis and sequence_axis args without meaning to. Something like below as a new unit test in for this file.

inputs = random(batch, sequence, feature)
permuted_inputs = permute(inputs, (0, 2, 1)
outputs = RotaryEmbedding(inputs)
permuted_outputs = RotaryEmbedding(permuted_inputs, sequence_axis=-1, feature_axis=-2)
assertAllEqual(outputs, permute(premuted_outputs, (0, 2, 1))

Copy link
Collaborator Author

@shivance shivance Nov 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mattdangerw for the suggestion here. Turns out that it does break.
I don't have bandwidth to fix this today.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good! We have been heads down getting this Kaggle integration ready anyway, which will be needed before we can actually provide any llama 2 checkpoints.

Let's check in next week. If you are strapped for time I can just patch this in and fix here, I think this and comparing outputs in the conversion script are basically the last remaining issues?

self.rope_scaling_factor = rope_scaling_factor
self.rope_max_wavelength = rope_max_wavelength

def build(self, inputs_shape):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as mistral... Consider something like this, where we collocate all einsum equations in build, and we add a nice key at the top. Helps readability.

https://github.com/keras-team/keras/blob/master/keras/layers/attention/grouped_query_attention.py#L124-L167

(ok if we want to punt on this for this pr)

Copy link
Collaborator Author

@shivance shivance Nov 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good! Added.


with torch.no_grad():
keras_outputs = keras_model(keras_inputs)
print("Keras output = ", keras_outputs.numpy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a line that also runs output through the hf version and compares the difference? How close do we get?

@mattdangerw
Copy link
Member

/gcbrun

@mattdangerw
Copy link
Member

Talked with @shivance, going to try to merge this in with some last fixes to the rotary embedding layer.

We will need to follow up and fix the conversion script so it actually validates the output.

@mattdangerw mattdangerw merged commit 5fd92c8 into keras-team:master Dec 22, 2023
11 checks passed
@ashmalvayani
Copy link

ashmalvayani commented Jan 25, 2024

`Can you please tell how to load the llama code in KerasNLP just like how we load the bert-model as below?

classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=2,
activation="softmax",
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants