@@ -1427,17 +1427,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
14271427static void aclnn_get_slope_inner (ggml_backend_cann_context& ctx, void * slope_buffer,
14281428 float m, int64_t size, float start, float stop, float step){
14291429 int64_t ne[] = {size};
1430- size_t nb[] = {sizeof (float )};
1430+ size_t nb[] = {sizeof (uint16_t )};
14311431
1432- ggml_cann_pool_alloc arange_allocator (ctx.pool (), size * sizeof (float ));
1432+ ggml_cann_pool_alloc arange_allocator (ctx.pool (), size * sizeof (uint16_t ));
14331433 void * arange_buffer = arange_allocator.get ();
14341434
14351435 aclTensor* arange_tensor = ggml_cann_create_tensor (
1436- arange_buffer, ACL_FLOAT , sizeof (float ), ne, nb, 1 );
1436+ arange_buffer, ACL_FLOAT16 , sizeof (uint16_t ), ne, nb, 1 );
14371437 aclnn_arange (ctx, arange_tensor, start, stop, step, size);
14381438
14391439 aclTensor* slope_tensor = ggml_cann_create_tensor (
1440- slope_buffer, ACL_FLOAT , sizeof (float ), ne, nb, 1 );
1440+ slope_buffer, ACL_FLOAT16 , sizeof (uint16_t ), ne, nb, 1 );
14411441
14421442 aclScalar* sc = aclCreateScalar (&m, aclDataType::ACL_FLOAT);
14431443
@@ -3180,11 +3180,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
31803180
31813181void ggml_cann_flash_attn_ext (ggml_backend_cann_context& ctx, ggml_tensor* dst){
31823182
3183- ggml_tensor* src0 = dst->src [0 ]; // q, fp32
3184- ggml_tensor* src1 = dst->src [1 ]; // k, fp16
3185- ggml_tensor* src2 = dst->src [2 ]; // v, fp16
3183+ ggml_tensor* src0 = dst->src [0 ]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
3184+ ggml_tensor* src1 = dst->src [1 ]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
3185+ ggml_tensor* src2 = dst->src [2 ]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
31863186 ggml_tensor* src3 = dst->src [3 ]; // mask, fp16
31873187
3188+ // B, N, S, D (uncont) -> B, S, N, D (cont)
3189+ int64_t src0_bsnd_ne[GGML_MAX_DIMS];
3190+ memcpy (src0_bsnd_ne, src0->ne , GGML_MAX_DIMS * sizeof (int64_t ));
3191+ size_t src0_bsnd_nb[GGML_MAX_DIMS];
3192+ memcpy (src0_bsnd_nb, src0->nb , GGML_MAX_DIMS * sizeof (size_t ));
3193+ int64_t src1_bsnd_ne[GGML_MAX_DIMS];
3194+ memcpy (src1_bsnd_ne, src1->ne , GGML_MAX_DIMS * sizeof (int64_t ));
3195+ size_t src1_bsnd_nb[GGML_MAX_DIMS];
3196+ memcpy (src1_bsnd_nb, src1->nb , GGML_MAX_DIMS * sizeof (size_t ));
3197+ int64_t src2_bsnd_ne[GGML_MAX_DIMS];
3198+ memcpy (src2_bsnd_ne, src2->ne , GGML_MAX_DIMS * sizeof (int64_t ));
3199+ size_t src2_bsnd_nb[GGML_MAX_DIMS];
3200+ memcpy (src2_bsnd_nb, src2->nb , GGML_MAX_DIMS * sizeof (size_t ));
3201+
3202+ auto transpose12 = [](int64_t * ne, size_t * nb) {
3203+ int64_t ne_tmp = ne[1 ];
3204+ size_t nb_tmp = nb[1 ];
3205+ ne[1 ] = ne[2 ];
3206+ nb[1 ] = nb[2 ];
3207+ ne[2 ] = ne_tmp;
3208+ nb[2 ] = nb_tmp;
3209+ };
3210+
3211+ transpose12 (src0_bsnd_ne, src0_bsnd_nb);
3212+ transpose12 (src1_bsnd_ne, src1_bsnd_nb);
3213+ transpose12 (src2_bsnd_ne, src2_bsnd_nb);
3214+
31883215 float maxBias = 0 .0f ;
31893216 float scaleValue = 1 .0f ;
31903217 float logitSoftcap = 0 .0f ;
@@ -3206,11 +3233,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32063233 void * src0_f16_buffer = nullptr ;
32073234
32083235 if (ggml_cann_type_mapping (src0->type ) != faDataType){
3209- aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor (src0);
3236+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor (src0, src0_bsnd_ne,
3237+ src0_bsnd_nb, GGML_MAX_DIMS);
32103238 src0_f16_buffer = src0_f16_allocator.alloc (
32113239 ggml_nelements (src0) * faElemSize);
32123240
3213- int64_t * src0_f16_ne = src0-> ne ;
3241+ int64_t * src0_f16_ne = src0_bsnd_ne ;
32143242 size_t src0_f16_nb[GGML_MAX_DIMS];
32153243 src0_f16_nb[0 ] = sizeof (uint16_t );
32163244 for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
@@ -3224,20 +3252,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32243252 aclnn_cast (ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
32253253 ggml_cann_release_resources (ctx, acl_src0_f32_tensor);
32263254 }else {
3227- acl_src0_f16_tensor = ggml_cann_create_tensor (src0);
3255+ acl_src0_f16_tensor = ggml_cann_create_tensor (src0, src0_bsnd_ne,
3256+ src0_bsnd_nb, GGML_MAX_DIMS);
32283257 }
32293258
32303259 // Step 2: create the acl tensors for src1 (Key), src2 (Value),
32313260 // and the direct output from FusedInferAttention
32323261
3233- acl_src1_f16_tensor = ggml_cann_create_tensor (src1);
3234- acl_src2_f16_tensor = ggml_cann_create_tensor (src2);
3262+ acl_src1_f16_tensor = ggml_cann_create_tensor (src1, src1_bsnd_ne,
3263+ src1_bsnd_nb, GGML_MAX_DIMS);
3264+ acl_src2_f16_tensor = ggml_cann_create_tensor (src2, src2_bsnd_ne,
3265+ src2_bsnd_nb, GGML_MAX_DIMS);
32353266
32363267 ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
32373268 void * out_f16_buffer = out_f16_allocator.alloc (
32383269 ggml_nelements (dst) * faElemSize);
32393270
3240- int64_t * out_f16_ne = src0-> ne ;
3271+ int64_t * out_f16_ne = src0_bsnd_ne ;
32413272 size_t out_f16_nb[GGML_MAX_DIMS];
32423273 out_f16_nb[0 ] = faElemSize;
32433274 for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
@@ -3251,88 +3282,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32513282
32523283 // Step 3: create the PSEShift tensor if needed
32533284 // this tensor is considered as mask (f16) in the llama.cpp
3254-
32553285 aclTensor* bcast_pse_tensor = nullptr ;
3256- int64_t bcast_pse_ne[GGML_MAX_DIMS];
3257- size_t bcast_pse_nb[GGML_MAX_DIMS];
32583286 ggml_cann_pool_alloc bcast_pse_allocator (ctx.pool ());
3259- void * bcast_pse_buffer = nullptr ;
3260-
32613287 if (src3 != nullptr ){
3262- bcast_pse_buffer = bcast_pse_allocator.alloc (
3263- ggml_nelements (src3) * src0->ne [2 ] * sizeof (uint16_t ));
3264-
3265- if (src0->ne [1 ] > 1 ){
3266- // Case 1: broadcast pse for prefill stage with multiple head
3267- aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor (src3);
3268- bcast_pse_ne[0 ] = src3->ne [0 ];
3269- bcast_pse_ne[1 ] = src3->ne [1 ];
3270- bcast_pse_ne[2 ] = src0->ne [2 ];
3271- bcast_pse_ne[3 ] = src3->ne [3 ];
3288+ // Construct the truncated pse tensor (common for prefill/decode)
3289+ int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
3290+ src3->ne [0 ], // D
3291+ src0->ne [1 ], // S (number of Q tokens)
3292+ src3->ne [2 ], // mask N
3293+ src3->ne [3 ] // B
3294+ };
3295+ size_t * trunc_pse_nb = src3->nb ;
3296+
3297+ aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor (
3298+ src3->data , ACL_FLOAT16, sizeof (uint16_t ),
3299+ trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
3300+ );
32723301
3302+ int64_t bcast_pse_ne[GGML_MAX_DIMS];
3303+ size_t bcast_pse_nb[GGML_MAX_DIMS];
3304+ bcast_pse_ne[0 ] = src3->ne [0 ]; // D
3305+ bcast_pse_ne[1 ] = src0->ne [1 ]; // S
3306+ bcast_pse_ne[2 ] = src0->ne [2 ]; // N (num_heads)
3307+ bcast_pse_ne[3 ] = src3->ne [3 ]; // B
3308+ if (maxBias == 0 .0f ) {
3309+ // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
3310+ // Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
32733311 bcast_pse_nb[0 ] = sizeof (uint16_t );
3274- for ( int i = 1 ; i < GGML_MAX_DIMS; ++i){
3275- bcast_pse_nb[i ] = bcast_pse_nb[i - 1 ] * bcast_pse_ne[i - 1 ];
3276- }
3312+ bcast_pse_nb[ 1 ] = bcast_pse_nb[ 0 ] * bcast_pse_ne[ 0 ];
3313+ bcast_pse_nb[2 ] = 0 ; // <---- the head dimension shares the same data
3314+ bcast_pse_nb[ 3 ] = src3-> nb [ 3 ];
32773315
32783316 bcast_pse_tensor = ggml_cann_create_tensor (
3279- bcast_pse_buffer, ACL_FLOAT16, sizeof (uint16_t ),
3280- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3281-
3282- int64_t repeats[] = {1 , src0->ne [2 ], 1 , 1 };
3283- aclnn_repeat (ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
3284-
3285- ggml_cann_release_resources (ctx, acl_mask_f16_tensor);
3286- }else {
3287- // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
3288- int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne [0 ], src0->ne [1 ], src3->ne [2 ], src3->ne [3 ]};
3289- size_t * trunc_pse_nb = src3->nb ;
3290-
3291- aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor (
32923317 src3->data , ACL_FLOAT16, sizeof (uint16_t ),
3293- trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
3294-
3295- bcast_pse_ne[0 ] = src3->ne [0 ];
3296- bcast_pse_ne[1 ] = src0->ne [1 ];
3297- bcast_pse_ne[2 ] = src0->ne [2 ];
3298- bcast_pse_ne[3 ] = src3->ne [3 ];
3318+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3319+ );
32993320
3321+ ggml_cann_release_resources (ctx, acl_mask_f16_trunc_tensor);
3322+ } else {
33003323 bcast_pse_nb[0 ] = sizeof (uint16_t );
3301- for (int i = 1 ; i < GGML_MAX_DIMS; ++i) {
3324+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
33023325 bcast_pse_nb[i] = bcast_pse_nb[i - 1 ] * bcast_pse_ne[i - 1 ];
33033326 }
33043327
3328+ void * bcast_pse_buffer = bcast_pse_allocator.alloc (
3329+ ggml_nelements (src3) * src0->ne [2 ] * sizeof (uint16_t )
3330+ );
3331+
33053332 bcast_pse_tensor = ggml_cann_create_tensor (
33063333 bcast_pse_buffer, ACL_FLOAT16, sizeof (uint16_t ),
3307- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3334+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3335+ );
33083336
33093337 int64_t repeats[] = {1 , src0->ne [2 ], 1 , 1 };
33103338 aclnn_repeat (ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
33113339
3312- ggml_cann_release_resources (ctx, acl_mask_f16_trunc_tensor);
3313- }
3314-
3315- // Compute the slope if needed. Derived from ggml_cann_softmax().
3316- if (maxBias != 0 .0f ){
33173340 // alibi
3341+ // Compute the slope if needed. Derived from ggml_cann_softmax().
33183342 const int64_t n_heads = src0->ne [2 ];
3319- ggml_cann_pool_alloc slope_allocator (ctx.pool (), n_heads * sizeof (float ));
3343+ ggml_cann_pool_alloc slope_allocator (ctx.pool (), n_heads * sizeof (uint16_t ));
33203344 void * slope_buffer = slope_allocator.get ();
33213345 aclnn_get_slope (ctx, n_heads, slope_buffer, maxBias);
33223346
33233347 int64_t slope_ne[] = {1 , 1 , n_heads, 1 };
33243348 size_t slope_nb[GGML_MAX_DIMS];
3325- slope_nb[0 ] = sizeof (float );
3349+ slope_nb[0 ] = sizeof (uint16_t );
33263350 for (int i = 1 ;i<GGML_MAX_DIMS;i++) {
33273351 slope_nb[i] = slope_nb[i-1 ] * slope_ne[0 ];
33283352 }
33293353
33303354 aclTensor* slope_tensor = ggml_cann_create_tensor (
3331- slope_buffer, ACL_FLOAT , sizeof (float ),
3355+ slope_buffer, ACL_FLOAT16 , sizeof (uint16_t ),
33323356 slope_ne, slope_nb, GGML_MAX_DIMS);
33333357 GGML_CANN_CALL_ACLNN_OP (ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
33343358
3335- ggml_cann_release_resources (ctx, slope_tensor);
3359+ ggml_cann_release_resources (ctx, slope_tensor, acl_mask_f16_trunc_tensor );
33363360 }
33373361 }
33383362
@@ -3349,7 +3373,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33493373 // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
33503374 int64_t preTokens = 65535 ;
33513375 int64_t nextTokens = 65535 ;
3352- char layout[5 ] = {' B' , ' N ' , ' S ' , ' D' , 0 };
3376+ char layout[5 ] = {' B' , ' S ' , ' N ' , ' D' , 0 };
33533377 int64_t sparseMode = 0 ;
33543378 int64_t innerPrecise = (src0->ne [1 ] == 1 ) ? 0 : 2 ;
33553379 int64_t blockSize = 0 ;
@@ -3386,32 +3410,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33863410 );
33873411
33883412 // Step 6: post-processing, permute and cast to f32
3389-
3390- int64_t new_dim[] = {0 , 2 , 1 , 3 };
33913413 aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3392-
3393- if (ggml_cann_type_mapping (dst->type ) != faDataType){
3394- ggml_cann_pool_alloc perm_out_f16_allocator (ctx.pool ());
3395- perm_out_f16_allocator.alloc (ggml_nelements (dst) * faElemSize);
3396- void * perm_out_f16_buffer = perm_out_f16_allocator.get ();
3397-
3398- int64_t * perm_out_f16_ne = dst->ne ;
3399- size_t perm_out_f16_nb[GGML_MAX_DIMS];
3400- perm_out_f16_nb[0 ] = faElemSize;
3401- for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
3402- perm_out_f16_nb[i] = perm_out_f16_nb[i - 1 ] * perm_out_f16_ne[i - 1 ];
3403- }
3404- aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor (
3405- perm_out_f16_buffer, faDataType, faElemSize,
3406- perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
3407- aclnn_permute (ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
3408- aclnn_cast (ctx,
3409- acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
3410- ggml_cann_release_resources (ctx, acl_perm_out_f16_tensor);
3411- }else {
3412- // only need to permute
3413- aclnn_permute (ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
3414- }
3414+ // TODO: when dst is fp16, don't need cast
3415+ aclnn_cast (ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
34153416 ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
34163417 acl_src1_f16_tensor,
34173418 acl_src2_f16_tensor,
0 commit comments