Skip to content

Commit 2733401

Browse files
varlen combine scheduler (#70)
* varlen combine scheduler Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * cleanup Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * move check Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * standard scheduling algo Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * better heuristic Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * better comments Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * cleanup Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * cleanup Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * put in a more readable heurisitic Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * Apply suggestions from code review Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * FA2 8.0 PTX (#69) Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
1 parent 92949c3 commit 2733401

File tree

3 files changed

+259
-15
lines changed

3 files changed

+259
-15
lines changed

hopper/flash_fwd_combine_kernel.h

Lines changed: 231 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,24 @@ class FlashAttnFwdCombine {
122122
using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
123123
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
124124

125+
struct BlockCoord {
126+
int block_m;
127+
int block_k;
128+
int bidb;
129+
};
130+
125131
struct SharedStorage : cute::aligned_struct<128> {
126132
cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
127133
cute::array_aligned<int, kBlockM> smem_max_valid_split;
128134
cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
135+
BlockCoord block_coord;
129136
};
130137

131138
static constexpr int SharedStorageSize = sizeof(SharedStorage);
132139

133140
// Device side arguments
134141
struct Arguments {
142+
int b;
135143
ElementPartial const* const ptr_O_partial;
136144
ShapeOPartial const shape_O_partial;
137145
StrideOPartial const stride_O_partial;
@@ -149,7 +157,8 @@ class FlashAttnFwdCombine {
149157
};
150158

151159
// Kernel entry point API
152-
struct Params {
160+
struct CollectiveParams {
161+
int b;
153162
ElementPartial const* const ptr_O_partial;
154163
ShapeOPartial const shape_O_partial;
155164
StrideOPartial const stride_O_partial;
@@ -169,10 +178,11 @@ class FlashAttnFwdCombine {
169178

170179
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
171180
static
172-
Params
181+
CollectiveParams
173182
to_underlying_arguments(Arguments const& args) {
174183
assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
175184
return {
185+
args.b,
176186
args.ptr_O_partial,
177187
args.shape_O_partial,
178188
args.stride_O_partial,
@@ -191,33 +201,243 @@ class FlashAttnFwdCombine {
191201
};
192202
}
193203

204+
struct SchedulerArguments {
205+
int b;
206+
int seqlen_q;
207+
int total_q;
208+
int num_heads;
209+
int dv;
210+
int const* cu_seqlens_q;
211+
int const* seqused_q;
212+
};
213+
214+
struct StaticTileScheduler {
215+
struct Params {};
216+
static Params to_underlying_arguments(SchedulerArguments const& args) { return {}; }
217+
218+
SharedStorage& shared_storage;
219+
CUTE_DEVICE StaticTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
220+
221+
static dim3 get_grid_shape(SchedulerArguments const& args) {
222+
unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
223+
unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
224+
return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
225+
}
226+
227+
CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
228+
int block_m = blockIdx.x;
229+
int block_k = blockIdx.y;
230+
int bidb = blockIdx.z;
231+
return {block_m, block_k, bidb};
232+
}
233+
};
234+
235+
struct StaticVarlenTileScheduler {
236+
//
237+
// For varlen we have two Scheduling algos:
238+
// 1) STANDARD, same as StaticTileScheduler
239+
// 2) LINEARIZE_M_AND_BATCH, this flattens the tiled M dimension and
240+
// batch dimension into a linear tile index. The grid is then a
241+
// 2D grid of (tile_id, k_block). We then map the linear tile id
242+
// to (m_block, bidb) in the get_block_coord function. This mapping
243+
// is non-trivial since each batch element can have a different
244+
// number of m_blocks. This has overhead when computing the block
245+
// coordinates, but it is more efficient when prefills and decodes
246+
// are mixed since in that case the STANDARD scheduling algo will
247+
// have a lot of empty (no work) blocks in the grid.
248+
//
249+
250+
enum SchedulingAlgo {
251+
STANDARD, // Same as StaticTileScheduler
252+
LINEARIZE_M_AND_BATCH, // Linearize the M and batch dimensions into a single tile index
253+
};
254+
255+
struct Params {
256+
int b;
257+
int num_heads;
258+
int const* const cu_seqlens_q;
259+
int const* const seqused_q;
260+
SchedulingAlgo algo;
261+
};
262+
263+
SharedStorage& shared_storage;
264+
CUTE_DEVICE StaticVarlenTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
265+
266+
static SchedulingAlgo choose_scheduling_algo(SchedulerArguments const& args) {
267+
// Choose the scheduling algorithm based on how dense the grid of tiles that
268+
// do actual work is. If the grid is more then 50% sparse, we linearize the M
269+
// and batch. If the grid is more than 50% dense, we use the standard scheduling
270+
// algorithm since its more efficient at calculating the block coordinates.
271+
// NOTE: in varlen case args.seqlen_q is the max seqlen_q across all batches
272+
// use lower bound to estimate when the density is more than 50%
273+
int lower_bound_on_non_empty_tiles = cute::ceil_div(args.total_q, kBlockM);
274+
int grid_size = args.b * cute::ceil_div(args.seqlen_q, kBlockM);
275+
return 2 * lower_bound_on_non_empty_tiles >= grid_size ?
276+
SchedulingAlgo::STANDARD :
277+
SchedulingAlgo::LINEARIZE_M_AND_BATCH;
278+
}
279+
280+
static Params to_underlying_arguments(SchedulerArguments const& args) {
281+
return {
282+
args.b,
283+
args.num_heads,
284+
args.cu_seqlens_q,
285+
args.seqused_q,
286+
choose_scheduling_algo(args)
287+
};
288+
}
289+
290+
static dim3 get_grid_shape(SchedulerArguments const& args) {
291+
unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
292+
293+
switch (choose_scheduling_algo(args)) {
294+
case SchedulingAlgo::STANDARD: {
295+
unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
296+
unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
297+
return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
298+
}
299+
case SchedulingAlgo::LINEARIZE_M_AND_BATCH: {
300+
// rough worst case upper bound on the number of blocks required
301+
// (assuming each batch has an additional partial block)
302+
unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
303+
return {num_blocks_m, num_blocks_k, 1};
304+
}}
305+
306+
// rough worst case upper bound on the number of blocks required
307+
// (assuming each batch has an additional partial block)
308+
unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
309+
return {num_blocks_m, num_blocks_k, 1};
310+
}
311+
312+
CUTE_DEVICE BlockCoord get_block_coord_linearized_m_and_batch(Params const& params) {
313+
int num_heads = params.num_heads;
314+
int curr_tile_id = blockIdx.x;
315+
316+
// Scan through the batches find the batch that contains the current
317+
// tile_id. Compute using only the first warp of the block.
318+
if (threadIdx.x < 32) {
319+
// We compute linearized tile index start and ends for each batch
320+
// in groups of 32 in parallel
321+
int group_start_bidb = -(cutlass::NumThreadsPerWarp);
322+
int group_end_bidb = 0;
323+
int group_end_tile_id = 0;
324+
int group_start_tile_id = 0;
325+
int group_total_num_tiles = 0;
326+
327+
int local_num_m_blocks = 0;
328+
int local_num_m_blocks_cumulative = 0;
329+
330+
do {
331+
group_start_bidb += cutlass::NumThreadsPerWarp;
332+
group_end_bidb += cutlass::NumThreadsPerWarp;
333+
334+
auto get_num_m_blocks = [&](int bidb) {
335+
if (bidb >= params.b) return 0;
336+
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, 0, params.cu_seqlens_q, params.seqused_q};
337+
return cute::ceil_div(seqlen_info.seqlen * num_heads, Int<kBlockM>{}());
338+
};
339+
340+
// Cumulative number of blocks for the next 31 batches
341+
local_num_m_blocks = get_num_m_blocks(group_start_bidb + threadIdx.x);
342+
local_num_m_blocks_cumulative = warp_prefix_sum(local_num_m_blocks);
343+
// Total number of blocks for the next 32 batches
344+
group_total_num_tiles = warp_shfl_get_last(local_num_m_blocks_cumulative);
345+
346+
group_start_tile_id = group_end_tile_id;
347+
group_end_tile_id += group_total_num_tiles;
348+
} while (curr_tile_id >= group_end_tile_id && group_end_bidb < params.b);
349+
350+
int local_batch_end_tile_id = group_start_tile_id + local_num_m_blocks_cumulative;
351+
// Find the last batch idx in the group where `local_batch_end_tile_id <= curr_tile_id`
352+
// these values below are now common to all threads in the warp
353+
int batch_idx_in_group = warp_last_true_laneid(local_batch_end_tile_id <= curr_tile_id);
354+
int batch_num_m_blocks = warp_shfl_get(local_num_m_blocks, batch_idx_in_group);
355+
int batch_m_start_tile_id = group_start_tile_id + (batch_idx_in_group > 0 ?
356+
warp_shfl_get(local_num_m_blocks_cumulative, batch_idx_in_group - 1) : 0);
357+
358+
int bidb = group_start_bidb + batch_idx_in_group;
359+
int block_m = curr_tile_id - batch_m_start_tile_id;
360+
// NOTE(lucas): not sure why this causes a block_k unused warning
361+
// just inlined `blockIdx.y` to suppress the warning
362+
// int block_k = blockIdx.y;
363+
// shared_storage.block_coord = {block_m, block_k, bidb};
364+
BlockCoord block_coord{block_m, static_cast<int>(blockIdx.y), bidb};
365+
if (threadIdx.x == 0) { shared_storage.block_coord = block_coord; }
366+
}
367+
368+
__syncthreads();
369+
return shared_storage.block_coord;
370+
}
371+
372+
373+
CUTE_DEVICE BlockCoord get_block_coord_standard(Params const& params) {
374+
int block_m = blockIdx.x;
375+
int block_k = blockIdx.y;
376+
int bidb = blockIdx.z;
377+
return {block_m, block_k, bidb};
378+
}
379+
380+
CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
381+
switch (params.algo) {
382+
case SchedulingAlgo::STANDARD:
383+
return get_block_coord_standard(params);
384+
case SchedulingAlgo::LINEARIZE_M_AND_BATCH:
385+
return get_block_coord_linearized_m_and_batch(params);
386+
}
387+
return {0, 0, 0}; // Should never reach here
388+
}
389+
};
390+
391+
using TileScheduler = std::conditional_t<
392+
Varlen,
393+
StaticVarlenTileScheduler,
394+
StaticTileScheduler
395+
>;
396+
397+
using SchedulerParams = typename TileScheduler::Params;
398+
399+
struct Params {
400+
CollectiveParams params;
401+
SchedulerParams scheduler_params;
402+
};
403+
194404
CUTLASS_DEVICE
195405
void
196-
operator()(Params const& params, char* smem_buf) {
406+
operator()(Params const& kernel_params, char* smem_buf) {
407+
CollectiveParams const& params = kernel_params.params;
197408

198409
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
410+
TileScheduler tile_scheduler{shared_storage};
411+
199412
Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
200413
Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
201414
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
202415

203416
int const thread_idx = threadIdx.x;
204-
int const m_block = blockIdx.x;
205-
int const k_block = blockIdx.y;
206-
int const batch = blockIdx.z;
207-
int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
417+
418+
BlockCoord block_coord = tile_scheduler.get_block_coord(kernel_params.scheduler_params);
419+
420+
int const m_block = block_coord.block_m;
421+
int const k_block = block_coord.block_k;
422+
int const batch = block_coord.bidb;
208423

209424
if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
210425
cutlass::arch::wait_on_dependent_grids();
211426
*params.semaphore_to_reset = 0;
212427
}
213-
if (num_splits <= 1) { return; }
428+
214429
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
215430
int const offset = seqlen_info.offset;
216431
int const seqlen = seqlen_info.seqlen;
217432
int max_idx = seqlen * get<2>(params.shape_LSE_partial);
218-
if constexpr (Varlen) {
219-
if (m_block * kBlockM >= max_idx) { return; }
220-
}
433+
434+
bool block_coord_valid =
435+
block_coord.block_m < cute::ceil_div(max_idx, Int<kBlockM>{}) &&
436+
block_coord.bidb < params.b;
437+
if (!block_coord_valid) { return; }
438+
439+
int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
440+
if (num_splits <= 1) { return; }
221441

222442
cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
223443

hopper/flash_fwd_combine_launch_template.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool e
2525
IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
2626

2727
typename CombineKernel::Arguments args {
28+
params.b,
2829
static_cast<ElementPartial const*>(params.oaccum_ptr),
2930
{!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
3031
{params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
@@ -38,10 +39,17 @@ void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool e
3839
params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
3940
};
4041

41-
typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);
42-
int num_blocks_k = cute::ceil_div(params.dv, kBlockK);
43-
int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM);
44-
dim3 grid_m(num_blocks_m, num_blocks_k, params.b);
42+
typename CombineKernel::SchedulerArguments scheduler_args {
43+
params.b, params.seqlen_q, params.total_q, params.h, params.dv,
44+
params.cu_seqlens_q, params.seqused_q
45+
};
46+
47+
typename CombineKernel::Params kernel_params = {
48+
CombineKernel::to_underlying_arguments(args),
49+
CombineKernel::TileScheduler::to_underlying_arguments(scheduler_args)
50+
};
51+
52+
dim3 grid_m = CombineKernel::TileScheduler::get_grid_shape(scheduler_args);
4553
auto kernel = cutlass::device_kernel<CombineKernel>;
4654
int smem_size = CombineKernel::SharedStorageSize;
4755
if (smem_size >= 48 * 1024) {

hopper/utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,22 @@ CUTE_DEVICE T warp_prefix_sum(T val) {
646646

647647
////////////////////////////////////////////////////////////////////////////////////////////////////
648648

649+
template<typename T>
650+
CUTE_DEVICE T warp_shfl_get(T val, int src_lane) {
651+
return __shfl_sync(0xffffffff, val, src_lane);
652+
};
653+
654+
template<typename T>
655+
CUTE_DEVICE T warp_shfl_get_last(T val) {
656+
return __shfl_sync(0xffffffff, val, cutlass::NumThreadsPerWarp - 1);
657+
};
658+
659+
CUTE_DEVICE int warp_last_true_laneid(bool cond) {
660+
return __popc(__ballot_sync(0xffffffff, cond));
661+
};
662+
663+
////////////////////////////////////////////////////////////////////////////////////////////////////
664+
649665
template<class T>
650666
CUTE_DEVICE T warp_uniform(T a) {
651667
return __shfl_sync(0xffffffff, a, 0);

0 commit comments

Comments
 (0)