[Infer Symbolic Shape No.132] flash_attn_unpadded InferMeta and Symbolic Inference #67701
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR Category
CINN
PR Types
Improvements
Description
为
flash_attn_unpadded
增加对应的infer_symbolic_shape接口:根据Paddle nn中的限制,shape的限制如下:
注::这里的文档可能存在问题,在flash-attention的文档中,flash_attn_unpadded是用来处理qkv的head数量不一样的情况,qkv的shape可以不同(需要再确认paddle的实现):
更改包括:
InferMeta和Infer Shape Symbolic新增对size为3的q,k,v shape支持
TODO: