Skip to content

Commit a884461

Browse files
ggerganovolexiyb
authored andcommitted
cuda : improve text-generation and batched decoding performance (ggml-org#3776)
* cuda : prints wip * cuda : new cublas gemm branch for multi-batch quantized src0 * cuda : add F32 sgemm branch * cuda : fine-tune >= VOLTA params + use MMQ only for small batches * cuda : remove duplicated cuBLAS GEMM code * cuda : add CUDA_USE_TENSOR_CORES and GGML_CUDA_FORCE_MMQ macros * build : add compile option to force use of MMQ kernels
1 parent 73b81db commit a884461

File tree

5 files changed

+125
-19
lines changed

5 files changed

+125
-19
lines changed

CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
8282
option(LLAMA_CUBLAS "llama: use CUDA" OFF)
8383
#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF)
8484
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
85+
option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF)
8586
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
8687
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
8788
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
@@ -305,6 +306,9 @@ if (LLAMA_CUBLAS)
305306
if (LLAMA_CUDA_FORCE_DMMV)
306307
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
307308
endif()
309+
if (LLAMA_CUDA_FORCE_MMQ)
310+
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
311+
endif()
308312
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
309313
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
310314
if (DEFINED LLAMA_CUDA_DMMV_Y)
@@ -405,6 +409,9 @@ if (LLAMA_HIPBLAS)
405409
if (LLAMA_CUDA_FORCE_DMMV)
406410
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV)
407411
endif()
412+
if (LLAMA_CUDA_FORCE_MMQ)
413+
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_MMQ)
414+
endif()
408415
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
409416
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
410417
target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})

Makefile

+3
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,9 @@ endif # CUDA_DOCKER_ARCH
397397
ifdef LLAMA_CUDA_FORCE_DMMV
398398
NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV
399399
endif # LLAMA_CUDA_FORCE_DMMV
400+
ifdef LLAMA_CUDA_FORCE_MMQ
401+
NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ
402+
endif # LLAMA_CUDA_FORCE_MMQ
400403
ifdef LLAMA_CUDA_DMMV_X
401404
NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
402405
else

ggml-cuda.cu

+114-16
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,24 @@
8787
#define CC_OFFSET_AMD 1000000
8888
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
8989

90+
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
91+
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
92+
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
93+
// - 7B quantum model: +100-200 MB
94+
// - 13B quantum model: +200-400 MB
95+
//
96+
//#define GGML_CUDA_FORCE_MMQ
97+
98+
// TODO: improve this to be correct for more hardware
99+
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
100+
// probably other such cases, and not sure what happens on AMD hardware
101+
#if !defined(GGML_CUDA_FORCE_MMQ)
102+
#define CUDA_USE_TENSOR_CORES
103+
#endif
104+
105+
// max batch size to use MMQ kernels when tensor cores are available
106+
#define MMQ_MAX_BATCH_SIZE 32
107+
90108
#if defined(GGML_USE_HIPBLAS)
91109
#define __CUDA_ARCH__ 1300
92110

@@ -470,7 +488,6 @@ static int g_device_count = -1;
470488
static int g_main_device = 0;
471489
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
472490
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
473-
static bool g_mul_mat_q = true;
474491

475492
static void * g_scratch_buffer = nullptr;
476493
static size_t g_scratch_size = 0; // disabled by default
@@ -3554,9 +3571,15 @@ static __device__ __forceinline__ void mul_mat_q(
35543571
#define MMQ_X_Q4_0_RDNA1 64
35553572
#define MMQ_Y_Q4_0_RDNA1 64
35563573
#define NWARPS_Q4_0_RDNA1 8
3574+
#if defined(CUDA_USE_TENSOR_CORES)
3575+
#define MMQ_X_Q4_0_AMPERE 4
3576+
#define MMQ_Y_Q4_0_AMPERE 32
3577+
#define NWARPS_Q4_0_AMPERE 4
3578+
#else
35573579
#define MMQ_X_Q4_0_AMPERE 64
35583580
#define MMQ_Y_Q4_0_AMPERE 128
35593581
#define NWARPS_Q4_0_AMPERE 4
3582+
#endif
35603583
#define MMQ_X_Q4_0_PASCAL 64
35613584
#define MMQ_Y_Q4_0_PASCAL 64
35623585
#define NWARPS_Q4_0_PASCAL 8
@@ -3615,9 +3638,15 @@ template <bool need_check> static __global__ void
36153638
#define MMQ_X_Q4_1_RDNA1 64
36163639
#define MMQ_Y_Q4_1_RDNA1 64
36173640
#define NWARPS_Q4_1_RDNA1 8
3641+
#if defined(CUDA_USE_TENSOR_CORES)
3642+
#define MMQ_X_Q4_1_AMPERE 4
3643+
#define MMQ_Y_Q4_1_AMPERE 32
3644+
#define NWARPS_Q4_1_AMPERE 4
3645+
#else
36183646
#define MMQ_X_Q4_1_AMPERE 64
36193647
#define MMQ_Y_Q4_1_AMPERE 128
36203648
#define NWARPS_Q4_1_AMPERE 4
3649+
#endif
36213650
#define MMQ_X_Q4_1_PASCAL 64
36223651
#define MMQ_Y_Q4_1_PASCAL 64
36233652
#define NWARPS_Q4_1_PASCAL 8
@@ -3678,9 +3707,15 @@ template <bool need_check> static __global__ void
36783707
#define MMQ_X_Q5_0_RDNA1 64
36793708
#define MMQ_Y_Q5_0_RDNA1 64
36803709
#define NWARPS_Q5_0_RDNA1 8
3710+
#if defined(CUDA_USE_TENSOR_CORES)
3711+
#define MMQ_X_Q5_0_AMPERE 4
3712+
#define MMQ_Y_Q5_0_AMPERE 32
3713+
#define NWARPS_Q5_0_AMPERE 4
3714+
#else
36813715
#define MMQ_X_Q5_0_AMPERE 128
36823716
#define MMQ_Y_Q5_0_AMPERE 64
36833717
#define NWARPS_Q5_0_AMPERE 4
3718+
#endif
36843719
#define MMQ_X_Q5_0_PASCAL 64
36853720
#define MMQ_Y_Q5_0_PASCAL 64
36863721
#define NWARPS_Q5_0_PASCAL 8
@@ -3739,9 +3774,15 @@ template <bool need_check> static __global__ void
37393774
#define MMQ_X_Q5_1_RDNA1 64
37403775
#define MMQ_Y_Q5_1_RDNA1 64
37413776
#define NWARPS_Q5_1_RDNA1 8
3777+
#if defined(CUDA_USE_TENSOR_CORES)
3778+
#define MMQ_X_Q5_1_AMPERE 4
3779+
#define MMQ_Y_Q5_1_AMPERE 32
3780+
#define NWARPS_Q5_1_AMPERE 4
3781+
#else
37423782
#define MMQ_X_Q5_1_AMPERE 128
37433783
#define MMQ_Y_Q5_1_AMPERE 64
37443784
#define NWARPS_Q5_1_AMPERE 4
3785+
#endif
37453786
#define MMQ_X_Q5_1_PASCAL 64
37463787
#define MMQ_Y_Q5_1_PASCAL 64
37473788
#define NWARPS_Q5_1_PASCAL 8
@@ -3800,9 +3841,15 @@ mul_mat_q5_1(
38003841
#define MMQ_X_Q8_0_RDNA1 64
38013842
#define MMQ_Y_Q8_0_RDNA1 64
38023843
#define NWARPS_Q8_0_RDNA1 8
3844+
#if defined(CUDA_USE_TENSOR_CORES)
3845+
#define MMQ_X_Q8_0_AMPERE 4
3846+
#define MMQ_Y_Q8_0_AMPERE 32
3847+
#define NWARPS_Q8_0_AMPERE 4
3848+
#else
38033849
#define MMQ_X_Q8_0_AMPERE 128
38043850
#define MMQ_Y_Q8_0_AMPERE 64
38053851
#define NWARPS_Q8_0_AMPERE 4
3852+
#endif
38063853
#define MMQ_X_Q8_0_PASCAL 64
38073854
#define MMQ_Y_Q8_0_PASCAL 64
38083855
#define NWARPS_Q8_0_PASCAL 8
@@ -3861,9 +3908,15 @@ template <bool need_check> static __global__ void
38613908
#define MMQ_X_Q2_K_RDNA1 128
38623909
#define MMQ_Y_Q2_K_RDNA1 32
38633910
#define NWARPS_Q2_K_RDNA1 8
3911+
#if defined(CUDA_USE_TENSOR_CORES)
3912+
#define MMQ_X_Q2_K_AMPERE 4
3913+
#define MMQ_Y_Q2_K_AMPERE 32
3914+
#define NWARPS_Q2_K_AMPERE 4
3915+
#else
38643916
#define MMQ_X_Q2_K_AMPERE 64
38653917
#define MMQ_Y_Q2_K_AMPERE 128
38663918
#define NWARPS_Q2_K_AMPERE 4
3919+
#endif
38673920
#define MMQ_X_Q2_K_PASCAL 64
38683921
#define MMQ_Y_Q2_K_PASCAL 64
38693922
#define NWARPS_Q2_K_PASCAL 8
@@ -3922,9 +3975,15 @@ mul_mat_q2_K(
39223975
#define MMQ_X_Q3_K_RDNA1 32
39233976
#define MMQ_Y_Q3_K_RDNA1 128
39243977
#define NWARPS_Q3_K_RDNA1 8
3978+
#if defined(CUDA_USE_TENSOR_CORES)
3979+
#define MMQ_X_Q3_K_AMPERE 4
3980+
#define MMQ_Y_Q3_K_AMPERE 32
3981+
#define NWARPS_Q3_K_AMPERE 4
3982+
#else
39253983
#define MMQ_X_Q3_K_AMPERE 128
39263984
#define MMQ_Y_Q3_K_AMPERE 128
39273985
#define NWARPS_Q3_K_AMPERE 4
3986+
#endif
39283987
#define MMQ_X_Q3_K_PASCAL 64
39293988
#define MMQ_Y_Q3_K_PASCAL 64
39303989
#define NWARPS_Q3_K_PASCAL 8
@@ -3985,9 +4044,15 @@ template <bool need_check> static __global__ void
39854044
#define MMQ_X_Q4_K_RDNA1 32
39864045
#define MMQ_Y_Q4_K_RDNA1 64
39874046
#define NWARPS_Q4_K_RDNA1 8
4047+
#if defined(CUDA_USE_TENSOR_CORES)
4048+
#define MMQ_X_Q4_K_AMPERE 4
4049+
#define MMQ_Y_Q4_K_AMPERE 32
4050+
#define NWARPS_Q4_K_AMPERE 4
4051+
#else
39884052
#define MMQ_X_Q4_K_AMPERE 64
39894053
#define MMQ_Y_Q4_K_AMPERE 128
39904054
#define NWARPS_Q4_K_AMPERE 4
4055+
#endif
39914056
#define MMQ_X_Q4_K_PASCAL 64
39924057
#define MMQ_Y_Q4_K_PASCAL 64
39934058
#define NWARPS_Q4_K_PASCAL 8
@@ -4048,9 +4113,15 @@ template <bool need_check> static __global__ void
40484113
#define MMQ_X_Q5_K_RDNA1 32
40494114
#define MMQ_Y_Q5_K_RDNA1 64
40504115
#define NWARPS_Q5_K_RDNA1 8
4116+
#if defined(CUDA_USE_TENSOR_CORES)
4117+
#define MMQ_X_Q5_K_AMPERE 4
4118+
#define MMQ_Y_Q5_K_AMPERE 32
4119+
#define NWARPS_Q5_K_AMPERE 4
4120+
#else
40514121
#define MMQ_X_Q5_K_AMPERE 64
40524122
#define MMQ_Y_Q5_K_AMPERE 128
40534123
#define NWARPS_Q5_K_AMPERE 4
4124+
#endif
40544125
#define MMQ_X_Q5_K_PASCAL 64
40554126
#define MMQ_Y_Q5_K_PASCAL 64
40564127
#define NWARPS_Q5_K_PASCAL 8
@@ -4109,9 +4180,15 @@ mul_mat_q5_K(
41094180
#define MMQ_X_Q6_K_RDNA1 32
41104181
#define MMQ_Y_Q6_K_RDNA1 64
41114182
#define NWARPS_Q6_K_RDNA1 8
4183+
#if defined(CUDA_USE_TENSOR_CORES)
4184+
#define MMQ_X_Q6_K_AMPERE 4
4185+
#define MMQ_Y_Q6_K_AMPERE 32
4186+
#define NWARPS_Q6_K_AMPERE 4
4187+
#else
41124188
#define MMQ_X_Q6_K_AMPERE 64
41134189
#define MMQ_Y_Q6_K_AMPERE 64
41144190
#define NWARPS_Q6_K_AMPERE 4
4191+
#endif
41154192
#define MMQ_X_Q6_K_PASCAL 64
41164193
#define MMQ_Y_Q6_K_PASCAL 64
41174194
#define NWARPS_Q6_K_PASCAL 8
@@ -5663,6 +5740,16 @@ void ggml_init_cublas() {
56635740
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
56645741
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
56655742
int64_t total_vram = 0;
5743+
#if defined(GGML_CUDA_FORCE_MMQ)
5744+
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
5745+
#else
5746+
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
5747+
#endif
5748+
#if defined(CUDA_USE_TENSOR_CORES)
5749+
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
5750+
#else
5751+
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
5752+
#endif
56665753
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
56675754
for (int id = 0; id < g_device_count; ++id) {
56685755
cudaDeviceProp prop;
@@ -6347,7 +6434,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
63476434
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
63486435
row_diff, src1_ncols, ne10,
63496436
&alpha, src0_ddf_i, ne00,
6350-
src1_ddf_i, ne10,
6437+
src1_ddf_i, ne10,
63516438
&beta, dst_dd_i, ldc));
63526439

63536440
if (src0_as != 0) {
@@ -7048,9 +7135,10 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
70487135
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
70497136
}
70507137

7051-
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
7138+
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
70527139
GGML_ASSERT(!ggml_is_transposed(src0));
70537140
GGML_ASSERT(!ggml_is_transposed(src1));
7141+
70547142
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
70557143
GGML_ASSERT(src0->type == GGML_TYPE_F16);
70567144
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -7202,17 +7290,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72027290
}
72037291

72047292
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7205-
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
7206-
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
7293+
const bool all_on_device =
7294+
(src0->backend == GGML_BACKEND_GPU) &&
7295+
(src1->backend == GGML_BACKEND_GPU) &&
7296+
( dst->backend == GGML_BACKEND_GPU);
72077297

72087298
int64_t min_compute_capability = INT_MAX;
72097299
for (int64_t id = 0; id < g_device_count; ++id) {
7210-
if (min_compute_capability > g_compute_capabilities[id]
7211-
&& g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
7300+
if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
72127301
min_compute_capability = g_compute_capabilities[id];
72137302
}
72147303
}
72157304

7305+
#ifdef CUDA_USE_TENSOR_CORES
7306+
const bool use_tensor_cores = true;
7307+
#else
7308+
const bool use_tensor_cores = false;
7309+
#endif
7310+
72167311
// debug helpers
72177312
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
72187313
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -7221,20 +7316,19 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
72217316
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
72227317
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
72237318

7224-
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
7319+
if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
72257320
// KQ single-batch
72267321
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
7227-
} else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
7322+
} else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
72287323
// KQV single-batch
72297324
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
7230-
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
7325+
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
72317326
// KQ + KQV multi-batch
72327327
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
72337328
} else if (src0->type == GGML_TYPE_F32) {
72347329
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
72357330
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
72367331
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
7237-
72387332
#ifdef GGML_CUDA_FORCE_DMMV
72397333
const bool use_mul_mat_vec_q = false;
72407334
#else
@@ -7247,7 +7341,15 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
72477341
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
72487342
}
72497343
} else {
7250-
if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
7344+
bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
7345+
7346+
// when tensor cores are available, use them for large batch size
7347+
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
7348+
if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) {
7349+
use_mul_mat_q = false;
7350+
}
7351+
7352+
if (use_mul_mat_q) {
72517353
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
72527354
} else {
72537355
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
@@ -7601,10 +7703,6 @@ void ggml_cuda_set_main_device(const int main_device) {
76017703
}
76027704
}
76037705

7604-
void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
7605-
g_mul_mat_q = mul_mat_q;
7606-
}
7607-
76087706
void ggml_cuda_set_scratch_size(const size_t scratch_size) {
76097707
// this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously
76107708
// it still won't always work as expected, but it's better than nothing

llama.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -5959,8 +5959,6 @@ static int llama_decode_internal(
59595959
}
59605960
}
59615961

5962-
ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);
5963-
59645962
// HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed
59655963
if (!lctx.embedding.empty()) {
59665964
embeddings->backend = GGML_BACKEND_CPU;

llama.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ extern "C" {
178178
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
179179

180180
// Keep the booleans together to avoid misalignment during copy-by-value.
181-
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
181+
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
182182
bool f16_kv; // use fp16 for KV cache, fp32 otherwise
183183
bool logits_all; // the llama_eval() call computes all logits, not just the last one
184184
bool embedding; // embedding mode only

0 commit comments

Comments
 (0)