@@ -105,45 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
105105}
106106
107107template <int block_size, bool do_multiply = false , bool do_add = false >
108- static __global__ void rms_norm_f32 (const float * x,
109- float * dst,
110- const int ncols,
111- const int64_t stride_row,
112- const int64_t stride_channel,
113- const int64_t stride_sample,
114- const float eps,
115- const float * mul = nullptr ,
116- const int64_t mul_stride_row = 0 ,
117- const int64_t mul_stride_channel = 0 ,
118- const int64_t mul_stride_sample = 0 ,
119- const uint32_t mul_ncols = 0 ,
120- const uint32_t mul_nrows = 0 ,
121- const uint32_t mul_nchannels = 0 ,
122- const uint32_t mul_nsamples = 0 ,
123- const uint32_t mp_mul_cols = 0 ,
124- const uint32_t L_mul_cols = 0 ,
125- const uint32_t mp_mul_rows = 0 ,
126- const uint32_t L_mul_rows = 0 ,
127- const uint32_t mp_mul_channels = 0 ,
128- const uint32_t L_mul_channels = 0 ,
129- const uint32_t mp_mul_samples = 0 ,
130- const uint32_t L_mul_samples = 0 ,
131- const float * add = nullptr ,
132- const int64_t add_stride_row = 0 ,
133- const int64_t add_stride_channel = 0 ,
134- const int64_t add_stride_sample = 0 ,
135- const uint32_t add_ncols = 0 ,
136- const uint32_t add_nrows = 0 ,
137- const uint32_t add_nchannels = 0 ,
138- const uint32_t add_nsamples = 0 ,
139- const uint32_t mp_add_cols = 0 ,
140- const uint32_t L_add_cols = 0 ,
141- const uint32_t mp_add_rows = 0 ,
142- const uint32_t L_add_rows = 0 ,
143- const uint32_t mp_add_channels = 0 ,
144- const uint32_t L_add_channels = 0 ,
145- const uint32_t mp_add_samples = 0 ,
146- const uint32_t L_add_samples = 0 ) {
108+ static __global__ void rms_norm_f32 (const float * x,
109+ float * dst,
110+ const int ncols,
111+ const int64_t stride_row,
112+ const int64_t stride_channel,
113+ const int64_t stride_sample,
114+ const float eps,
115+ const float * mul = nullptr ,
116+ const int64_t mul_stride_row = 0 ,
117+ const int64_t mul_stride_channel = 0 ,
118+ const int64_t mul_stride_sample = 0 ,
119+ const uint3 mul_ncols_packed = make_uint3(0 , 0 , 0 ),
120+ const uint3 mul_nrows_packed = make_uint3(0 , 0 , 0 ),
121+ const uint3 mul_nchannels_packed = make_uint3(0 , 0 , 0 ),
122+ const uint3 mul_nsamples_packed = make_uint3(0 , 0 , 0 ),
123+ const float * add = nullptr,
124+ const int64_t add_stride_row = 0,
125+ const int64_t add_stride_channel = 0,
126+ const int64_t add_stride_sample = 0,
127+ const uint3 add_ncols_packed = make_uint3(0 , 0 , 0 ),
128+ const uint3 add_nrows_packed = make_uint3(0 , 0 , 0 ),
129+ const uint3 add_nchannels_packed = make_uint3(0 , 0 , 0 ),
130+ const uint3 add_nsamples_packed = make_uint3(0 , 0 , 0 )) {
147131 const int nrows = gridDim .x ;
148132 const int nchannels = gridDim .y ;
149133
@@ -158,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x,
158142 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
159143
160144 if constexpr (do_multiply) {
161- const uint32_t mul_row = fastmodulo (row, mul_nrows, mp_mul_rows, L_mul_rows );
162- const uint32_t mul_channel = fastmodulo (channel, mul_nchannels, mp_mul_channels, L_mul_channels );
163- const uint32_t mul_sample = fastmodulo (sample, mul_nsamples, mp_mul_samples, L_mul_samples );
145+ const uint32_t mul_row = fastmodulo (row, mul_nrows_packed );
146+ const uint32_t mul_channel = fastmodulo (channel, mul_nchannels_packed );
147+ const uint32_t mul_sample = fastmodulo (sample, mul_nsamples_packed );
164148 mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
165149 }
166150
167151 if constexpr (do_add) {
168- const int add_row = fastmodulo (row, add_nrows, mp_add_rows, L_add_rows );
169- const int add_channel = fastmodulo (channel, add_nchannels, mp_add_channels, L_add_channels );
170- const int add_sample = fastmodulo (sample, add_nsamples, mp_add_samples, L_add_samples );
152+ const int add_row = fastmodulo (row, add_nrows_packed );
153+ const int add_channel = fastmodulo (channel, add_nchannels_packed );
154+ const int add_sample = fastmodulo (sample, add_nsamples_packed );
171155 add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
172156 }
173157
@@ -201,11 +185,11 @@ static __global__ void rms_norm_f32(const float * x,
201185
202186 for (int col = tid; col < ncols; col += block_size) {
203187 if constexpr (do_multiply && do_add) {
204- const int mul_col = fastmodulo (col, mul_ncols, mp_mul_cols, L_mul_cols );
205- const int add_col = fastmodulo (col, add_ncols, mp_add_cols, L_add_cols );
188+ const int mul_col = fastmodulo (col, mul_ncols_packed );
189+ const int add_col = fastmodulo (col, add_ncols_packed );
206190 dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
207191 } else if constexpr (do_multiply) {
208- const int mul_col = fastmodulo (col, mul_ncols, mp_mul_cols, L_mul_cols );
192+ const int mul_col = fastmodulo (col, mul_ncols_packed );
209193 dst[col] = scale * x[col] * mul[mul_col];
210194 } else {
211195 dst[col] = scale * x[col];
@@ -414,63 +398,45 @@ static void rms_norm_mul_f32_cuda(const float * x,
414398 return ;
415399 }
416400 if (add == nullptr ) {
417- uint32_t mp_mul_cols, L_mul_cols;
418- init_fastdiv_values (mul_ncols, mp_mul_cols, L_mul_cols);
419- uint32_t mp_mul_rows, L_mul_rows;
420- init_fastdiv_values (mul_nrows, mp_mul_rows, L_mul_rows);
421- uint32_t mp_mul_channels, L_mul_channels;
422- init_fastdiv_values (mul_nchannels, mp_mul_channels, L_mul_channels);
423- uint32_t mp_mul_samples, L_mul_samples;
424- init_fastdiv_values (mul_nsamples, mp_mul_samples, L_mul_samples);
401+ uint3 mul_ncols_packed = init_fastmodulo_values (mul_ncols);
402+ uint3 mul_nrows_packed = init_fastmodulo_values (mul_nrows);
403+ uint3 mul_nchannels_packed = init_fastmodulo_values (mul_nchannels);
404+ uint3 mul_nsamples_packed = init_fastmodulo_values (mul_nsamples);
425405 if (ncols < 1024 ) {
426406 const dim3 block_dims (256 , 1 , 1 );
427407 rms_norm_f32<256 , true ><<<blocks_num, block_dims, 0 , stream>>> (
428408 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
429- mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
430- mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples);
409+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
431410 } else {
432411 const dim3 block_dims (1024 , 1 , 1 );
433412 rms_norm_f32<1024 , true ><<<blocks_num, block_dims, 0 , stream>>> (
434413 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
435- mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
436- mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples);
414+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
437415 }
438416 } else {
439- uint32_t mp_mul_cols, L_mul_cols;
440- init_fastdiv_values (mul_ncols, mp_mul_cols, L_mul_cols);
441- uint32_t mp_mul_rows, L_mul_rows;
442- init_fastdiv_values (mul_nrows, mp_mul_rows, L_mul_rows);
443- uint32_t mp_mul_channels, L_mul_channels;
444- init_fastdiv_values (mul_nchannels, mp_mul_channels, L_mul_channels);
445- uint32_t mp_mul_samples, L_mul_samples;
446- init_fastdiv_values (mul_nsamples, mp_mul_samples, L_mul_samples);
447-
448- uint32_t mp_add_cols, L_add_cols;
449- init_fastdiv_values (add_ncols, mp_add_cols, L_add_cols);
450- uint32_t mp_add_rows, L_add_rows;
451- init_fastdiv_values (add_nrows, mp_add_rows, L_add_rows);
452- uint32_t mp_add_channels, L_add_channels;
453- init_fastdiv_values (add_nchannels, mp_add_channels, L_add_channels);
454- uint32_t mp_add_samples, L_add_samples;
455- init_fastdiv_values (add_nsamples, mp_add_samples, L_add_samples);
417+ uint3 mul_ncols_packed = init_fastmodulo_values (mul_ncols);
418+ uint3 mul_nrows_packed = init_fastmodulo_values (mul_nrows);
419+ uint3 mul_nchannels_packed = init_fastmodulo_values (mul_nchannels);
420+ uint3 mul_nsamples_packed = init_fastmodulo_values (mul_nsamples);
421+
422+ uint3 add_ncols_packed = init_fastmodulo_values (add_ncols);
423+ uint3 add_nrows_packed = init_fastmodulo_values (add_nrows);
424+ uint3 add_nchannels_packed = init_fastmodulo_values (add_nchannels);
425+ uint3 add_nsamples_packed = init_fastmodulo_values (add_nsamples);
456426 if (ncols < 1024 ) {
457427 const dim3 block_dims (256 , 1 , 1 );
458428 rms_norm_f32<256 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (
459429 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
460- mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
461- mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add,
462- add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels,
463- add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels,
464- mp_add_samples, L_add_samples);
430+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
431+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
432+ add_nchannels_packed, add_nsamples_packed);
465433 } else {
466434 const dim3 block_dims (1024 , 1 , 1 );
467435 rms_norm_f32<1024 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (
468436 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
469- mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
470- mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add,
471- add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels,
472- add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels,
473- mp_add_samples, L_add_samples);
437+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
438+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
439+ add_nchannels_packed, add_nsamples_packed);
474440 }
475441 }
476442}
0 commit comments