@@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1313 return __half2float (val);
1414}
1515
16+ struct soft_max_params {
17+
18+ int64_t nheads;
19+ uint32_t n_head_log2;
20+ int64_t ncols;
21+ int64_t nrows_x;
22+ int64_t nrows_y;
23+ int64_t ne00;
24+ int64_t ne01;
25+ int64_t ne02;
26+ int64_t ne03;
27+ int64_t nb11;
28+ int64_t nb12;
29+ int64_t nb13;
30+
31+ int64_t ne12;
32+ int64_t ne13;
33+ float scale;
34+ float max_bias;
35+ float m0;
36+ float m1;
37+ };
38+
1639// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
1740// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
1841#ifdef __clang__
@@ -21,24 +44,32 @@ __device__ float __forceinline__ t2f32<half>(half val) {
2144#endif // __clang__
2245template <bool use_shared, int ncols_template, int block_size_template, typename T>
2346static __global__ void soft_max_f32 (
24- const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25- const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
26- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
47+ const float * x, const T * mask, float * dst, const soft_max_params p) {
48+ const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
2749
2850 const int tid = threadIdx .x ;
29- const int rowx = blockIdx .x ;
30- const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
51+
52+ const int64_t i03 = blockIdx .z ;
53+ const int64_t i02 = blockIdx .y ;
54+ const int64_t i01 = blockIdx .x ;
55+
56+ // TODO: noncontigous inputs/outputs
57+ const int rowx = blockIdx .x + blockIdx .y * gridDim .x + blockIdx .z * gridDim .x * gridDim .y ;
58+
59+ const int64_t i11 = i01;
60+ const int64_t i12 = i02 % p.ne12 ;
61+ const int64_t i13 = i03 % p.ne13 ;
3162
3263 x += int64_t (rowx)*ncols;
33- mask += int64_t (rowy)*ncols * (mask != nullptr );
64+ mask += (i11*p. nb11 + i12*p. nb12 + i13*p. nb13 ) / sizeof (T) * (mask != nullptr );
3465 dst += int64_t (rowx)*ncols;
3566
3667 const int block_size = block_size_template == 0 ? blockDim .x : block_size_template;
3768
3869 const int warp_id = threadIdx .x / WARP_SIZE;
3970 const int lane_id = threadIdx .x % WARP_SIZE;
4071
41- const float slope = get_alibi_slope (max_bias, rowx/nrows_y, n_head_log2, m0, m1);
72+ const float slope = get_alibi_slope (p. max_bias , i02, p. n_head_log2 , p. m0 , p. m1 );
4273
4374 extern __shared__ float data_soft_max_f32[];
4475 float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -55,7 +86,7 @@ static __global__ void soft_max_f32(
5586 break ;
5687 }
5788
58- const float val = x[col]*scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
89+ const float val = x[col]*p. scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
5990
6091 vals[col] = val;
6192 max_val = max (max_val, val);
@@ -151,63 +182,60 @@ static __global__ void soft_max_back_f32(
151182}
152183
153184template <typename T>
154- static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias , cudaStream_t stream) {
185+ static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params , cudaStream_t stream) {
155186 int nth = WARP_SIZE;
187+ const int64_t ncols_x = params.ncols ;
188+
156189 while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
157190 const dim3 block_dims (nth, 1 , 1 );
158- const dim3 block_nums (nrows_x, 1 , 1 );
191+ const dim3 block_nums (params. ne01 , params. ne02 , params. ne03 );
159192 const size_t nbytes_shared = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof (float );
160193 static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
161194
162- const uint32_t n_head = nrows_x/nrows_y;
163- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
164-
165- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
166- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
167195
168196 // FIXME: this limit could be raised by ~2-4x on Ampere or newer
169197 if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
170198 switch (ncols_x) {
171199 case 32 :
172200 soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
173- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
201+ (x, mask, dst, params );
174202 break ;
175203 case 64 :
176204 soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
177- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
205+ (x, mask, dst, params );
178206 break ;
179207 case 128 :
180208 soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
181- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
209+ (x, mask, dst, params );
182210 break ;
183211 case 256 :
184212 soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
185- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
213+ (x, mask, dst, params );
186214 break ;
187215 case 512 :
188216 soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
189- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
217+ (x, mask, dst, params );
190218 break ;
191219 case 1024 :
192220 soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
193- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
221+ (x, mask, dst, params );
194222 break ;
195223 case 2048 :
196224 soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
197- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
225+ (x, mask, dst, params );
198226 break ;
199227 case 4096 :
200228 soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
201- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
229+ (x, mask, dst, params );
202230 break ;
203231 default :
204232 soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
205- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
233+ (x, mask, dst, params );
206234 break ;
207235 }
208236 } else {
209237 const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
210- soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
238+ soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params );
211239 }
212240}
213241
@@ -235,10 +263,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
235263
236264 GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
237265
238- const int64_t ne00 = src0->ne [0 ];
239266 const int64_t nrows_x = ggml_nrows (src0);
240267 const int64_t nrows_y = src0->ne [1 ];
241268
269+ const int64_t ne00 = src0->ne [0 ];
270+
242271 float scale = 1 .0f ;
243272 float max_bias = 0 .0f ;
244273
@@ -247,10 +276,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
247276
248277 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
249278
279+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
280+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
281+ const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
282+
283+ const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
284+ const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
285+
286+ const uint32_t n_head = src0->ne [2 ];
287+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
288+
289+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
290+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
291+
292+
293+ soft_max_params params = {};
294+ params.nheads = src0->ne [2 ];
295+ params.n_head_log2 = n_head_log2;
296+ params.ncols = ne00;
297+ params.nrows_x = nrows_x;
298+ params.nrows_y = nrows_y;
299+ params.ne00 = src0->ne [0 ];
300+ params.ne01 = src0->ne [1 ];
301+ params.ne02 = src0->ne [2 ];
302+ params.ne03 = src0->ne [3 ];
303+ params.nb11 = nb11;
304+ params.nb12 = nb12;
305+ params.nb13 = nb13;
306+ params.ne12 = ne12;
307+ params.ne13 = ne13;
308+ params.scale = scale;
309+ params.max_bias = max_bias;
310+ params.m0 = m0;
311+ params.m1 = m1;
312+
250313 if (use_f16) {
251- soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias , stream);
314+ soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, params , stream);
252315 } else {
253- soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias , stream);
316+ soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, params , stream);
254317 }
255318}
256319
0 commit comments