Skip to content

Commit 6545c87

Browse files
ORipplerJohannesGaessler
authored andcommitted
CUDA: Optimize rms_norm_f32 kernel and its fused variants, giving 1-6% perf E2E (ggml-org#15715)
* Add fastdiv, use it in modulo and use modulo in rms_norm_f32 Fastdiv is much faster way to do integer division, which was identified as bottleneck in rms_norm_f32 * Support more `block_size` values in `rms_norm_f32` This makes us more flexible in selecting the optimal threads w.r.t paralellizing across a col vs. launch-overheads of threads and mio throttles * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Replace modulo with fastmodulo in `rms_norm_f32` * Use `BinPackArguments=true` for formating function calls Will file a separate PR to adjust .clang-format file * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Use uint3 for both `fastdiv` and `fastmodulo` The compiler seems to reliably optimize away the unused .z component in the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3 * More constrained type declarations Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Rename fastdiv and fastmodulo variables to shared variable name As suggest by JohannesGaessler, this increases clarity of the intended use * Pack fastdiv/fastmodulo constants into uint2/uint3 objects By packing constants to be used together into a struct, we are less likely to make errors. * Rename function parameter of fastmodulo `modulo_consts` is more fitting/descriptive --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent 375e61d commit 6545c87

File tree

2 files changed

+129
-85
lines changed

2 files changed

+129
-85
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,38 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
563563
#endif // CUDART_VERSION >= 12050
564564
}
565565

566+
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
567+
// Precompute mp (m' in the paper) and L such that division
568+
// can be computed using a multiply (high 32b of 64b result)
569+
// and a shift:
570+
//
571+
// n/d = (mulhi(n, mp) + n) >> L;
572+
static const uint3 init_fastdiv_values(uint32_t d) {
573+
// compute L = ceil(log2(d));
574+
uint32_t L = 0;
575+
while (L < 32 && (uint32_t{ 1 } << L) < d) {
576+
L++;
577+
}
578+
579+
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
580+
// pack divisor as well to reduce error surface
581+
return make_uint3(mp, L, d);
582+
}
583+
584+
static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {
585+
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z>
586+
// fastdiv_values.z is unused and optimized away by the compiler.
587+
// Compute high 32 bits of n * mp
588+
const uint32_t hi = __umulhi(n, fastdiv_values.x);
589+
// add n, apply bit shift
590+
return (hi + n) >> fastdiv_values.y;
591+
}
592+
593+
static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {
594+
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
595+
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
596+
}
597+
566598
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
567599

568600
static __device__ __forceinline__ float get_alibi_slope(

ggml/src/ggml-cuda/norm.cu

Lines changed: 97 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -105,29 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
105105
}
106106

107107
template <int block_size, bool do_multiply = false, bool do_add = false>
108-
static __global__ void rms_norm_f32(const float * x, float * dst,
108+
static __global__ void rms_norm_f32(const float * x,
109+
float * dst,
109110
const int ncols,
110111
const int64_t stride_row,
111112
const int64_t stride_channel,
112113
const int64_t stride_sample,
113114
const float eps,
114-
const float * mul = nullptr,
115-
const int64_t mul_stride_row = 0,
116-
const int64_t mul_stride_channel = 0,
117-
const int64_t mul_stride_sample = 0,
118-
const int mul_ncols = 0,
119-
const int mul_nrows = 0,
120-
const int mul_nchannels = 0,
121-
const int mul_nsamples = 0,
122-
const float * add = nullptr,
123-
const int64_t add_stride_row = 0,
124-
const int64_t add_stride_channel = 0,
125-
const int64_t add_stride_sample = 0,
126-
const int add_ncols = 0,
127-
const int add_nrows = 0,
128-
const int add_nchannels = 0,
129-
const int add_nsamples = 0) {
130-
115+
const float * mul = nullptr,
116+
const int64_t mul_stride_row = 0,
117+
const int64_t mul_stride_channel = 0,
118+
const int64_t mul_stride_sample = 0,
119+
const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
120+
const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
121+
const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
122+
const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
123+
const float * add = nullptr,
124+
const int64_t add_stride_row = 0,
125+
const int64_t add_stride_channel = 0,
126+
const int64_t add_stride_sample = 0,
127+
const uint3 add_ncols_packed = make_uint3(0, 0, 0),
128+
const uint3 add_nrows_packed = make_uint3(0, 0, 0),
129+
const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
130+
const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
131131
const int nrows = gridDim.x;
132132
const int nchannels = gridDim.y;
133133

@@ -142,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
142142
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
143143

144144
if constexpr (do_multiply) {
145-
const int mul_row = row % mul_nrows;
146-
const int mul_channel = channel % mul_nchannels;
147-
const int mul_sample = sample % mul_nsamples;
148-
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
145+
const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
146+
const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
147+
const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
148+
mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
149149
}
150150

151151
if constexpr (do_add) {
152-
const int add_row = row % add_nrows;
153-
const int add_channel = channel % add_nchannels;
154-
const int add_sample = sample % add_nsamples;
152+
const int add_row = fastmodulo(row, add_nrows_packed);
153+
const int add_channel = fastmodulo(channel, add_nchannels_packed);
154+
const int add_sample = fastmodulo(sample, add_nsamples_packed);
155155
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
156156
}
157157

@@ -165,15 +165,18 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
165165
// sum up partial sums
166166
tmp = warp_reduce_sum(tmp);
167167
if constexpr (block_size > WARP_SIZE) {
168-
static_assert(block_size == 1024, "unexpected block_size");
168+
static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
169169
__shared__ float s_sum[32];
170-
const int warp_id = threadIdx.x / WARP_SIZE;
171-
const int lane_id = threadIdx.x % WARP_SIZE;
170+
const int warp_id = tid / WARP_SIZE;
171+
const int lane_id = tid % WARP_SIZE;
172172
if (lane_id == 0) {
173173
s_sum[warp_id] = tmp;
174174
}
175175
__syncthreads();
176-
tmp = s_sum[lane_id];
176+
tmp = 0.0f;
177+
if (lane_id < (block_size / WARP_SIZE)) {
178+
tmp = s_sum[lane_id];
179+
}
177180
tmp = warp_reduce_sum(tmp);
178181
}
179182

@@ -182,12 +185,12 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
182185

183186
for (int col = tid; col < ncols; col += block_size) {
184187
if constexpr (do_multiply && do_add) {
185-
const int mul_col = col % mul_ncols;
186-
const int add_col = col % add_ncols;
187-
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
188+
const int mul_col = fastmodulo(col, mul_ncols_packed);
189+
const int add_col = fastmodulo(col, add_ncols_packed);
190+
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
188191
} else if constexpr (do_multiply) {
189-
const int mul_col = col % mul_ncols;
190-
dst[col] = scale * x[col] * mul[mul_col];
192+
const int mul_col = fastmodulo(col, mul_ncols_packed);
193+
dst[col] = scale * x[col] * mul[mul_col];
191194
} else {
192195
dst[col] = scale * x[col];
193196
}
@@ -354,77 +357,86 @@ static void rms_norm_f32_cuda(
354357
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
355358
const dim3 blocks_num(nrows, nchannels, nsamples);
356359
if (ncols < 1024) {
357-
const dim3 block_dims(WARP_SIZE, 1, 1);
358-
rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
360+
const dim3 block_dims(256, 1, 1);
361+
rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
359362
} else {
360363
const dim3 block_dims(1024, 1, 1);
361364
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
362365
}
363366
}
364367

365-
static void rms_norm_mul_f32_cuda(const float * x,
366-
const float * mul,
367-
const float * add,
368-
float * dst,
369-
const int ncols,
370-
const int nrows,
371-
const int nchannels,
372-
const int nsamples,
373-
const int64_t stride_row,
374-
const int64_t stride_channel,
375-
const int64_t stride_sample,
376-
const int64_t mul_stride_row,
377-
const int64_t mul_stride_channel,
378-
const int64_t mul_stride_sample,
379-
const int mul_ncols,
380-
const int mul_nrows,
381-
const int mul_nchannels,
382-
const int mul_nsamples,
383-
const int64_t add_stride_row,
384-
const int64_t add_stride_channel,
385-
const int64_t add_stride_sample,
386-
const int add_ncols,
387-
const int add_nrows,
388-
const int add_nchannels,
389-
const int add_nsamples,
390-
const float eps,
391-
cudaStream_t stream) {
368+
static void rms_norm_mul_f32_cuda(const float * x,
369+
const float * mul,
370+
const float * add,
371+
float * dst,
372+
const int ncols,
373+
const int nrows,
374+
const int nchannels,
375+
const int nsamples,
376+
const int64_t stride_row,
377+
const int64_t stride_channel,
378+
const int64_t stride_sample,
379+
const int64_t mul_stride_row,
380+
const int64_t mul_stride_channel,
381+
const int64_t mul_stride_sample,
382+
const uint32_t mul_ncols,
383+
const uint32_t mul_nrows,
384+
const uint32_t mul_nchannels,
385+
const uint32_t mul_nsamples,
386+
const int64_t add_stride_row,
387+
const int64_t add_stride_channel,
388+
const int64_t add_stride_sample,
389+
const uint32_t add_ncols,
390+
const uint32_t add_nrows,
391+
const uint32_t add_nchannels,
392+
const uint32_t add_nsamples,
393+
const float eps,
394+
cudaStream_t stream) {
392395
const dim3 blocks_num(nrows, nchannels, nsamples);
393396
if (mul == nullptr) {
394397
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
395398
return;
396399
}
397400
if (add == nullptr) {
401+
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
402+
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
403+
const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
404+
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
398405
if (ncols < 1024) {
399-
const dim3 block_dims(WARP_SIZE, 1, 1);
400-
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
401-
ncols, stride_row, stride_channel, stride_sample, eps,
402-
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
403-
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
406+
const dim3 block_dims(256, 1, 1);
407+
rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
408+
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
409+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
404410
} else {
405411
const dim3 block_dims(1024, 1, 1);
406-
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
407-
ncols, stride_row, stride_channel, stride_sample, eps,
408-
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
409-
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
412+
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
413+
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
414+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
410415
}
411416
} else {
417+
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
418+
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
419+
const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
420+
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
421+
422+
const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
423+
const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
424+
const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
425+
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
412426
if (ncols < 1024) {
413-
const dim3 block_dims(WARP_SIZE, 1, 1);
414-
rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
415-
ncols, stride_row, stride_channel, stride_sample, eps,
416-
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
417-
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
418-
add, add_stride_row, add_stride_channel, add_stride_sample,
419-
add_ncols, add_nrows, add_nchannels, add_nsamples);
427+
const dim3 block_dims(256, 1, 1);
428+
rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
429+
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
430+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
431+
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
432+
add_nchannels_packed, add_nsamples_packed);
420433
} else {
421434
const dim3 block_dims(1024, 1, 1);
422-
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
423-
ncols, stride_row, stride_channel, stride_sample, eps,
424-
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
425-
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
426-
add, add_stride_row, add_stride_channel, add_stride_sample,
427-
add_ncols, add_nrows, add_nchannels, add_nsamples);
435+
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
436+
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
437+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
438+
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
439+
add_nchannels_packed, add_nsamples_packed);
428440
}
429441
}
430442
}

0 commit comments

Comments
 (0)