Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GroupedQueryAttention layer #18488

Merged
merged 38 commits into from
Oct 22, 2023
Merged

Add GroupedQueryAttention layer #18488

merged 38 commits into from
Oct 22, 2023

Conversation

awsaf49
Copy link
Contributor

@awsaf49 awsaf49 commented Sep 25, 2023

This PR corresponds to Issue: #18402.

@codecov-commenter
Copy link

codecov-commenter commented Sep 25, 2023

Codecov Report

Attention: 6 lines in your changes are missing coverage. Please review.

Comparison is base (29a954a) 77.40% compared to head (a3b89dc) 78.36%.
Report is 114 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18488      +/-   ##
==========================================
+ Coverage   77.40%   78.36%   +0.95%     
==========================================
  Files         331      335       +4     
  Lines       31972    32600     +628     
  Branches     6241     6355     +114     
==========================================
+ Hits        24749    25548     +799     
+ Misses       5646     5486     -160     
+ Partials     1577     1566      -11     
Flag Coverage Δ
keras 78.25% <94.82%> (+0.93%) ⬆️
keras-jax 63.69% <94.82%> (+1.45%) ⬆️
keras-numpy 58.00% <75.86%> (+1.84%) ⬆️
keras-tensorflow 64.53% <94.82%> (+2.35%) ⬆️
keras-torch 65.42% <94.82%> (+1.44%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
keras/layers/__init__.py 96.03% <100.00%> (+0.03%) ⬆️
keras/layers/attention/grouped_query_attention.py 94.78% <94.78%> (ø)

... and 67 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@awsaf49
Copy link
Contributor Author

awsaf49 commented Sep 25, 2023

Hello @mattdangerw,

I would appreciate your valuable feedback on the implementation. I used llamav2 as a reference for this layer but also used keras.layers.MultiHeadAttention to maintain the keras-api.

I did observe some key differences between the llamav2 implementation and the Keras MultiHeadAttention:

  1. Llamav2 approaches this layer as self-attention, accepting a single tensor as input, whereas Keras MultiHeadAttention takes query, key, and value as separate inputs.

  2. In Llamav2, the layer accepts a mask input and directly integrates it with attention_scores. In contrast, Keras MHA computes the mask internally within the layer and applies it separately within the Softmax layer.

  3. Llamav2 utilizes a key-value cache mechanism, which differs somewhat from keras-nlp.layers.CachedMultiHeadAttention. I'm currently undecided whether to incorporate the cache within this layer or create a separate layer akin to the keras-nlp approach. Your insights on this matter would be highly appreciated.

Additionally, I haven't yet included the Dense layer kwargs and unit testing. I believe this can be easily integrated at a later stage. This implementation does show speed up of GQA and MQA than MHA .

@awsaf49 awsaf49 marked this pull request as ready for review September 25, 2023 13:05
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left some initial high level comments.

@fchollet
Copy link
Collaborator

The code / API looks good to me (would like @mattdangerw to LGTM it though). Please add unit tests!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay! Overall this looks good to me, just a few last comments.

Do we know of a reference layer somewhere in a torch or jax lib we could use to sanity check our numerics?

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 19, 2023
What happens if `num_query_heads` is not divisible by `num_key_value_heads`?
use different letters to denote `num_query_heads` vs `num_key_value_heads`
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Oct 19, 2023
@awsaf49
Copy link
Contributor Author

awsaf49 commented Oct 19, 2023

Sorry for the delay! Overall this looks good to me, just a few last comments.

Do we know of a reference layer somewhere in a torch or jax lib we could use to sanity check our numerics?

I tried to match numeric result of PyTorch (simple ver. of llama-v2) with this impl. but I am having trouble initialing them with same weights which I think is essential for testing. Probably it's due to Einsum.

Here is the pytorch code I'm using, which is simpler ver. of llama-v2:

import numpy as np
import torch
import math
import torch.nn.functional as F

class GroupedQueryAttentionTorch(nn.Module):
    """Multi-head attention module."""
    
    def __init__(self, dim, head_dim, n_kv_heads, n_q_heads):
        """Initialize the Attention module.
        
        Args:
            dim (int): Hidden dimension size
            n_kv_heads (int): Number of key and value heads 
            n_q_heads (int): Number of query heads
        """
        
        super().__init__()
        
        self.n_kv_heads = n_kv_heads
        self.n_q_heads = n_q_heads 
        self.n_rep = self.n_q_heads // self.n_kv_heads
        self.dim = dim
        self.head_dim = head_dim
        
        self.wq = nn.Linear(dim, n_q_heads * self.head_dim, 
                            bias=False,) 
        self.wk = nn.Linear(dim, n_kv_heads * self.head_dim,
                            bias=False)
        self.wv = nn.Linear(dim, n_kv_heads * self.head_dim,
                            bias=False)
        self.wo = nn.Linear(n_q_heads * self.head_dim, dim,
                            bias=False)

    def forward(
        self,
        x,
        mask,
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_q_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)

        # repeat k/v heads if n_kv_heads < n_q_heads
        key = self.repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_q_heads, head_dim)
        value = self.repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_q_heads, head_dim)

        query = xq.transpose(1, 2) / math.sqrt(self.head_dim) # (bs, n_q_heads, seqlen, head_dim)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        scores = torch.matmul(query, key.transpose(2, 3)) 
        if mask is not None:
            scores = scores + mask  # (bs, n_q_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, value)  # (bs, n_q_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)
    
    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
        bs, slen, n_kv_heads, head_dim = x.shape
        if n_rep == 1:
            return x
        return (
            x[:, :, :, None, :]
            .expand(bs, slen, n_kv_heads, n_rep, head_dim)
            .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
        )

@fchollet
Copy link
Collaborator

I tried to match numeric result of PyTorch (simple ver. of llama-v2) with this impl. but I am having trouble initialing them with same weights which I think is essential for testing. Probably it's due to Einsum.

Do you mean the weight layout is different? What are the different weight matrixes and their shapes?

@awsaf49
Copy link
Contributor Author

awsaf49 commented Oct 20, 2023

I tried to match numeric result of PyTorch (simple ver. of llama-v2) with this impl. but I am having trouble initialing them with same weights which I think is essential for testing. Probably it's due to Einsum.

Do you mean the weight layout is different? What are the different weight matrixes and their shapes?

@fchollet Thanks or your help with this query! After some digging, I noticed an interesting difference in the weight shapes between Linear layer in PyTorch version and the EinsumDense layer in Keras. In PyTorch, the weight shape for the Linear layer looks like (n_q_heads * head_dim, dim), while for the EinsumDense layer in Keras, it's (dim, n_q_heads, head_dim). To make things work, I had to add a little reshape action after the transpose operation to fix the shape mismatch.

After that, I ran a numeric comparison of GQA between Keras (TensorFlow backend) and PyTorch version (simpler version of llama-v2; code above) and it matched up perfectly. I used the following code to test,

# Define parameters and data
head_dim = 2
n_q_heads = 3 * 2
n_kv_heads = 3
batch_size = 2
seq_len = 4
dim = 16

# Instantiate the torch GQA
gqa_torch = GroupedQueryAttentionTorch(dim=dim, 
                                       head_dim=head_dim,
                                       n_kv_heads=n_kv_heads, 
                                       n_q_heads=n_q_heads)


# Manually initialize PyTorch model weights to known values
init_tensor1 = torch.rand(n_q_heads * head_dim, dim)
init_tensor2 = torch.rand(n_kv_heads * head_dim, dim)
init_tensor3 = torch.rand(dim, n_q_heads * head_dim)
gqa_torch.wq.weight = nn.Parameter(init_tensor1.clone())
gqa_torch.wk.weight = nn.Parameter(init_tensor2.clone())
gqa_torch.wv.weight = nn.Parameter(init_tensor2.clone())
gqa_torch.wo.weight = nn.Parameter(init_tensor3.clone())

# Generate some example input data
inputs = torch.randn(batch_size, seq_len, dim)
mask = None  # You can add a mask tensor here if needed

# Forward Pass
output_torch = gqa_torch(inputs, mask).detach().numpy()
print(output_torch.shape)


# Instantiate the keras GQA
gqa_keras = GroupedQueryAttentionKeras(head_dim=head_dim,
                                    num_query_heads=n_q_heads,
                                    num_key_value_heads=n_kv_heads)
gqa_keras.build((batch_size, seq_len, dim), (batch_size, seq_len, dim))

# Manually initialize Keras model weights to same known values
gqa_keras._query_dense.kernel = tf.Variable(init_tensor1.numpy().T.reshape(dim, n_q_heads, head_dim))
gqa_keras._key_dense.kernel = tf.Variable(init_tensor2.numpy().T.reshape(dim, n_kv_heads, head_dim))
gqa_keras._value_dense.kernel = tf.Variable(init_tensor2.numpy().T.reshape(dim, n_kv_heads, head_dim))
gqa_keras._output_dense.kernel = tf.Variable(init_tensor3.numpy().T.reshape(n_q_heads, head_dim, dim))


# Generate inputs
inputs = tf.convert_to_tensor(inputs.numpy())
mask = None  # You can add a mask tensor here if needed

# Forward Pass
output_keras = gqa_keras(inputs, inputs, attention_mask=mask).numpy()
print(output_keras.shape)

# Test if torch keras result match
np.testing.assert_allclose(output_keras, output_torch, rtol=1e-5, atol=1e-5)

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- thank you for the great contribution! 👍

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 22, 2023
@fchollet fchollet merged commit a35287a into keras-team:master Oct 22, 2023
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Oct 22, 2023
@awsaf49
Copy link
Contributor Author

awsaf49 commented Oct 22, 2023

Thanks for your feedbacks @mattdangerw and @fchollet. I hope this layer helps the community to build LLMs more easily.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

6 participants