1- 
21#include  " ggml.h" 
32#include  " common.cuh" 
43#include  " convert.cuh" 
@@ -8,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
87static  __global__  void  mul_mat_vec_f (
98        const  T * __restrict__  x, const  float  * __restrict__  y, const  int32_t  * __restrict__  ids, float  * __restrict__  dst,
109        const  int  ncols2, const  int  nchannels_y, const  int  stride_row, const  int  stride_col_y2, const  int  stride_col_dst,
11-         const  uint3  channel_ratio_fd , const  int  stride_channel_x, const  int  stride_channel_y, const  int  stride_channel_dst,
12-         const  uint3  sample_ratio_fd , const  int  stride_sample_x, const  int  stride_sample_y, const  int  stride_sample_dst) {
10+         const  uint3  channel_ratio , const  int  stride_channel_x, const  int  stride_channel_y, const  int  stride_channel_dst,
11+         const  uint3  sample_ratio , const  int  stride_sample_x, const  int  stride_sample_y, const  int  stride_sample_dst) {
1312    const  int  row         = blockIdx .x ;
1413    const  int  channel_dst = blockIdx .y ;
15-     const  int  channel_x   = ids ? ids[channel_dst]          : fastdiv ((uint32_t ) channel_dst, channel_ratio_fd );
14+     const  int  channel_x   = ids ? ids[channel_dst]          : fastdiv ((uint32_t ) channel_dst, channel_ratio );
1615    const  int  channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
1716    const  int  sample_dst  = blockIdx .z ;
18-     const  int  sample_x    = fastdiv ((uint32_t ) sample_dst, sample_ratio_fd );
17+     const  int  sample_x    = fastdiv ((uint32_t ) sample_dst, sample_ratio );
1918    const  int  sample_y    = sample_dst;
2019    const  int  tid         = threadIdx .x ;
2120
@@ -89,16 +88,14 @@ static __global__ void mul_mat_vec_f(
8988#endif  //  FP16_AVAILABLE
9089        }
9190    } else  if  constexpr  (std::is_same_v<T, nv_bfloat16>) {
92-         const  int  * x2 = (const  int  *) x;
91+         const  nv_bfloat162  * x2 = (const  nv_bfloat162  *) x;
9392        for  (int  col2 = tid; col2 < ncols2; col2 += block_size) {
94-             const  int  tmpx = x2[col2];
93+             const  nv_bfloat162  tmpx = x2[col2];
9594#pragma  unroll
9695            for  (int  j = 0 ; j < ncols_dst; ++j) {
9796                const  float2  tmpy = y2[j*stride_col_y2 + col2];
98-                 const  float  tmpx0 = ggml_cuda_cast<float >(reinterpret_cast <const  nv_bfloat16 *>(&tmpx)[0 ]);
99-                 const  float  tmpx1 = ggml_cuda_cast<float >(reinterpret_cast <const  nv_bfloat16 *>(&tmpx)[1 ]);
100-                 ggml_cuda_mad (sumf[j], tmpx0, tmpy.x );
101-                 ggml_cuda_mad (sumf[j], tmpx1, tmpy.y );
97+                 ggml_cuda_mad (sumf[j], tmpx.x , tmpy.x );
98+                 ggml_cuda_mad (sumf[j], tmpx.y , tmpy.y );
10299            }
103100        }
104101    } else  {
@@ -143,7 +140,7 @@ static void launch_mul_mat_vec_f_cuda(
143140    GGML_ASSERT (stride_col_y % 2  == 0 );
144141    GGML_ASSERT (ids || nchannels_dst % nchannels_x == 0 );
145142    GGML_ASSERT (       nsamples_dst  % nsamples_x  == 0 );
146-     const  uint3  channel_ratio_fd = ids ? make_uint3 (0 , 0 , 0 )               : init_fastdiv_values (nchannels_dst / nchannels_x);
143+     const  uint3  channel_ratio_fd = ids ? make_uint3 (0 , 0 , 0 ) : init_fastdiv_values (nchannels_dst / nchannels_x);
147144    const  uint3  sample_ratio_fd  = init_fastdiv_values (nsamples_dst  / nsamples_x);
148145
149146    const  int  device = ggml_cuda_get_device ();
0 commit comments