Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LinalgExt] Add online_attention op #17536

Merged
merged 31 commits into from
Jun 12, 2024

Conversation

Groverkss
Copy link
Contributor

@Groverkss Groverkss commented May 31, 2024

This patch adds a new online_attention op. This op represents a partially reduced attention op which can be tiled along it's k2 reduction dimension. This op also has indexing maps, supports tiling on all dimensions other than k1 dimension, and can decompose based on any given indexing maps.

This patch also makes the CPU backend use online attention to decompose and tile reduction dimension, allowing it to be tiled along N and batch dimensions, and tiling using LLVMCPUTile.

@Groverkss Groverkss force-pushed the new-decomposition-attention branch from 1918c97 to cd0db37 Compare June 3, 2024 14:56
@Groverkss Groverkss marked this pull request as ready for review June 4, 2024 18:20
@Groverkss Groverkss force-pushed the new-decomposition-attention branch from 6a947fe to 8ea2ff6 Compare June 5, 2024 14:29
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks OK to me. I didnt look too much into details of the attention op implementation itself. I am happy to stamp if needed.

Meta comment, please add more comments on methods (more for future you than anything else)

@@ -366,6 +366,16 @@ void DecomposeAttentionPass::runOnOperation() {
SmallVector<Operation *> ops;
decomposeTiledAttention(attnOp, ops, rewriter, optionalTileSize);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the "decomposeTiledAttention" part now? They both are doing the same thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to do that in a separate patch. There are a number of transform scripts for attention (The CUDA attention transform scripts) that I need to take into account before doing this.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round of comments, more about asking questions.

The PR is very big.. It's fine for this one, but please break it to small PRs in the future. If it were me, I'd split it to:

  1. Introduce OnlineAttention op
  2. Implement TilingInterface methods for the op
  3. Implement AggregatedOpInterface methods for the op
  4. Implement convertToOnlineAttention
  5. The rest of changes on CPU side.

Comment on lines +600 to +601
funcPassManager.addPass(
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this happen before tiling parallel dims? Is it a requirement for tiling reduction loops?

Copy link
Contributor Author

@Groverkss Groverkss Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can only tile reduction loops on online_attention op. We could do this before tiling parallel dims, but we would then need to propagate lowering_config info in createConvertAttentionToOnlineAttention pass. For more context, the conversion does:

attention { lowering_config }

to

acc = acc_fill
max = max_fill
sum = sum_fill
out:3 = online_attention acc, max, sum {lowering_config}
elementwise out#0, out#2

The lowering config gets preserved on the online_attention op and is used for reduction tiling. Until we have consumer fusion (and greedy fusion for multiple operands/results) fixed, I don't think we can do it.

As a side note, this doesn't allow us to do further levels of parallel tiling on the elementwise and fill operations (which is not the best).

Copy link
Contributor Author

@Groverkss Groverkss Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, I would like there to be a way to propagate the lowering_config attribute when I do a conversion like this (which would be putting the tiling information on the type, or somewhere more presistent).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this before tiling parallel dims, but we would then need to propagate lowering_config info in createConvertAttentionToOnlineAttention pass.

It is more like asking questions but not a requirement to address the comment. I'm trying to see the whole picture of how it could be done in CPU backend.

So it seems that we can convert the op to online_attention op before lowering strategy selection, like what we've done in softmax op. Do you think that we want to keep it as attention form when we're doing the tiling on parallel loops? Or it does not matter if we have "tile online_attention op and fuse its producers/consumers into the for loop"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I understand what you mean now. I can try. I'm thinking there might be problems with fusion because online_attention op has multiple results. Let me try and see if I can do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to try it and land it in this PR, because the PR is already big and it is fairly new to CPU backends. I can pull in others to help with CPU changes later. Are there other pending changes for attention ops?

@Groverkss
Copy link
Contributor Author

First round of comments, more about asking questions.

The PR is very big.. It's fine for this one, but please break it to small PRs in the future. If it were me, I'd split it to:

  1. Introduce OnlineAttention op
  2. Implement TilingInterface methods for the op
  3. Implement AggregatedOpInterface methods for the op
  4. Implement convertToOnlineAttention
  5. The rest of changes on CPU side.

Yeah, Ideally this patch should have been split up. I just sent my entire experimentation branch as a patch for now (because we need this patch soon).

@Groverkss Groverkss force-pushed the new-decomposition-attention branch from 8ea2ff6 to 997815b Compare June 7, 2024 14:13
@Groverkss Groverkss requested a review from hanhanW June 7, 2024 14:13
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some cosmetic comments

@Groverkss Groverkss requested a review from kuhar June 7, 2024 15:11
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for CPU changes and code structure. I'll review other implementation details later.

Comment on lines +600 to +601
funcPassManager.addPass(
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this before tiling parallel dims, but we would then need to propagate lowering_config info in createConvertAttentionToOnlineAttention pass.

It is more like asking questions but not a requirement to address the comment. I'm trying to see the whole picture of how it could be done in CPU backend.

So it seems that we can convert the op to online_attention op before lowering strategy selection, like what we've done in softmax op. Do you think that we want to keep it as attention form when we're doing the tiling on parallel loops? Or it does not matter if we have "tile online_attention op and fuse its producers/consumers into the for loop"?

@Groverkss Groverkss force-pushed the new-decomposition-attention branch from bf8cfce to 48470bf Compare June 10, 2024 12:11
@Groverkss Groverkss force-pushed the new-decomposition-attention branch from 48470bf to cdc9f6b Compare June 10, 2024 16:45
@Groverkss Groverkss force-pushed the new-decomposition-attention branch from cdc9f6b to dc651fa Compare June 11, 2024 17:11
@Groverkss Groverkss enabled auto-merge (squash) June 11, 2024 17:23
@Groverkss Groverkss force-pushed the new-decomposition-attention branch from 53a65bb to 713c95b Compare June 12, 2024 13:43
@Groverkss Groverkss merged commit abf0087 into iree-org:main Jun 12, 2024
50 of 51 checks passed
@ScottTodd
Copy link
Member

Did this regress CPU performance?

Presubmit test results are suspicious on this PR and postsubmit started failing after merge.

FAILED SHARK-TestSuite/iree_tests/pytorch/models/sdxl-scheduled-unet-3-tank/model.mlirbc::cpu_llvm_task_real_weights - Failed: Timeout >1200.0s

https://github.com/iree-org/iree/actions/runs/9484305572/job/26134004282

image

@ScottTodd
Copy link
Member

Looking at the CI logs, this may have timed out during compilation. That makes more sense than timing out at runtime, but still should be investigated.

Compile command from the logs:

INFO     root:conftest.py:393 Launching compile command:
cd /home/nod/actions-runner/_work/iree/iree/SHARK-TestSuite/iree_tests/pytorch/models/sdxl-scheduled-unet-3-tank && iree-compile model.mlirbc --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host --iree-input-demote-f64-to-f32 -o model_cpu_llvm_task_real_weights.vmfb

Model file: https://github.com/nod-ai/SHARK-TestSuite/blob/main/iree_tests/pytorch/models/sdxl-scheduled-unet-3-tank/model.mlirbc

ScottTodd added a commit that referenced this pull request Jun 12, 2024
ScottTodd added a commit that referenced this pull request Jun 12, 2024
@hanhanW
Copy link
Contributor

hanhanW commented Jun 13, 2024

@Groverkss I'm not able to take a look today, but I can do it tomorrow. Let me know if you want me to take a look.

LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
This patch adds a new online_attention op. This op represents a
partially reduced attention op which can be tiled along it's k2
reduction dimension. This op also has indexing maps, supports tiling on
all dimensions other than k1 dimension, and can decompose based on any
given indexing maps.

This patch also makes the CPU backend use online attention to decompose
and tile reduction dimension, allowing it to be tiled along N and batch
dimensions, and tiling using LLVMCPUTile.

Signed-off-by: Lubo Litchev <lubol@google.com>
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
Reverts iree-org#17536

This caused `sdxl-scheduled-unet-3-tank` to hit timeouts when compiling
for cpu:
https://github.com/iree-org/iree/actions/runs/9484305572/job/26134004282

Signed-off-by: Lubo Litchev <lubol@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants