@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443
443
#define CUDA_SCALE_BLOCK_SIZE 256
444
444
#define CUDA_CLAMP_BLOCK_SIZE 256
445
445
#define CUDA_ROPE_BLOCK_SIZE 256
446
+ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
446
447
#define CUDA_ALIBI_BLOCK_SIZE 32
447
448
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448
449
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -502,6 +503,31 @@ static size_t g_scratch_offset = 0;
502
503
503
504
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr };
504
505
506
+ static __device__ __forceinline__ float warp_reduce_sum (float x) {
507
+ #pragma unroll
508
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
509
+ x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
510
+ }
511
+ return x;
512
+ }
513
+
514
+ static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
515
+ #pragma unroll
516
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
517
+ a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
518
+ a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
519
+ }
520
+ return a;
521
+ }
522
+
523
+ static __device__ __forceinline__ float warp_reduce_max (float x) {
524
+ #pragma unroll
525
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
526
+ x = fmaxf (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
527
+ }
528
+ return x;
529
+ }
530
+
505
531
static __global__ void add_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
506
532
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
507
533
@@ -578,15 +604,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
578
604
dst[i] = x[i] * x[i];
579
605
}
580
606
581
- static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
582
- #pragma unroll
583
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
584
- a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
585
- a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
586
- }
587
- return a;
588
- }
589
-
590
607
template <int block_size>
591
608
static __global__ void norm_f32 (const float * x, float * dst, const int ncols) {
592
609
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -625,14 +642,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
625
642
}
626
643
}
627
644
628
- static __device__ __forceinline__ float warp_reduce_sum (float x) {
629
- #pragma unroll
630
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
631
- x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
632
- }
633
- return x;
634
- }
635
-
636
645
template <int block_size>
637
646
static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
638
647
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -4718,45 +4727,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
4718
4727
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
4719
4728
}
4720
4729
4721
- // the CUDA soft max implementation differs from the CPU implementation
4722
- // instead of doubles floats are used
4723
- static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
4724
- const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4725
- const int block_size = blockDim .y ;
4726
- const int tid = threadIdx .y ;
4730
+ static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4731
+ const int tid = threadIdx .x ;
4732
+ const int rowx = blockIdx .x ;
4733
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4734
+
4735
+ const int block_size = blockDim .x ;
4736
+
4737
+ const int warp_id = threadIdx .x / WARP_SIZE;
4738
+ const int lane_id = threadIdx .x % WARP_SIZE;
4739
+
4740
+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4727
4741
4728
4742
float max_val = -INFINITY;
4729
4743
4730
4744
for (int col = tid; col < ncols; col += block_size) {
4731
- const int i = row*ncols + col;
4732
- max_val = max (max_val, x[i]);
4745
+ const int ix = rowx*ncols + col;
4746
+ const int iy = rowy*ncols + col;
4747
+ max_val = max (max_val, x[ix]*scale + (y ? y[iy] : 0 .0f ));
4733
4748
}
4734
4749
4735
4750
// find the max value in the block
4736
- #pragma unroll
4737
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4738
- max_val = max (max_val, __shfl_xor_sync (0xffffffff , max_val, mask, 32 ));
4751
+ max_val = warp_reduce_max (max_val);
4752
+ if (block_size > WARP_SIZE) {
4753
+ if (warp_id == 0 ) {
4754
+ buf[lane_id] = -INFINITY;
4755
+ }
4756
+ __syncthreads ();
4757
+
4758
+ if (lane_id == 0 ) {
4759
+ buf[warp_id] = max_val;
4760
+ }
4761
+ __syncthreads ();
4762
+
4763
+ max_val = buf[lane_id];
4764
+ max_val = warp_reduce_max (max_val);
4739
4765
}
4740
4766
4741
4767
float tmp = 0 .f ;
4742
4768
4743
4769
for (int col = tid; col < ncols; col += block_size) {
4744
- const int i = row*ncols + col;
4745
- const float val = expf (x[i] - max_val);
4770
+ const int ix = rowx*ncols + col;
4771
+ const int iy = rowy*ncols + col;
4772
+ const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val);
4746
4773
tmp += val;
4747
- dst[i ] = val;
4774
+ dst[ix ] = val;
4748
4775
}
4749
4776
4750
- // sum up partial sums
4751
- #pragma unroll
4752
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4753
- tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
4777
+ // find the sum of exps in the block
4778
+ tmp = warp_reduce_sum (tmp);
4779
+ if (block_size > WARP_SIZE) {
4780
+ if (warp_id == 0 ) {
4781
+ buf[lane_id] = 0 .f ;
4782
+ }
4783
+ __syncthreads ();
4784
+
4785
+ if (lane_id == 0 ) {
4786
+ buf[warp_id] = tmp;
4787
+ }
4788
+ __syncthreads ();
4789
+
4790
+ tmp = buf[lane_id];
4791
+ tmp = warp_reduce_sum (tmp);
4754
4792
}
4755
4793
4756
4794
const float inv_tmp = 1 .f / tmp;
4757
4795
4758
4796
for (int col = tid; col < ncols; col += block_size) {
4759
- const int i = row *ncols + col;
4797
+ const int i = rowx *ncols + col;
4760
4798
dst[i] *= inv_tmp;
4761
4799
}
4762
4800
}
@@ -5793,10 +5831,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
5793
5831
diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
5794
5832
}
5795
5833
5796
- static void soft_max_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5797
- const dim3 block_dims (1 , WARP_SIZE, 1 );
5834
+ static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5835
+ int nth = WARP_SIZE;
5836
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
5837
+ const dim3 block_dims (nth, 1 , 1 );
5798
5838
const dim3 block_nums (nrows_x, 1 , 1 );
5799
- soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
5839
+ soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale );
5800
5840
}
5801
5841
5802
5842
static void im2col_f32_f16_cuda (const float * x, half * dst,
@@ -6835,14 +6875,18 @@ inline void ggml_cuda_op_soft_max(
6835
6875
GGML_ASSERT (src0->type == GGML_TYPE_F32);
6836
6876
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
6837
6877
6878
+ GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6879
+
6838
6880
const int64_t ne00 = src0->ne [0 ];
6839
- const int64_t nrows = ggml_nrows (src0);
6881
+ const int64_t nrows_x = ggml_nrows (src0);
6882
+ const int64_t nrows_y = src1 ? ggml_nrows (src1) : 1 ;
6840
6883
6841
- soft_max_f32_cuda (src0_dd, dst_dd, ne00, nrows, main_stream);
6884
+ float scale = 1 .0f ;
6885
+ memcpy (&scale, dst->op_params , sizeof (float ));
6886
+
6887
+ soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
6842
6888
6843
- (void ) src1;
6844
6889
(void ) dst;
6845
- (void ) src1_dd;
6846
6890
}
6847
6891
6848
6892
inline void ggml_cuda_op_scale (
0 commit comments