From a2672e3e5a6d05ef10042d212b0482dbc670e10b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 30 Sep 2025 10:35:55 +0200 Subject: [PATCH] CUDA: faster tile FA, add oob checks, more HSs --- ggml/src/ggml-cuda/CMakeLists.txt | 2 + ggml/src/ggml-cuda/common.cuh | 7 +- ggml/src/ggml-cuda/fattn-common.cuh | 9 +- ggml/src/ggml-cuda/fattn-tile.cu | 771 +---------- ggml/src/ggml-cuda/fattn-tile.cuh | 1213 +++++++++++++++++ ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 2 + ggml/src/ggml-cuda/fattn.cu | 76 +- .../fattn-tile-instance-dkq112-dv112.cu | 5 + .../fattn-tile-instance-dkq128-dv128.cu | 5 + .../fattn-tile-instance-dkq256-dv256.cu | 5 + .../fattn-tile-instance-dkq40-dv40.cu | 5 + .../fattn-tile-instance-dkq576-dv512.cu | 5 + .../fattn-tile-instance-dkq64-dv64.cu | 5 + .../fattn-tile-instance-dkq80-dv80.cu | 5 + .../fattn-tile-instance-dkq96-dv96.cu | 5 + .../template-instances/generate_cu_files.py | 18 +- ggml/src/ggml-hip/CMakeLists.txt | 2 + ggml/src/ggml-musa/CMakeLists.txt | 2 + 18 files changed, 1358 insertions(+), 784 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index bdcefe7b7ed7a..3024775135966 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_CUDA "*.cu") + file(GLOB SRCS "template-instances/fattn-tile*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d51abbeafa944..e0abde5427c83 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -245,7 +245,8 @@ static bool fp16_available(const int cc) { } static bool fast_fp16_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc); + return GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); } // To be used for feature selection of external libraries, e.g. cuBLAS. @@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, } // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. +// Important: do not use this function if dst and src both point at registers. +// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types. +// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions. +// If dst and src point at different address spaces then they are guaranteed to not be aliased. template static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { if constexpr (alignment != 0) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 33d2f0f49e3de..bc0c2523cc82f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -793,8 +793,6 @@ void launch_fattn( GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); - GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -878,7 +876,7 @@ void launch_fattn( // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or // multiple sequences of possibly different lengths. - if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { const int s31 = mask->nb[1] / sizeof(half2); const int s33 = mask->nb[3] / sizeof(half2); @@ -916,8 +914,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); - const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -946,7 +943,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = Q->ne[2]*Q->ne[3]; + blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 68de623d80349..3a5806d9091d7 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -1,756 +1,45 @@ #include "common.cuh" -#include "fattn-common.cuh" #include "fattn-tile.cuh" #include "fattn-wmma-f16.cuh" -// kq_stride == number of KQ rows to process per iteration -// kq_nbatch == number of K columns to load in parallel for KQ calculation - -static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { - if (GGML_CUDA_CC_IS_AMD(cc)) { - if (GGML_CUDA_CC_IS_RDNA(cc)) { - switch (D) { - case 64: - return 128; - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - switch (D) { - case 64: - return ncols == 32 ? 128 : 64; - case 128: - return ncols == 32 ? 64 : 32; - case 256: - return 32; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - if (fast_fp16_available(cc)) { - switch (D) { - case 64: - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - switch (D) { - case 64: - return ncols <= 16 ? 128 : 64; - case 128: - return ncols <= 16 ? 64 : 32; - case 256: - return 32; - default: - GGML_ABORT("fatal error"); - return -1; - } - GGML_UNUSED(warp_size); -} - -static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { -#ifdef GGML_USE_HIP -#ifdef RDNA - switch (D) { - case 64: - return 128; - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#else - switch (D) { - case 64: - return ncols == 32 ? 128 : 64; - case 128: - return ncols == 32 ? 64 : 32; - case 256: - return 32; - default: - return -1; - } -#endif // RDNA -#else -#ifdef FAST_FP16_AVAILABLE - switch (D) { - case 64: - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#else - switch (D) { - case 64: - return ncols <= 16 ? 128 : 64; - case 128: - return ncols <= 16 ? 64 : 32; - case 256: - return 32; - default: - return -1; - } -#endif // FAST_FP16_AVAILABLE -#endif // GGML_USE_HIP - GGML_UNUSED_VARS(ncols, warp_size); -} - -static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) { -#ifdef GGML_USE_HIP - switch (D) { - case 64: - return 64; - case 128: - case 256: - return 128; - default: - return -1; - } -#else -#ifdef FAST_FP16_AVAILABLE - switch (D) { - case 64: - return 64; - case 128: - case 256: - return 128; - default: - return -1; - } -#else - switch (D) { - case 64: - return 64; - case 128: - return 128; - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#endif // FAST_FP16_AVAILABLE -#endif // GGML_USE_HIP - GGML_UNUSED_VARS(ncols, warp_size); -} - -static int fattn_tile_get_nthreads_host(const int cc, const int ncols) { - return 256; - GGML_UNUSED_VARS(cc, ncols); -} - -static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) { - return 256; - GGML_UNUSED(ncols); -} - -static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) { -#ifdef RDNA - return 3; -#else - return ncols <= 16 ? 3 : 2; -#endif // RDNA - GGML_UNUSED(ncols); -} - -template // D == head size -__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols)) -static __global__ void flash_attn_tile( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: -#ifdef GGML_USE_WMMA_FATTN - NO_DEVICE_CODE; - return; -#endif // GGML_USE_WMMA_FATTN - - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; - return; - } - - constexpr int warp_size = 32; - constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size; - constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); - static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); - constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); - static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch"); - - // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); - constexpr int cpy_ne = cpy_nb / 4; - - constexpr int cpw = ncols/nwarps; // cols per warp - - // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel. - // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes. -#ifdef FAST_FP16_AVAILABLE - constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; - - __shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; - __shared__ half2 Q_tmp[ncols][D/2]; - __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts. - half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; -#else - constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; - - __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; - __shared__ float Q_tmp[ncols][D]; - __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts. - float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; -#endif // FAST_FP16_AVAILABLE - static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j"); - - float KQ_max[cpw]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - KQ_max[j0/nwarps] = -FLT_MAX/2.0f; - } - float KQ_sum[cpw] = {0.0f}; - - // Load Q data, convert to FP16 if fast. -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - const int j = j0 + threadIdx.y*cpw; - - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - float tmp_f[cpy_ne_D] = {0.0f}; - if (ic0 + j < ne01) { - ggml_cuda_memcpy_1(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]); - } - -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; ++i1) { - tmp_f[i1] *= scale; - } - -#ifdef FAST_FP16_AVAILABLE - half2 tmp_h2[cpy_ne_D/2]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { - tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); - } - ggml_cuda_memcpy_1(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2); -#else - ggml_cuda_memcpy_1 (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f); -#endif // FAST_FP16_AVAILABLE - } - } - - __syncthreads(); - - // Main loop over KV cache: - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { - // Calculate KQ tile and keep track of new maximum KQ values: - - float KQ_max_new[cpw]; -#pragma unroll - for (int j = 0; j < cpw; ++j) { - KQ_max_new[j] = KQ_max[j]; - } - - float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication. - - // KQ = K @ Q matrix multiplication: -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size); -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) { - ggml_cuda_memcpy_1( - &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], - &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]); - } -#else - constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size; -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) { - half2 tmp_h2[cpy_ne_kqnb/2]; - ggml_cuda_memcpy_1( - tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]); - - float2 tmp_f2[cpy_ne_kqnb/2]; -#pragma unroll - for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) { - tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]); - } - ggml_cuda_memcpy_1( - &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2); - } -#endif // FAST_FP16_AVAILABLE - } - - __syncthreads(); - -#ifdef FAST_FP16_AVAILABLE -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) { - half2 K_k[kq_stride/warp_size][cpy_ne]; - half2 Q_k[cpw][cpy_ne]; -#else -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) { - float K_k[kq_stride/warp_size][cpy_ne]; - float Q_k[cpw][cpy_ne]; -#endif // FAST_FP16_AVAILABLE - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]); -#else - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]); -#endif // FAST_FP16_AVAILABLE - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { - const int j_KQ = j_KQ_0 + threadIdx.y*cpw; - -#ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]); -#else - ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]); -#endif // FAST_FP16_AVAILABLE - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { -#pragma unroll - for (int k = 0; k < cpy_ne; ++k) { - ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]); - } - } - } - } - - if (k_KQ_0 + kq_nbatch < D) { - __syncthreads(); // Sync not needed on last iteration. - } - } - - // Apply logit softcap, mask, update KQ_max: -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { - const int j_KQ = j_KQ_0 + threadIdx.y*cpw; - - if (use_logit_softcap) { - KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]); - } - - KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]); - } - } - - __syncthreads(); - - // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { -#ifdef FAST_FP16_AVAILABLE - half tmp[kq_stride/warp_size][softmax_iter_j]; -#else - float tmp[kq_stride/warp_size][softmax_iter_j]; -#endif // FAST_FP16_AVAILABLE - -#pragma unroll - for (int j1 = 0; j1 < softmax_iter_j; ++j1) { - KQ_max_new[j0+j1] = warp_reduce_max(KQ_max_new[j0+j1]); - const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]); - KQ_max[j0+j1] = KQ_max_new[j0+j1]; - - float KQ_sum_add = 0.0f; -#pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { - const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]); - KQ_sum_add += val; - tmp[i0/warp_size][j1] = val; - } - KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add; - -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2; - } -#else -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale; - VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale; - } -#endif // FAST_FP16_AVAILABLE - } - -#pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - ggml_cuda_memcpy_1( - KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]); - } - } - - // VKQ = V @ KQ matrix multiplication: - constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K. - static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); -#pragma unroll - for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { - const int k_tile = k1 + threadIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1( - &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D], - &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]); - } -#else - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - half2 tmp_h2[cpy_ne_D/2]; - ggml_cuda_memcpy_1( - tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]); - - float2 tmp_f2[cpy_ne_D/2]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { - tmp_f2[i1] = __half22float2(tmp_h2[i1]); - } - ggml_cuda_memcpy_1( - &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2); - } -#endif // FAST_FP16_AVAILABLE - } - - __syncthreads(); - -#ifdef FAST_FP16_AVAILABLE -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { - half2 V_k[(D/2)/warp_size]; - half2 KQ_k[cpw]; - - constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]); - } -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { - const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); - - half tmp[softmax_iter_j]; - ggml_cuda_memcpy_1( - &tmp, KQ[j][k0 + k1]); -#pragma unroll - for (int j1 = 0; j1 < softmax_iter_j; ++j1) { - KQ_k[j0+j1] = __half2half2(tmp[j1]); - } - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0]; - } - } - } -#else -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { - float2 V_k[(D/2)/warp_size]; - float KQ_k[cpw]; - - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]); - } -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { - const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); - - ggml_cuda_memcpy_1( - &KQ_k[j0], KQ[j][k0 + k1]); - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0]; - VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0]; - } - } - } -#endif // FAST_FP16_AVAILABLE - - __syncthreads(); - } - } - - - // Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - float KQ_max_new_j = fmaxf(KQ_max[j0], sink); - KQ_max_new_j = warp_reduce_max(KQ_max_new_j); - - const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j); - KQ_max[j0] = KQ_max_new_j; - - const float val = expf(sink - KQ_max[j0]); - KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale; - if (threadIdx.x == 0) { - KQ_sum[j0] += val; - } - -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0][i0/warp_size] *= KQ_max_scale_h2; - } -#else -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0][i0/warp_size].x *= KQ_max_scale; - VKQ[j0][i0/warp_size].y *= KQ_max_scale; - } -#endif // FAST_FP16_AVAILABLE - } - } - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { - KQ_sum[j_VKQ_0] = warp_reduce_sum(KQ_sum[j_VKQ_0]); - } - if (gridDim.y == 1) { -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]); -#pragma unroll - for (int i = 0; i < (D/2)/warp_size; ++i) { - VKQ[j_VKQ_0][i] *= KQ_sum_j_inv; - } -#else - const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0]; -#pragma unroll - for (int i = 0; i < (D/2)/warp_size; ++i) { - VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv; - VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv; - } -#endif // FAST_FP16_AVAILABLE - } - } - - // Write back results: -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { - const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - float2 tmp[cpy_ne_D]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; ++i1) { - tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]); - } - ggml_cuda_memcpy_1(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); - } -#else - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1( - &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]); - } -#endif // FAST_FP16_AVAILABLE - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} - -template -static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const int warp_size = 32; - - constexpr size_t nbytes_shared = 0; - -#ifdef GGML_USE_HIP - if constexpr (D <= 128) { - if (Q->ne[1] > 32) { - constexpr int cols_per_block = 64; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); - return; - } - } -#endif // GGML_USE_HIP - - if (Q->ne[1] > 16) { - constexpr int cols_per_block = 32; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); - return; - } - - constexpr int cols_per_block = 16; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); -} - -template -static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + switch (K->ne[0]) { + case 40: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst); + } break; case 64: { - launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst); + } break; + case 80: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst); + } break; + case 96: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst); + } break; + case 112: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst); } break; case 128: { - launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); } break; case 256: { - launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); + } break; + case 576: { + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); } break; default: { GGML_ABORT("Unsupported head size"); } break; } } - -void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_switch_head_size(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_switch_head_size(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 10dc22d1bf971..2efc9cc880cf8 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -1,3 +1,1216 @@ #include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +// nbatch_fa == number of KQ rows to process per iteration +// nbatch_K == number of K columns to load in parallel for KQ calculation + +// TODO optimize kernel parameters for FP16 NVIDIA (P100) +// TODO optimize kernel parameters for head sizes 40, 80, 96, 112 + +// The ROCm compiler cannot handle templating in __launch_bounds__. +// As a workaround, define a macro to package the kernel parameters as uint32_t: +#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads) <= 512, "bad nthreads"); \ + static_assert((occupancy) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K) <= 256, "bad nbatch_K"); \ + return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \ + } \ + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) + + return 0; +} + +static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA(cc)) { + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); + } + if (fast_fp16_available(cc)) { + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +} + +static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) { +#ifdef GGML_USE_HIP +#ifdef RDNA + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); +#endif // RDNA +#else +#ifdef FAST_FP16_AVAILABLE + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP +} + +static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1); +} + +// TODO: deduplicate with mma-f16 +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j); + const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; + + const half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + ggml_cuda_memcpy_1( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + } + } + } + }; + // 1: max 64*16=512 bytes, 512 half + // 2: max 32*16=512 bytes, 256 half + // 3: max 16*16=256 bytes, 128 half + // 4: max 8*16=128 bytes, 64 half + // 5: max 4*16= 64 bytes, 32 half + // 6: max 2*16= 32 bytes, 16 half + // 7: max 1*16= 16 bytes, 8 half + static_assert(J % 8 == 0, "bad J"); + static_assert((J/2) % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<7>{}(load); +} + +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j); + const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); + + const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; + half2 tmp_h2[cpy_ne/2]; + ggml_cuda_memcpy_1( + tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + + float2 tmp_f2[cpy_ne/2]; +#pragma unroll + for (int l = 0; l < cpy_ne/2; ++l) { + tmp_f2[l] = __half22float2(tmp_h2[l]); + } + ggml_cuda_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); + } + } + } + }; + // 1: max 32*16=512 bytes, 128 float + // 2: max 16*16=256 bytes, 64 float + // 3: max 8*16=128 bytes, 32 float + // 4: max 4*16= 64 bytes, 16 float + // 5: max 2*16= 32 bytes, 8 float + static_assert(J % 8 == 0, "bad J"); + static_assert(J % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<5>{}(load); +} + +// Function that performs a single iteration in for the KQ matrix multiplication: +template +static __device__ __forceinline__ void flash_attn_tile_iter_KQ( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int k_VKQ_0, + const int k_VKQ_sup, + const int k_KQ_0, + float * KQ_acc) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + flash_attn_tile_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE + static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { + half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + half2 Q_k[cpw][cpy_ne]; +#else + static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { + float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + float Q_k[cpw][cpy_ne]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]); + } + } + } + } + + if (k_KQ_0 + nbatch_K < DKQ) { + __syncthreads(); // Sync not needed on last iteration. + } +} + +// Function that performs a single iteration of the main loop over up to nbatch_fa tokens. +template +static __device__ __forceinline__ void flash_attn_tile_iter( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ mask, + const float logit_softcap, + const float slope, + T_KQ * const KQ, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int stride_V2, + const int stride_mask, + float * const KQ_max, + float * const KQ_sum, + T_acc * const VKQ, + const int k_VKQ_0, + const int k_VKQ_max) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory. + // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs]. +#ifdef FAST_FP16_AVAILABLE + constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; +#else + constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; +#endif // FAST_FP16_AVAILABLE + static_assert(cpw % KQ_cs == 0, "bad KQ_cs"); + const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data + + float KQ_max_new[cpw]; +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_max_new[jc0] = KQ_max[jc0]; + } + + float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication. + + // KQ = K @ Q matrix multiplication: + constexpr int nbatch_K_last = DKQ % nbatch_K; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + if (nbatch_K_last > 0) { + constexpr int k_KQ_0 = DKQ - nbatch_K_last; + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + + // Apply logit softcap + mask, update KQ_max: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + if (use_logit_softcap) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ? + slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; + + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + + KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]); + } + + if constexpr (np == 1) { + __syncthreads(); + } else { + static_assert(cpw == 1, "bad cpw"); + __shared__ float KQ_max_new_shared[nwarps]; + if (threadIdx.x == 0) { + KQ_max_new_shared[threadIdx.y] = KQ_max_new[0]; + } + __syncthreads(); + KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np]; + KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); + } + + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { +#ifdef FAST_FP16_AVAILABLE + half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#else + float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int jc1 = 0; jc1 < KQ_cs; ++jc1) { + const int jc = jc0 + jc1; + + const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]); + KQ_max[jc] = KQ_max_new[jc]; + + float KQ_sum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]); + if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) { + KQ_sum_add += val; + } + tmp[i0/(np*warp_size)][jc1] = val; + } + KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + ggml_cuda_memcpy_1( + KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs, + tmp[i0/(np*warp_size)]); + } + } + + // VKQ = V @ KQ matrix multiplication: + static_assert(DV <= DKQ, "bad DV"); + static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K"); + constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K. + static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V"); + static_assert(nbatch_V % np == 0, "bad nbatch_V"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { + flash_attn_tile_load_tile + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + half2 V_k[(DVp/2)/warp_size]; + half2 KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + half tmp[KQ_cs]; + ggml_cuda_memcpy_1( + &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); +#pragma unroll + for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) { + KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0]; + } + } + } +#else +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + float2 V_k[(DVp/2)/warp_size]; + float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + ggml_cuda_memcpy_1( + &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0]; + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0]; + } + } + } +#endif // FAST_FP16_AVAILABLE + + __syncthreads(); + } +} + +template // D == head size +__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: + + if ( +#ifdef GGML_USE_WMMA_FATTN + (ncols2 != 1 && DV != 40 && DV != 512) || +#endif // GGML_USE_WMMA_FATTN + (use_logit_softcap && !(DV == 128 || DV == 256)) + ) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined"); + + constexpr int ncols = ncols1*ncols2; + constexpr int warp_size = 32; + constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size; + constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2); + constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on. + + const int sequence = blockIdx.z / (ne02/ncols2); + const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb21 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp. + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column. + static_assert(cpw == 1 || np == 1, "bad cpw / np"); + static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0"); + + constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size. + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel. + // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11. + // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV). + // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications. + // VKQ == Accumulators in registers for the final VKQ result. +#ifdef FAST_FP16_AVAILABLE + __shared__ half2 Q_tmp[ncols * DKQ/2]; + __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; + __shared__ half KQ[ncols * nbatch_fa]; + half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#else + __shared__ float Q_tmp[ncols * DKQ]; + __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; + __shared__ float KQ[ncols * nbatch_fa]; + float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#endif // FAST_FP16_AVAILABLE + + float KQ_max[cpw]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; + } + float KQ_sum[cpw] = {0.0f}; + + // Load Q data, convert to FP16 if fast: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { + if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { + float tmp_f[cpy_ne_D] = {0.0f}; + if (ncols1 == 1 || col_Q_0 + j < ne01) { + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); + } + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + +#ifdef FAST_FP16_AVAILABLE + half2 tmp_h2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); + } + ggml_cuda_memcpy_1( + &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)], + tmp_h2); +#else + ggml_cuda_memcpy_1( + &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D], + tmp_f); +#endif // FAST_FP16_AVAILABLE + } + } + } + + __syncthreads(); + + // Main loop over KV cache: + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + if (ncols2 == 1) { + // Branch with out-of-bounds checks. + int k_VKQ_0 = blockIdx.y*nbatch_fa; + while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + k_VKQ_0 += gridDim.y*nbatch_fa; + } + if (k_VKQ_0 < k_VKQ_max) { + constexpr bool oob_check = true; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } else { + // Branch without out-of-bounds checks. + for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } + +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_sum[jc0] = warp_reduce_sum(KQ_sum[jc0]); + } + + if constexpr (np > 1) { + static_assert(cpw == 1, "bad cpw"); + static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small"); + +#ifdef FAST_FP16_AVAILABLE + half2 * VKQ_combine = (half2 *) KV_tmp; +#else + float * VKQ_combine = (float *) KV_tmp; +#endif // FAST_FP16_AVAILABLE + float * KQ_sum_combine = (float *) Q_tmp; + + if (threadIdx.y % np != 0) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]); + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1( + &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size); + } +#endif // FAST_FP16_AVAILABLE + + if (threadIdx.x == 0) { + KQ_sum_combine[threadIdx.y] = KQ_sum[0]; + } + + return; + } + + __syncthreads(); + +#pragma unroll + for (int ip = 1; ip < np; ++ip) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + half2 tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + VKQ[i0/warp_size + i1] += tmp[i1]; + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + float tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + ((float *)VKQ)[i0/warp_size + i1] += tmp[i1]; + } + } +#endif // FAST_FP16_AVAILABLE + + KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip]; + } + } + + // Attention sink: adjust KQ max and sum only for the first of all parallel blocks: + if (sinks && blockIdx.y == 0) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + const float sink = ((const float *) sinks)[head0 + jc % ncols2]; + + float KQ_max_new_j = fmaxf(KQ_max[jc0], sink); + const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j); + KQ_max[jc0] = KQ_max_new_j; + + const float val = expf(sink - KQ_max[jc0]); + KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + if (gridDim.y == 1) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]); +#pragma unroll + for (int i = 0; i < (DVp/2)/warp_size; ++i) { + VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv; + } +#else + const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0]; +#pragma unroll + for (int i = 0; i < (DVp/2)/warp_size; ++i) { + VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv; + VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv; + } +#endif // FAST_FP16_AVAILABLE + } + } + + // Write back results: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (ncols1 > 1 && col_Q_0 + j >= ne01) { + return; + } + + const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); + } + if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { + ggml_cuda_memcpy_1(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { + ggml_cuda_memcpy_1( + &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], + &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); + } + } +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template +static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + + constexpr size_t nbytes_shared = 0; + +#ifdef GGML_USE_HIP + if constexpr (DV <= 128) { + if (Q->ne[1] > 32/ncols2) { + constexpr int cols_per_block = 64; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } +#endif // GGML_USE_HIP + +#ifndef GGML_USE_HIP + if constexpr (DV <= 256) +#endif // GGML_USE_HIP + { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 4) { + if (Q->ne[1] > 2/ncols2) { + constexpr int cols_per_block = 4; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 2) { + constexpr int cols_per_block = 2; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + GGML_ABORT("fatal error"); +} + +template +static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); + const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + + if constexpr (DV == 512) { + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + } + + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("fatal error"); +} + +template +void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_ncols2(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_ncols2(ctx, dst); + } +} void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#define DECL_FATTN_TILE_CASE(DKQ, DV) \ + template void ggml_cuda_flash_attn_ext_tile_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_TILE_CASE( 40, 40); +extern DECL_FATTN_TILE_CASE( 64, 64); +extern DECL_FATTN_TILE_CASE( 80, 80); +extern DECL_FATTN_TILE_CASE( 96, 96); +extern DECL_FATTN_TILE_CASE(112, 112); +extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 1848d08836185..7235f1b77aeed 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,3 +1,5 @@ +#pragma once + #include "common.cuh" #if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 0c8e7b3e41904..fe970adaecef3 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; #endif// FLASH_ATTN_AVAILABLE + const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; @@ -206,37 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const int gqa_ratio = Q->ne[2] / K->ne[2]; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); - const int cc = ggml_cuda_info().devices[device].cc; + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - // TODO: temporary until support is extended - // https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206 - if (K->ne[1] % FATTN_KQ_STRIDE != 0) { - return BEST_FATTN_KERNEL_NONE; - } + // The effective batch size for the kernel can be increased by gqa_ratio. + // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, + const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + + const int cc = ggml_cuda_info().devices[device].cc; switch (K->ne[0]) { + case 40: case 64: - case 128: - case 256: - if (V->ne[0] != K->ne[0]) { - return BEST_FATTN_KERNEL_NONE; - } - break; case 80: case 96: + case 128: case 112: + case 256: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; } - if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) { - return BEST_FATTN_KERNEL_NONE; - } break; case 576: if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) { + if (!gqa_opt_applies || gqa_ratio % 16 != 0) { return BEST_FATTN_KERNEL_NONE; } break; @@ -270,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0; - - // If Turing tensor cores available, use them except for some cases with batch size 1: - if (turing_mma_available(cc)) { - best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16; + // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + // If Turing tensor cores available, use them: + if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { if (can_use_vector_kernel) { if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } else { if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { if (Q->ne[1] <= 2) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } else { if (Q->ne[1] == 1) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } } - if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) { - best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply. + if (!gqa_opt_applies && Q->ne[1] == 1) { + return BEST_FATTN_KERNEL_VEC; } } - return best; - } - - // Use kernels specialized for small batch sizes if possible: - if (Q->ne[1] <= 8 && can_use_vector_kernel) { - return BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_MMA_F16; } - // For large batch sizes, use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc)) { + // Use the WMMA kernel if possible: + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) { + if (can_use_vector_kernel && Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } return BEST_FATTN_KERNEL_WMMA_F16; } - // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: + // If there are no tensor cores available, use the generic tile kernel: + if (can_use_vector_kernel) { + if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } return BEST_FATTN_KERNEL_TILE; } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu new file mode 100644 index 0000000000000..a8b15ad72a916 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(112, 112); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu new file mode 100644 index 0000000000000..1da18105508ac --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(128, 128); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu new file mode 100644 index 0000000000000..bc65c723eca9d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(256, 256); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu new file mode 100644 index 0000000000000..10b330fa6c031 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(40, 40); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu new file mode 100644 index 0000000000000..254b7d2e1dc29 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu new file mode 100644 index 0000000000000..5caffac0467d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(64, 64); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu new file mode 100644 index 0000000000000..90abb3b186261 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(80, 80); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu new file mode 100644 index 0000000000000..7292c0aab8f98 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(96, 96); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index d410080fab841..81a986f38cacf 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,8 +3,17 @@ from glob import glob import os +HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576] + TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v}); +""" + SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec.cuh" @@ -51,6 +60,11 @@ def get_short_name(long_quant_name): for filename in glob("*.cu"): os.remove(filename) +for head_size_kq in HEAD_SIZES_KQ: + head_size_v = head_size_kq if head_size_kq != 576 else 512 + with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: + f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) + for type_k in TYPES_KV: for type_v in TYPES_KV: with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: @@ -64,7 +78,9 @@ def get_short_name(long_quant_name): with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: f.write(SOURCE_FATTN_MMA_START) - for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: + for head_size_kq in HEAD_SIZES_KQ: + if head_size_kq == 40: + continue if head_size_kq != 576 and ncols2 == 16: continue if head_size_kq == 576 and ncols2 != 16: diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 0e2b1847e09e2..934aefdcb45fa 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -53,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") +file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index f8477a2ef356d..d76cb51977f90 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh") file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")