@@ -5206,6 +5206,7 @@ kernel void kernel_flash_attn_ext(
52065206
52075207typedef decltype (kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 , 64 >) flash_attn_ext_t;
52085208
5209+ template [[host_name(" kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 32 , 32 >;
52095210template [[host_name(" kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 40 , 40 >;
52105211template [[host_name(" kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 64 , 64 >;
52115212template [[host_name(" kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 80 , 80 >;
@@ -5217,6 +5218,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_at
52175218template [[host_name(" kernel_flash_attn_ext_f32_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 256 , 256 >;
52185219template [[host_name(" kernel_flash_attn_ext_f32_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 576 , 512 >;
52195220
5221+ template [[host_name(" kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 32 , 32 >;
52205222template [[host_name(" kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 40 , 40 >;
52215223template [[host_name(" kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 , 64 >;
52225224template [[host_name(" kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 80 , 80 >;
@@ -5229,6 +5231,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
52295231template [[host_name(" kernel_flash_attn_ext_f16_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 576 , 512 >;
52305232
52315233#if defined(GGML_METAL_HAS_BF16)
5234+ template [[host_name(" kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 32 , 32 >;
52325235template [[host_name(" kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 40 , 40 >;
52335236template [[host_name(" kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 , 64 >;
52345237template [[host_name(" kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 80 , 80 >;
@@ -5241,6 +5244,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
52415244template [[host_name(" kernel_flash_attn_ext_bf16_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 576 , 512 >;
52425245#endif
52435246
5247+ template [[host_name(" kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 32 , 32 >;
52445248template [[host_name(" kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 40 , 40 >;
52455249template [[host_name(" kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 64 , 64 >;
52465250template [[host_name(" kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 80 , 80 >;
@@ -5252,6 +5256,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
52525256template [[host_name(" kernel_flash_attn_ext_q4_0_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 256 , 256 >;
52535257template [[host_name(" kernel_flash_attn_ext_q4_0_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 576 , 512 >;
52545258
5259+ template [[host_name(" kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 32 , 32 >;
52555260template [[host_name(" kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 40 , 40 >;
52565261template [[host_name(" kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 64 , 64 >;
52575262template [[host_name(" kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 80 , 80 >;
@@ -5263,6 +5268,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
52635268template [[host_name(" kernel_flash_attn_ext_q4_1_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 256 , 256 >;
52645269template [[host_name(" kernel_flash_attn_ext_q4_1_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 576 , 512 >;
52655270
5271+ template [[host_name(" kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 32 , 32 >;
52665272template [[host_name(" kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 40 , 40 >;
52675273template [[host_name(" kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 64 , 64 >;
52685274template [[host_name(" kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 80 , 80 >;
@@ -5274,6 +5280,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
52745280template [[host_name(" kernel_flash_attn_ext_q5_0_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 256 , 256 >;
52755281template [[host_name(" kernel_flash_attn_ext_q5_0_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 576 , 512 >;
52765282
5283+ template [[host_name(" kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 32 , 32 >;
52775284template [[host_name(" kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 40 , 40 >;
52785285template [[host_name(" kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 64 , 64 >;
52795286template [[host_name(" kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 80 , 80 >;
@@ -5285,6 +5292,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
52855292template [[host_name(" kernel_flash_attn_ext_q5_1_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 256 , 256 >;
52865293template [[host_name(" kernel_flash_attn_ext_q5_1_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 576 , 512 >;
52875294
5295+ template [[host_name(" kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 32 , 32 >;
52885296template [[host_name(" kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 40 , 40 >;
52895297template [[host_name(" kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 64 , 64 >;
52905298template [[host_name(" kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 80 , 80 >;
@@ -5830,6 +5838,17 @@ kernel void kernel_flash_attn_ext_vec(
58305838
58315839typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >) flash_attn_ext_vec_t;
58325840
5841+ template [[host_name(" kernel_flash_attn_ext_vec_f32_dk32_dv32" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1 , dequantize_f32_t4, float4, 1 , dequantize_f32_t4, 32 , 32 , 4 >;
5842+ template [[host_name(" kernel_flash_attn_ext_vec_f16_dk32_dv32" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 32 , 32 , 4 >;
5843+ #if defined(GGML_METAL_HAS_BF16)
5844+ template [[host_name(" kernel_flash_attn_ext_vec_bf16_dk32_dv32" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1 , dequantize_bf16_t4, bfloat4, 1 , dequantize_bf16_t4, 32 , 32 , 4 >;
5845+ #endif
5846+ template [[host_name(" kernel_flash_attn_ext_vec_q4_0_dk32_dv32" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8 , dequantize_q4_0_t4, block_q4_0, 8 , dequantize_q4_0_t4, 32 , 32 , 4 >;
5847+ template [[host_name(" kernel_flash_attn_ext_vec_q4_1_dk32_dv32" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8 , dequantize_q4_1_t4, block_q4_1, 8 , dequantize_q4_1_t4, 32 , 32 , 4 >;
5848+ template [[host_name(" kernel_flash_attn_ext_vec_q5_0_dk32_dv32" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8 , dequantize_q5_0_t4, block_q5_0, 8 , dequantize_q5_0_t4, 32 , 32 , 4 >;
5849+ template [[host_name(" kernel_flash_attn_ext_vec_q5_1_dk32_dv32" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8 , dequantize_q5_1_t4, block_q5_1, 8 , dequantize_q5_1_t4, 32 , 32 , 4 >;
5850+ template [[host_name(" kernel_flash_attn_ext_vec_q8_0_dk32_dv32" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 32 , 32 , 4 >;
5851+
58335852template [[host_name(" kernel_flash_attn_ext_vec_f32_dk64_dv64" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1 , dequantize_f32_t4, float4, 1 , dequantize_f32_t4, 64 , 64 , 2 >;
58345853template [[host_name(" kernel_flash_attn_ext_vec_f16_dk64_dv64" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 64 , 64 , 2 >;
58355854#if defined(GGML_METAL_HAS_BF16)
0 commit comments