Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Infer Symbolic Shape No.132] flash_attn_unpadded InferMeta and Symbolic Inference #67701

Conversation

ykcai-daniel
Copy link

@ykcai-daniel ykcai-daniel commented Aug 25, 2024

PR Category

CINN

PR Types

Improvements

Description

flash_attn_unpadded 增加对应的infer_symbolic_shape接口:
根据Paddle nn中的限制,shape的限制如下:

  1. Input q,k,v:
q, k,v : The query tensor in the Attention module.
                        3-D tensor with shape:
                        [total_seq_len, num_heads, head_dim].
                        The dtype can be float61 or bfloat16.
  1. cu_seqlens:
        cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
                        used to index query.
        cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,
                        used to index key and value.

注::这里的文档可能存在问题,在flash-attention的文档中,flash_attn_unpadded是用来处理qkv的head数量不一样的情况,qkv的shape可以不同(需要再确认paddle的实现):

 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.

更改包括:
InferMeta和Infer Shape Symbolic新增对size为3的q,k,v shape支持

TODO:

  1. 注意到flash_attn_with_sparse_mask也使用了InferMeta,但是InferMeta并不支持attn_mask,可以考虑之后支持attn_mask
  2. 增加flash_attn_unpadded softmax shape的支持
  3. 增加更多的cstr

Copy link

paddle-bot bot commented Aug 25, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@CLAassistant
Copy link

CLAassistant commented Aug 25, 2024

CLA assistant check
All committers have signed the CLA.

@ykcai-daniel
Copy link
Author

#67510 中也修改了flash_attn_sparse_mask的代码,infermeta的修改可能存在冲突

Copy link

paddle-ci-bot bot commented Sep 2, 2024

Sorry to inform you that b524487's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@gongshaotian
Copy link
Contributor

#67510 中也修改了flash_attn_sparse_mask的代码,infermeta的修改可能存在冲突

好的,我去联系下另外一个开发者

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants