Skip to content

Commit 23f2ccc

Browse files
committed
use bf16 directly + fix formatting
1 parent 8898040 commit 23f2ccc

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#include "ggml.h"
32
#include "common.cuh"
43
#include "convert.cuh"
@@ -8,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
87
static __global__ void mul_mat_vec_f(
98
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
109
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
11-
const uint3 channel_ratio_fd, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
12-
const uint3 sample_ratio_fd, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
10+
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
11+
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
1312
const int row = blockIdx.x;
1413
const int channel_dst = blockIdx.y;
15-
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio_fd);
14+
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
1615
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
1716
const int sample_dst = blockIdx.z;
18-
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio_fd);
17+
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
1918
const int sample_y = sample_dst;
2019
const int tid = threadIdx.x;
2120

@@ -89,16 +88,14 @@ static __global__ void mul_mat_vec_f(
8988
#endif // FP16_AVAILABLE
9089
}
9190
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
92-
const int * x2 = (const int *) x;
91+
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
9392
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
94-
const int tmpx = x2[col2];
93+
const nv_bfloat162 tmpx = x2[col2];
9594
#pragma unroll
9695
for (int j = 0; j < ncols_dst; ++j) {
9796
const float2 tmpy = y2[j*stride_col_y2 + col2];
98-
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
99-
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
100-
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
101-
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
97+
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
98+
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
10299
}
103100
}
104101
} else {
@@ -143,7 +140,7 @@ static void launch_mul_mat_vec_f_cuda(
143140
GGML_ASSERT(stride_col_y % 2 == 0);
144141
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
145142
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
146-
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
143+
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
147144
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
148145

149146
const int device = ggml_cuda_get_device();

0 commit comments

Comments
 (0)