-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LinalgExt] Add online_attention op #17536
[LinalgExt] Add online_attention op #17536
Conversation
1918c97
to
cd0db37
Compare
6a947fe
to
8ea2ff6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks OK to me. I didnt look too much into details of the attention op implementation itself. I am happy to stamp if needed.
Meta comment, please add more comments on methods (more for future you than anything else)
@@ -366,6 +366,16 @@ void DecomposeAttentionPass::runOnOperation() { | |||
SmallVector<Operation *> ops; | |||
decomposeTiledAttention(attnOp, ops, rewriter, optionalTileSize); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the "decomposeTiledAttention" part now? They both are doing the same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to do that in a separate patch. There are a number of transform scripts for attention (The CUDA attention transform scripts) that I need to take into account before doing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round of comments, more about asking questions.
The PR is very big.. It's fine for this one, but please break it to small PRs in the future. If it were me, I'd split it to:
- Introduce OnlineAttention op
- Implement TilingInterface methods for the op
- Implement AggregatedOpInterface methods for the op
- Implement convertToOnlineAttention
- The rest of changes on CPU side.
funcPassManager.addPass( | ||
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this happen before tiling parallel dims? Is it a requirement for tiling reduction loops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can only tile reduction loops on online_attention op. We could do this before tiling parallel dims, but we would then need to propagate lowering_config info in createConvertAttentionToOnlineAttention pass. For more context, the conversion does:
attention { lowering_config }
to
acc = acc_fill
max = max_fill
sum = sum_fill
out:3 = online_attention acc, max, sum {lowering_config}
elementwise out#0, out#2
The lowering config gets preserved on the online_attention op and is used for reduction tiling. Until we have consumer fusion (and greedy fusion for multiple operands/results) fixed, I don't think we can do it.
As a side note, this doesn't allow us to do further levels of parallel tiling on the elementwise and fill operations (which is not the best).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, I would like there to be a way to propagate the lowering_config attribute when I do a conversion like this (which would be putting the tiling information on the type, or somewhere more presistent).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do this before tiling parallel dims, but we would then need to propagate lowering_config info in createConvertAttentionToOnlineAttention pass.
It is more like asking questions but not a requirement to address the comment. I'm trying to see the whole picture of how it could be done in CPU backend.
So it seems that we can convert the op to online_attention op before lowering strategy selection, like what we've done in softmax op. Do you think that we want to keep it as attention form when we're doing the tiling on parallel loops? Or it does not matter if we have "tile online_attention op and fuse its producers/consumers into the for loop"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I understand what you mean now. I can try. I'm thinking there might be problems with fusion because online_attention op has multiple results. Let me try and see if I can do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to try it and land it in this PR, because the PR is already big and it is fairly new to CPU backends. I can pull in others to help with CPU changes later. Are there other pending changes for attention ops?
compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
Yeah, Ideally this patch should have been split up. I just sent my entire experimentation branch as a patch for now (because we need this patch soon). |
8ea2ff6
to
997815b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some cosmetic comments
compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for CPU changes and code structure. I'll review other implementation details later.
funcPassManager.addPass( | ||
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do this before tiling parallel dims, but we would then need to propagate lowering_config info in createConvertAttentionToOnlineAttention pass.
It is more like asking questions but not a requirement to address the comment. I'm trying to see the whole picture of how it could be done in CPU backend.
So it seems that we can convert the op to online_attention op before lowering strategy selection, like what we've done in softmax op. Do you think that we want to keep it as attention form when we're doing the tiling on parallel loops? Or it does not matter if we have "tile online_attention op and fuse its producers/consumers into the for loop"?
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
bf8cfce
to
48470bf
Compare
48470bf
to
cdc9f6b
Compare
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
Outdated
Show resolved
Hide resolved
cdc9f6b
to
dc651fa
Compare
53a65bb
to
713c95b
Compare
Did this regress CPU performance? Presubmit test results are suspicious on this PR and postsubmit started failing after merge.
https://github.com/iree-org/iree/actions/runs/9484305572/job/26134004282 |
Looking at the CI logs, this may have timed out during compilation. That makes more sense than timing out at runtime, but still should be investigated. Compile command from the logs:
|
This reverts commit abf0087.
Reverts #17536 This caused `sdxl-scheduled-unet-3-tank` to hit timeouts when compiling for cpu: https://github.com/iree-org/iree/actions/runs/9484305572/job/26134004282
@Groverkss I'm not able to take a look today, but I can do it tomorrow. Let me know if you want me to take a look. |
This patch adds a new online_attention op. This op represents a partially reduced attention op which can be tiled along it's k2 reduction dimension. This op also has indexing maps, supports tiling on all dimensions other than k1 dimension, and can decompose based on any given indexing maps. This patch also makes the CPU backend use online attention to decompose and tile reduction dimension, allowing it to be tiled along N and batch dimensions, and tiling using LLVMCPUTile. Signed-off-by: Lubo Litchev <lubol@google.com>
Reverts iree-org#17536 This caused `sdxl-scheduled-unet-3-tank` to hit timeouts when compiling for cpu: https://github.com/iree-org/iree/actions/runs/9484305572/job/26134004282 Signed-off-by: Lubo Litchev <lubol@google.com>
This patch adds a new online_attention op. This op represents a partially reduced attention op which can be tiled along it's k2 reduction dimension. This op also has indexing maps, supports tiling on all dimensions other than k1 dimension, and can decompose based on any given indexing maps.
This patch also makes the CPU backend use online attention to decompose and tile reduction dimension, allowing it to be tiled along N and batch dimensions, and tiling using LLVMCPUTile.