@@ -105,29 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
105105}
106106
107107template  <int  block_size, bool  do_multiply = false , bool  do_add = false >
108- static  __global__  void  rms_norm_f32 (const  float  * x, float  *       dst,
108+ static  __global__  void  rms_norm_f32 (const  float  * x,
109+                                     float  *       dst,
109110                                    const  int      ncols,
110111                                    const  int64_t  stride_row,
111112                                    const  int64_t  stride_channel,
112113                                    const  int64_t  stride_sample,
113114                                    const  float    eps,
114-                                     const  float  * mul                = nullptr ,
115-                                     const  int64_t  mul_stride_row     = 0 ,
116-                                     const  int64_t  mul_stride_channel = 0 ,
117-                                     const  int64_t  mul_stride_sample  = 0 ,
118-                                     const  int      mul_ncols          = 0 ,
119-                                     const  int      mul_nrows          = 0 ,
120-                                     const  int      mul_nchannels      = 0 ,
121-                                     const  int      mul_nsamples       = 0 ,
122-                                     const  float  * add                = nullptr ,
123-                                     const  int64_t  add_stride_row     = 0 ,
124-                                     const  int64_t  add_stride_channel = 0 ,
125-                                     const  int64_t  add_stride_sample  = 0 ,
126-                                     const  int      add_ncols          = 0 ,
127-                                     const  int      add_nrows          = 0 ,
128-                                     const  int      add_nchannels      = 0 ,
129-                                     const  int      add_nsamples       = 0 ) {
130- 
115+                                     const  float  * mul                  = nullptr ,
116+                                     const  int64_t  mul_stride_row       = 0 ,
117+                                     const  int64_t  mul_stride_channel   = 0 ,
118+                                     const  int64_t  mul_stride_sample    = 0 ,
119+                                     const  uint3    mul_ncols_packed     = make_uint3(0 , 0 , 0 ),
120+                                     const  uint3   mul_nrows_packed     = make_uint3(0 , 0 , 0 ),
121+                                     const  uint3   mul_nchannels_packed = make_uint3(0 , 0 , 0 ),
122+                                     const  uint3   mul_nsamples_packed  = make_uint3(0 , 0 , 0 ),
123+                                     const  float * add                  = nullptr,
124+                                     const  int64_t add_stride_row       = 0,
125+                                     const  int64_t add_stride_channel   = 0,
126+                                     const  int64_t add_stride_sample    = 0,
127+                                     const  uint3   add_ncols_packed     = make_uint3(0 , 0 , 0 ),
128+                                     const  uint3   add_nrows_packed     = make_uint3(0 , 0 , 0 ),
129+                                     const  uint3   add_nchannels_packed = make_uint3(0 , 0 , 0 ),
130+                                     const  uint3   add_nsamples_packed  = make_uint3(0 , 0 , 0 )) {
131131    const  int  nrows     = gridDim .x ;
132132    const  int  nchannels = gridDim .y ;
133133
@@ -142,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, float *       dst,
142142    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
143143
144144    if  constexpr  (do_multiply) {
145-         const  int  mul_row =  row % mul_nrows ;
146-         const  int  mul_channel = channel % mul_nchannels ;
147-         const  int  mul_sample =  sample % mul_nsamples ;
148-         mul += mul_sample* mul_stride_sample + mul_channel* mul_stride_channel + mul_row* mul_stride_row;
145+         const  uint32_t  mul_row     =  fastmodulo ( row, mul_nrows_packed) ;
146+         const  uint32_t  mul_channel = fastmodulo ( channel, mul_nchannels_packed) ;
147+         const  uint32_t  mul_sample  =  fastmodulo ( sample, mul_nsamples_packed) ;
148+         mul += mul_sample *  mul_stride_sample + mul_channel *  mul_stride_channel + mul_row *  mul_stride_row;
149149    }
150150
151151    if  constexpr  (do_add) {
152-         const  int  add_row     = row % add_nrows ;
153-         const  int  add_channel = channel % add_nchannels ;
154-         const  int  add_sample  = sample % add_nsamples ;
152+         const  int  add_row     = fastmodulo ( row, add_nrows_packed) ;
153+         const  int  add_channel = fastmodulo ( channel, add_nchannels_packed) ;
154+         const  int  add_sample  = fastmodulo ( sample, add_nsamples_packed) ;
155155        add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
156156    }
157157
@@ -165,15 +165,18 @@ static __global__ void rms_norm_f32(const float * x, float *       dst,
165165    //  sum up partial sums
166166    tmp = warp_reduce_sum (tmp);
167167    if  constexpr  (block_size > WARP_SIZE) {
168-         static_assert (block_size ==  1024 , " unexpected block_size" 
168+         static_assert (( block_size <=  1024 ) && (block_size %  32  ==  0 ) , " unexpected block_size" 
169169        __shared__  float  s_sum[32 ];
170-         const  int  warp_id = threadIdx . x  / WARP_SIZE;
171-         const  int  lane_id = threadIdx . x  % WARP_SIZE;
170+         const  int          warp_id = tid  / WARP_SIZE;
171+         const  int          lane_id = tid  % WARP_SIZE;
172172        if  (lane_id == 0 ) {
173173            s_sum[warp_id] = tmp;
174174        }
175175        __syncthreads ();
176-         tmp = s_sum[lane_id];
176+         tmp = 0 .0f ;
177+         if  (lane_id < (block_size / WARP_SIZE)) {
178+             tmp = s_sum[lane_id];
179+         }
177180        tmp = warp_reduce_sum (tmp);
178181    }
179182
@@ -182,12 +185,12 @@ static __global__ void rms_norm_f32(const float * x, float *       dst,
182185
183186    for  (int  col = tid; col < ncols; col += block_size) {
184187        if  constexpr  (do_multiply && do_add) {
185-             const  int  mul_col = col % mul_ncols ;
186-             const  int  add_col = col % add_ncols ;
187-             dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
188+             const  int  mul_col = fastmodulo ( col, mul_ncols_packed) ;
189+             const  int  add_col = fastmodulo ( col, add_ncols_packed) ;
190+             dst[col]           = scale * x[col] * mul[mul_col] + add[add_col];
188191        } else  if  constexpr  (do_multiply) {
189-             const  int  mul_col = col % mul_ncols ;
190-             dst[col] = scale * x[col] * mul[mul_col];
192+             const  int  mul_col = fastmodulo ( col, mul_ncols_packed) ;
193+             dst[col]           = scale * x[col] * mul[mul_col];
191194        } else  {
192195            dst[col] = scale * x[col];
193196        }
@@ -354,77 +357,86 @@ static void rms_norm_f32_cuda(
354357        const  int64_t  stride_row, const  int64_t  stride_channel, const  int64_t  stride_sample, const  float  eps, cudaStream_t stream) {
355358    const  dim3  blocks_num (nrows, nchannels, nsamples);
356359    if  (ncols < 1024 ) {
357-         const  dim3  block_dims (WARP_SIZE , 1 , 1 );
358-         rms_norm_f32<WARP_SIZE , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
360+         const  dim3  block_dims (256 , 1 , 1 );
361+         rms_norm_f32<256 , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
359362    } else  {
360363        const  dim3  block_dims (1024 , 1 , 1 );
361364        rms_norm_f32<1024 , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
362365    }
363366}
364367
365- static  void  rms_norm_mul_f32_cuda (const  float  * x,
366-                                   const  float  * mul,
367-                                   const  float  * add,
368-                                   float  *       dst,
369-                                   const  int      ncols,
370-                                   const  int      nrows,
371-                                   const  int      nchannels,
372-                                   const  int      nsamples,
373-                                   const  int64_t  stride_row,
374-                                   const  int64_t  stride_channel,
375-                                   const  int64_t  stride_sample,
376-                                   const  int64_t  mul_stride_row,
377-                                   const  int64_t  mul_stride_channel,
378-                                   const  int64_t  mul_stride_sample,
379-                                   const  int       mul_ncols,
380-                                   const  int       mul_nrows,
381-                                   const  int       mul_nchannels,
382-                                   const  int       mul_nsamples,
383-                                   const  int64_t  add_stride_row,
384-                                   const  int64_t  add_stride_channel,
385-                                   const  int64_t  add_stride_sample,
386-                                   const  int       add_ncols,
387-                                   const  int       add_nrows,
388-                                   const  int       add_nchannels,
389-                                   const  int       add_nsamples,
390-                                   const  float    eps,
391-                                   cudaStream_t  stream) {
368+ static  void  rms_norm_mul_f32_cuda (const  float  *   x,
369+                                   const  float  *   mul,
370+                                   const  float  *   add,
371+                                   float  *         dst,
372+                                   const  int        ncols,
373+                                   const  int        nrows,
374+                                   const  int        nchannels,
375+                                   const  int        nsamples,
376+                                   const  int64_t    stride_row,
377+                                   const  int64_t    stride_channel,
378+                                   const  int64_t    stride_sample,
379+                                   const  int64_t    mul_stride_row,
380+                                   const  int64_t    mul_stride_channel,
381+                                   const  int64_t    mul_stride_sample,
382+                                   const  uint32_t  mul_ncols,
383+                                   const  uint32_t  mul_nrows,
384+                                   const  uint32_t  mul_nchannels,
385+                                   const  uint32_t  mul_nsamples,
386+                                   const  int64_t    add_stride_row,
387+                                   const  int64_t    add_stride_channel,
388+                                   const  int64_t    add_stride_sample,
389+                                   const  uint32_t  add_ncols,
390+                                   const  uint32_t  add_nrows,
391+                                   const  uint32_t  add_nchannels,
392+                                   const  uint32_t  add_nsamples,
393+                                   const  float      eps,
394+                                   cudaStream_t    stream) {
392395    const  dim3  blocks_num (nrows, nchannels, nsamples);
393396    if  (mul == nullptr ) {
394397        rms_norm_f32_cuda (x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
395398        return ;
396399    }
397400    if  (add == nullptr ) {
401+         const  uint3  mul_ncols_packed     = init_fastdiv_values (mul_ncols);
402+         const  uint3  mul_nrows_packed     = init_fastdiv_values (mul_nrows);
403+         const  uint3  mul_nchannels_packed = init_fastdiv_values (mul_nchannels);
404+         const  uint3  mul_nsamples_packed  = init_fastdiv_values (mul_nsamples);
398405        if  (ncols < 1024 ) {
399-             const  dim3  block_dims (WARP_SIZE, 1 , 1 );
400-             rms_norm_f32<WARP_SIZE, true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst,
401-                 ncols, stride_row, stride_channel, stride_sample, eps,
402-                 mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
403-                 mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
406+             const  dim3  block_dims (256 , 1 , 1 );
407+             rms_norm_f32<256 , true ><<<blocks_num, block_dims, 0 , stream>>> (
408+                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
409+                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
404410        } else  {
405411            const  dim3  block_dims (1024 , 1 , 1 );
406-             rms_norm_f32<1024 , true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst,
407-                 ncols, stride_row, stride_channel, stride_sample, eps,
408-                 mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
409-                 mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
412+             rms_norm_f32<1024 , true ><<<blocks_num, block_dims, 0 , stream>>> (
413+                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
414+                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
410415        }
411416    } else  {
417+         const  uint3  mul_ncols_packed     = init_fastdiv_values (mul_ncols);
418+         const  uint3  mul_nrows_packed     = init_fastdiv_values (mul_nrows);
419+         const  uint3  mul_nchannels_packed = init_fastdiv_values (mul_nchannels);
420+         const  uint3  mul_nsamples_packed  = init_fastdiv_values (mul_nsamples);
421+ 
422+         const  uint3  add_ncols_packed     = init_fastdiv_values (add_ncols);
423+         const  uint3  add_nrows_packed     = init_fastdiv_values (add_nrows);
424+         const  uint3  add_nchannels_packed = init_fastdiv_values (add_nchannels);
425+         const  uint3  add_nsamples_packed  = init_fastdiv_values (add_nsamples);
412426        if  (ncols < 1024 ) {
413-             const  dim3  block_dims (WARP_SIZE, 1 , 1 );
414-             rms_norm_f32<WARP_SIZE, true , true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst,
415-                 ncols, stride_row, stride_channel, stride_sample, eps,
416-                 mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
417-                 mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
418-                 add, add_stride_row, add_stride_channel, add_stride_sample,
419-                 add_ncols, add_nrows, add_nchannels, add_nsamples);
427+             const  dim3  block_dims (256 , 1 , 1 );
428+             rms_norm_f32<256 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (
429+                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
430+                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
431+                 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
432+                 add_nchannels_packed, add_nsamples_packed);
420433        } else  {
421434            const  dim3  block_dims (1024 , 1 , 1 );
422-             rms_norm_f32<1024 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst,
423-                 ncols, stride_row, stride_channel, stride_sample, eps,
424-                 mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
425-                 mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
426-                 add, add_stride_row, add_stride_channel, add_stride_sample,
427-                 add_ncols, add_nrows, add_nchannels, add_nsamples);
435+             rms_norm_f32<1024 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (
436+                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
437+                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
438+                 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
439+                 add_nchannels_packed, add_nsamples_packed);
428440        }
429441    }
430442}
0 commit comments