diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 5a595840aa..9511929b79 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -431,7 +431,7 @@ inline int get_num_splits(Flash_fwd_params const& params) { // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params)); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); @@ -585,9 +585,11 @@ mha_fwd_get_scheduler_metadata( params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits) params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split + params.pack_gqa = params.num_splits > 1; bool is_varlen = true; @@ -611,7 +613,7 @@ mha_fwd_get_scheduler_metadata( } if (params.num_splits_dynamic_ptr) { - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params)); auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); @@ -725,7 +727,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); int const num_pages = !paged_KV ? 0 : k.size(0); int const page_size = !paged_KV ? 1 : k.size(1); - int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); + int const seqlen_k = !max_seqlen_k_.has_value() ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); @@ -938,9 +940,11 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits) params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split + params.pack_gqa = params.num_splits > 1; // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index e9297e1b7c..2556321913 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -21,12 +21,13 @@ #include "mainloop_fwd_sm90_tma_gmma_ws.hpp" #include "mainloop_fwd_sm80.hpp" #include "epilogue_fwd.hpp" +#include "heuristics.h" using namespace cute; template + bool PackGQA, bool Split, bool V_colmajor, bool Use_one_mma_wg> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); @@ -36,7 +37,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap); + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg); static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); @@ -203,17 +204,22 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { - // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; - - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; - BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + BOOL_SWITCH(use_one_mma_wg(params), Use_one_mma_wg_, [&] { + // Avoid over compiliation by making sure this only get set if it is actually used, i.e. we currently only support one mma wg for 128 head dim and hopper + static constexpr bool Use_one_mma_wg = Use_one_mma_wg_ && Arch >= 90 && kHeadDim == 128; + + // Only needed here to decide if we should use cluster + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128; + + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); + }); }); }); }); diff --git a/hopper/heuristics.h b/hopper/heuristics.h index 031ea44a0b..43d06f5482 100644 --- a/hopper/heuristics.h +++ b/hopper/heuristics.h @@ -5,6 +5,12 @@ #pragma once #include +#include "flash.h" + +inline bool use_one_mma_wg(Flash_fwd_params const& params) { + return params.arch >= 90 && params.d == 128 && + params.seqlen_q * (!params.pack_gqa ? 1 : params.h / params.h_k) <= 64; +}; inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) { // If varlen, we don't actually know seqlen_q but only max_seqlen_q. diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 4414b53ac2..b87a83afff 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -9,7 +9,7 @@ // Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, - bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { + bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false, bool use_one_mma_wg=false) { if (element_size == 2) { if (headdim <= 64) { // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; @@ -29,7 +29,11 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; + if (use_one_mma_wg) { + return {64, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; + } else { + return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; + } // {128, 192, false, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) {