Skip to content

Commit

Permalink
[kernel] Refactor FA kernel to be FA_transV when possible. (#568)
Browse files Browse the repository at this point in the history
Flash Attention transpose_V variant is significantly faster than the non
transpose_V variant. This is due to many MM intrinsics being mmtb by
default. Hence, doing FA transpose_V will allow for better/more
contiguous reads from shared memory to register, improving the attention
performance vastly. This also makes FP8 faster than FP16. I have tested
that it indeed improves SDXL performance on FP8, making FP8 faster than
our FP16 model.

I have also tested/confirmed that, if we do not find any producers that
we can fuse with, it seem to re-fuse back into the attention. Hence, the
worst performance it will get is same as before we un-split the
transpose.

For some data on a microbenchmark with real size from SDXL:
```
(B0, B1, M, K1, K2, N): (2, 10, 4096, 64, 4096, 64)
Over 100 runs:

FP16 non transpose:  22.7 ms
FP8 non transpose: 23.8 ms

FP16 transpose: 20.1 ms
FP8 transpose:  17.5 ms
```

Additionally, this PR also moves the reduction dimension of attention to
the fastest dimension. This is preferable because many optimization
passes expects reduction dims to be fastest dims, and will match our
lowerings pass from IREE more.

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Nov 19, 2024
1 parent 4dd2fc8 commit a7feae8
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions sharktank/sharktank/kernels/templates/flash_attention.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
!q_type = tensor<?x?x{{l}}x{{d}}x{{i_type}}>
!k_type = tensor<?x?x{{s}}x{{d}}x{{i_type}}>
!v_type = tensor<?x?x{{s}}x{{e}}x{{i_type}}>
!trans_v_type = tensor<?x?x{{e}}x{{s}}x{{i_type}}>
!o_type = tensor<?x?x{{l}}x{{e}}x{{o_type}}>
!o_dyn_type = tensor<?x?x?x?x{{o_type}}>
!s_type = tensor<{{scale_type}}>
Expand All @@ -32,16 +33,19 @@ util.func private @sharktank_flash_attention_{{l}}_{{s}}_{{d}}_{{e}}_{{i_type}}_

%scale = tensor.extract %s[] : !s_type

%init_trans_v = tensor.empty(%b0, %b1) : !trans_v_type
%transpose_v = linalg.transpose ins(%v: !v_type) outs(%init_trans_v: !trans_v_type) permutation = [0, 1, 3, 2]

%empty_dyn = tensor.empty(%b0, %b1, %l, %e) : !o_dyn_type
%empty = tensor.cast %empty_dyn : !o_dyn_type to !o_type

%atten = iree_linalg_ext.attention {indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
ins(%q, %k, %v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) {
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>]}
ins(%q, %k, %transpose_v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> !o_type
Expand Down

0 comments on commit a7feae8

Please sign in to comment.