Skip to content

Commit 2d3b750

Browse files
Attention Sinks Perf Boost (#78)
* add sink test to attention_ref Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * embed sink test into main flash attn test script Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * change local block positioning for fast path Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * add tma gqa modifications Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * fix exploding build Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * try to fix exploding build again Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * prune other hdims to keep build stable Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * tweak tile size for causal not local case Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * split compilation Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * fix error with varlen q Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * split compilation for root setup Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * renable use one mma wg Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * update for hdim diff Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * include comments on how to enable hdim diff in cmakelists Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * fix test variable Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * update pack gqa heuristic Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * fix comment Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * split disable hdim diff macro into 64 and 192, enable 192 by default in cmakelists Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * add assert to check using Hopper kernels with s_aux Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * more logical placement of assert checks for hdim diff in flash api Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * cherrypick Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * typo fix Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * review comments Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> * review comment Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> --------- Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Jay Shah <jayhshah@gmail.com>
1 parent 93cf5a0 commit 2d3b750

14 files changed

+469
-185
lines changed

CMakeLists.txt

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,27 +178,48 @@ endif ()
178178
if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
179179
# BF16 source files
180180
file(GLOB FA3_BF16_GEN_SRCS
181-
"hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
181+
"hopper/instantiations/flash_fwd_hdim64_bf16*_sm90.cu"
182+
"hopper/instantiations/flash_fwd_hdim96_bf16*_sm90.cu"
183+
"hopper/instantiations/flash_fwd_hdim128_bf16*_sm90.cu"
184+
"hopper/instantiations/flash_fwd_hdim192_bf16*_sm90.cu"
185+
"hopper/instantiations/flash_fwd_hdim256_bf16*_sm90.cu")
186+
# Add these for hdim diff cases
182187
file(GLOB FA3_BF16_GEN_SRCS_
183-
"hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
188+
# "hopper/instantiations/flash_fwd_hdim64_256_bf16*_sm90.cu"
189+
# "hopper/instantiations/flash_fwd_hdim64_512_bf16*_sm90.cu"
190+
"hopper/instantiations/flash_fwd_hdim192_128_bf16*_sm90.cu")
184191
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
185192
file(GLOB FA3_BF16_GEN_SRCS_
186193
"hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu")
187194
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
195+
188196
# FP16 source files
189197
file(GLOB FA3_FP16_GEN_SRCS
190-
"hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
198+
"hopper/instantiations/flash_fwd_hdim64_fp16*_sm90.cu"
199+
"hopper/instantiations/flash_fwd_hdim96_fp16*_sm90.cu"
200+
"hopper/instantiations/flash_fwd_hdim128_fp16*_sm90.cu"
201+
"hopper/instantiations/flash_fwd_hdim192_fp16*_sm90.cu"
202+
"hopper/instantiations/flash_fwd_hdim256_fp16*_sm90.cu")
203+
# Add these for hdim diff cases
191204
file(GLOB FA3_FP16_GEN_SRCS_
192-
"hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
205+
# "hopper/instantiations/flash_fwd_hdim64_256_fp16*_sm90.cu"
206+
# "hopper/instantiations/flash_fwd_hdim64_512_fp16*_sm90.cu"
207+
"hopper/instantiations/flash_fwd_hdim192_128_fp16*_sm90.cu")
193208
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
194209
file(GLOB FA3_FP16_GEN_SRCS_
195210
"hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu")
196211
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
212+
197213
# FP8 source files
198214
file(GLOB FA3_FP8_GEN_SRCS
199-
"hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
215+
"hopper/instantiations/flash_fwd_hdim64_e4m3*_sm90.cu"
216+
"hopper/instantiations/flash_fwd_hdim96_e4m3*_sm90.cu"
217+
"hopper/instantiations/flash_fwd_hdim128_e4m3*_sm90.cu"
218+
"hopper/instantiations/flash_fwd_hdim192_e4m3*_sm90.cu"
219+
"hopper/instantiations/flash_fwd_hdim256_e4m3*_sm90.cu")
220+
# Add these for hdim diff cases (192 only)
200221
file(GLOB FA3_FP8_GEN_SRCS_
201-
"hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
222+
"hopper/instantiations/flash_fwd_hdim192_128_e4m3*_sm90.cu")
202223
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
203224

204225
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
@@ -244,11 +265,17 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
244265
FLASHATTENTION_DISABLE_BACKWARD
245266
FLASHATTENTION_DISABLE_DROPOUT
246267
# FLASHATTENTION_DISABLE_ALIBI
247-
# FLASHATTENTION_DISABLE_SOFTCAP
268+
FLASHATTENTION_DISABLE_SOFTCAP
248269
FLASHATTENTION_DISABLE_UNEVEN_K
249270
# FLASHATTENTION_DISABLE_LOCAL
250271
FLASHATTENTION_DISABLE_PYBIND
251272
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
273+
FLASHATTENTION_DISABLE_CLUSTER # disabled for varlen in any case
274+
# FLASHATTENTION_DISABLE_SM8x
275+
FLASHATTENTION_DISABLE_HDIMDIFF64
276+
# FLASHATTENTION_DISABLE_HDIMDIFF192
277+
CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED
278+
CUTLASS_ENABLE_GDC_FOR_SM90
252279
)
253280
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
254281
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")

hopper/block.h

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,39 @@ struct BlockMN {
1111

1212
static
1313
CUTLASS_DEVICE
14-
cute::tuple<int, int> get_n_block_min_max(
14+
cute::tuple<int, int, int> get_n_block_min_max(
1515
SeqlenInfo_t const& seqlen_info,
1616
int const m_block, int const bidb, int const split_idx, int const num_splits,
1717
int const window_size_left, int const window_size_right,
1818
cutlass::FastDivmod const& qhead_per_khead_divmod) {
1919

20-
int const seqlen_k = seqlen_info.seqlen_k;
20+
int seqlen_k = seqlen_info.seqlen_k;
2121
int const seqlen_q = seqlen_info.seqlen_q;
22+
int n_offset = 0;
23+
24+
// If local, calculate n_offset and update seqlen_k
25+
if constexpr (Is_local) {
26+
int m_idx_min = m_block * kBlockM;
27+
if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
28+
// unlike previously, we don't divide by kBlockN because we want offset for seqlen_k
29+
n_offset = std::max(int(0), m_idx_min + seqlen_k - seqlen_q - window_size_left);
30+
// Subtract n_offset from seqlen_k for subsequent calculations such as n_block_max
31+
// This is the actual seqlen_k processed for this m_block
32+
seqlen_k -= n_offset;
33+
}
34+
2235
int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
2336
if constexpr (Is_causal || Is_local) {
2437
int m_idx_max = (m_block + 1) * kBlockM;
2538
// TODO: check off-by-1 error
2639
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
40+
// If local, blocking (m_idx_max - m_idx_min + window_size_right + window_size_left)
2741
n_block_max = std::min(n_block_max,
2842
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN));
2943
}
44+
// Now, only adjust n_block_min if split
3045
int n_block_min = 0;
31-
if constexpr (Is_local) {
32-
int m_idx_min = m_block * kBlockM;
33-
if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
34-
n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN);
35-
}
46+
3647
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
3748
if constexpr (Split) {
3849
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
@@ -45,7 +56,9 @@ struct BlockMN {
4556
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
4657
}
4758
// if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
48-
return {n_block_min, n_block_max};
59+
60+
// Return n_offset to add to KV gmem pointers and use in masks
61+
return {n_block_min, n_block_max, n_offset};
4962
}
5063

5164
static
@@ -55,12 +68,12 @@ struct BlockMN {
5568
int const m_block, int const bidb, int const split_idx, int const num_splits,
5669
int const window_size_left, int const window_size_right,
5770
cutlass::FastDivmod const& qhead_per_khead_divmod) {
58-
59-
auto [n_block_min, n_block_max] = get_n_block_min_max(
71+
// TODO: check logic with n_offset
72+
auto [n_block_min, n_block_max, n_offset] = get_n_block_min_max(
6073
seqlen_info, m_block, bidb, split_idx, num_splits,
6174
window_size_left, window_size_right, qhead_per_khead_divmod);
62-
int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
63-
int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
75+
int const idx_k_new_min = std::max(n_block_min * kBlockN + n_offset - seqlen_info.seqlen_k_og, 0);
76+
int const idx_k_new_max = std::min(n_block_max * kBlockN + n_offset - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
6477
int const n_block_new_min = idx_k_new_min / kBlockN;
6578
int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
6679
// if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}

hopper/epilogue_fwd.hpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace flash {
2121
using namespace cute;
2222

2323
template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
24-
int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
24+
int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false, int kBlockH_=1>
2525
struct CollectiveEpilogueFwd {
2626

2727
using TileShape_MNK_PV = TileShape_MNK_PV_;
@@ -32,16 +32,18 @@ struct CollectiveEpilogueFwd {
3232
static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
3333
static constexpr bool Varlen = Varlen_;
3434
static constexpr bool PackGQA = PackGQA_;
35+
static constexpr bool PackGQA_TMA = PackGQA && (kBlockH_ > 1);
3536
static constexpr bool Split = Split_;
3637
static constexpr bool Use_smem = !(Split && !Varlen);
37-
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
38+
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && (!PackGQA || PackGQA_TMA);
3839

3940
static_assert(ArchTag::kMinComputeCapability >= 80);
4041
static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
4142
static_assert(sizeof(Element) <= 2);
4243

4344
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
4445
static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
46+
static constexpr int kBlockH = kBlockH_;
4547

4648
static constexpr bool LargeHeadDimV = kHeadDimV > 256;
4749

@@ -83,7 +85,10 @@ struct CollectiveEpilogueFwd {
8385
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
8486
using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
8587
// ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
86-
using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
88+
using ShapeOPackedTMA = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<Int<kBlockH>, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
89+
using ShapeOPacked = std::conditional_t<PackGQA && !PackGQA_TMA,
90+
cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>,
91+
ShapeOPackedTMA>;
8792
using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
8893
// ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
8994
using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
@@ -110,7 +115,7 @@ struct CollectiveEpilogueFwd {
110115
Use_TMA_O,
111116
decltype(make_tma_copy(
112117
GmemTiledCopyOTMA{},
113-
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
118+
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeOPackedTMA{}, StrideOPacked{}),
114119
SmemLayoutOTMA{},
115120
select<0, 1>(TileShape_MNK_PV{}),
116121
_1{})), // no mcast for O
@@ -158,19 +163,13 @@ struct CollectiveEpilogueFwd {
158163

159164
static Params
160165
to_underlying_arguments(Arguments const& args) {
161-
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
162-
TMA_O tma_store_O = [&]{
163-
if constexpr (Use_TMA_O) {
164-
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
165-
} else {
166-
return nullptr;
167-
}
168-
}();
169166
// If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
170167
int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
171168
auto const shape_O_packed = cute::conditional_return<!PackGQA>(
172169
args.shape_O,
173-
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
170+
make_shape(
171+
make_shape(cute::conditional_return<PackGQA_TMA>(Int<kBlockH>{}, qhead_per_khead), get<0>(args.shape_O)),
172+
get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
174173
);
175174
auto const stride_O_packed = cute::conditional_return<!PackGQA>(
176175
args.stride_O,
@@ -180,6 +179,15 @@ struct CollectiveEpilogueFwd {
180179
args.stride_O_partial,
181180
make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
182181
);
182+
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), shape_O_packed, stride_O_packed);
183+
TMA_O tma_store_O = [&]{
184+
if constexpr (Use_TMA_O) {
185+
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
186+
} else {
187+
return nullptr;
188+
}
189+
}();
190+
183191
// If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
184192
auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
185193
select<0, 2, 3, 4>(args.shape_O),
@@ -308,7 +316,7 @@ struct CollectiveEpilogueFwd {
308316

309317
// Step 3: Write O from smem -> gmem
310318
if constexpr (Use_TMA_O) {
311-
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
319+
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O_packed)(_, _, bidh, bidb, split_idx);
312320
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
313321
auto block_tma_O = params.tma_store_O.get_slice(_0{});
314322
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)

0 commit comments

Comments
 (0)