From 59d1232ea76a7d760847244f9fffa3e75a590d32 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 10:26:58 +0300 Subject: [PATCH 01/23] cuda : prints wip --- ggml-cuda.cu | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ba0cd5a7d3f1e..7bbef0a1a7864 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6304,6 +6304,7 @@ inline void ggml_cuda_op_mul_mat_cublas( const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; + //printf("F16: row_diff: %ld, src1_ncols: %ld, ne10: %ld, ne00: %ld, ldc: %d\n", row_diff, src1_ncols, ne10, ne00, ldc); CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK( cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, @@ -7250,6 +7251,12 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); } else { + //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); + //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); + //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); + //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); + //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); + //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } } From 52af78260884013abc3f2fc3669e68f45ec2bfb5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 13:14:24 +0300 Subject: [PATCH 02/23] cuda : new cublas gemm branch for multi-batch quantized src0 --- ggml-cuda.cu | 115 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7bbef0a1a7864..ca49d73bfc4de 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6304,7 +6304,6 @@ inline void ggml_cuda_op_mul_mat_cublas( const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; - //printf("F16: row_diff: %ld, src1_ncols: %ld, ne10: %ld, ne00: %ld, ldc: %d\n", row_diff, src1_ncols, ne10, ne00, ldc); CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK( cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, @@ -7049,9 +7048,10 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } -static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -7202,6 +7202,115 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_cuda_pool_free(dst_f16, dst_as); } +static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // require tensor cores + const int compute_capability = g_compute_capabilities[id]; + GGML_ASSERT(compute_capability >= CC_VOLTA); + + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + //GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + + GGML_ASSERT(ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; GGML_UNUSED(ne02); + const int64_t ne03 = src0->ne[3]; GGML_UNUSED(ne03); + + const int64_t nb01 = src0->nb[1]; GGML_UNUSED(nb01); + const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); + const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; GGML_UNUSED(ne12); + const int64_t ne13 = src1->ne[3]; GGML_UNUSED(ne13); + + const int64_t nb11 = src1->nb[1]; GGML_UNUSED(nb11); + const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); + const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); + + const int64_t ne1 = ggml_nelements(src1); + const int64_t ne = ggml_nelements(dst); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + if (ggml_is_contiguous(src0)) { + // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 + half * src0_as_f16 = nullptr; + size_t src0_as = 0; + if (src0->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + const size_t ne = ne01*ne00; + src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); + to_fp16_cuda(src0_ddq, src0_as_f16, ne, main_stream); + } + + const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_ddq : src0_as_f16; + + half * src1_as_f16 = nullptr; + size_t src1_as = 0; + if (src1->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + const size_t ne = ne11*ne10; + src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf, src1_as_f16, ne, main_stream); + } + + const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf : src1_as_f16; + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne01*ne11 * sizeof(half), &dst_as); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + CUBLAS_CHECK( + cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16, CUDA_R_16F, ne01, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_ddf, ne01*ne11, main_stream); + + ggml_cuda_pool_free(dst_f16, dst_as); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f16, src0_as); + } + + if (src1_as != 0) { + ggml_cuda_pool_free(src1_as_f16, src1_as); + } + } else { + GGML_ASSERT(false && "not implemented"); + } +} + static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; @@ -7231,6 +7340,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); + } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 1) { + ggml_cuda_mul_mat_mat_deq_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { From 16b60dd75c8c89b726da5e9252454791fa1300b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 14:00:21 +0300 Subject: [PATCH 03/23] cuda : add F32 sgemm branch --- ggml-cuda.cu | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ca49d73bfc4de..75e0dddf90b1c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7252,7 +7252,8 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - if (ggml_is_contiguous(src0)) { +#if 0 + { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 half * src0_as_f16 = nullptr; size_t src0_as = 0; @@ -7306,9 +7307,40 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm if (src1_as != 0) { ggml_cuda_pool_free(src1_as_f16, src1_as); } - } else { - GGML_ASSERT(false && "not implemented"); } +#else + { + // convert src0 to fp32, multiply as fp32 + float * src0_as_f32 = nullptr; + size_t src0_as = 0; + if (src0->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + const size_t ne = ne01*ne00; + src0_as_f32 = (float *) ggml_cuda_pool_malloc(ne * sizeof(float), &src0_as); + to_fp32_cuda(src0_ddq, src0_as_f32, ne, main_stream); + } + + const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_ddq : src0_as_f32; + + const float * src1_ptr = (const float *) src1_ddf; + + const float alpha = 1.0f; + const float beta = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + CUBLAS_CHECK( + cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, src0_ptr, ne00, + src1_ptr, ne10, + &beta, dst_ddf, ne01)); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f32, src0_as); + } + } +#endif } static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From a3c28439d3c974db10092201e6228da709c801d0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 15:07:34 +0300 Subject: [PATCH 04/23] cuda : fine-tune >= VOLTA params + use MMQ only for small batches --- ggml-cuda.cu | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 75e0dddf90b1c..bafe080884073 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3554,8 +3554,8 @@ static __device__ __forceinline__ void mul_mat_q( #define MMQ_X_Q4_0_RDNA1 64 #define MMQ_Y_Q4_0_RDNA1 64 #define NWARPS_Q4_0_RDNA1 8 -#define MMQ_X_Q4_0_AMPERE 64 -#define MMQ_Y_Q4_0_AMPERE 128 +#define MMQ_X_Q4_0_AMPERE 4 +#define MMQ_Y_Q4_0_AMPERE 32 #define NWARPS_Q4_0_AMPERE 4 #define MMQ_X_Q4_0_PASCAL 64 #define MMQ_Y_Q4_0_PASCAL 64 @@ -3615,8 +3615,8 @@ template static __global__ void #define MMQ_X_Q4_1_RDNA1 64 #define MMQ_Y_Q4_1_RDNA1 64 #define NWARPS_Q4_1_RDNA1 8 -#define MMQ_X_Q4_1_AMPERE 64 -#define MMQ_Y_Q4_1_AMPERE 128 +#define MMQ_X_Q4_1_AMPERE 4 +#define MMQ_Y_Q4_1_AMPERE 32 #define NWARPS_Q4_1_AMPERE 4 #define MMQ_X_Q4_1_PASCAL 64 #define MMQ_Y_Q4_1_PASCAL 64 @@ -3678,8 +3678,8 @@ template static __global__ void #define MMQ_X_Q5_0_RDNA1 64 #define MMQ_Y_Q5_0_RDNA1 64 #define NWARPS_Q5_0_RDNA1 8 -#define MMQ_X_Q5_0_AMPERE 128 -#define MMQ_Y_Q5_0_AMPERE 64 +#define MMQ_X_Q5_0_AMPERE 4 +#define MMQ_Y_Q5_0_AMPERE 32 #define NWARPS_Q5_0_AMPERE 4 #define MMQ_X_Q5_0_PASCAL 64 #define MMQ_Y_Q5_0_PASCAL 64 @@ -3739,8 +3739,8 @@ template static __global__ void #define MMQ_X_Q5_1_RDNA1 64 #define MMQ_Y_Q5_1_RDNA1 64 #define NWARPS_Q5_1_RDNA1 8 -#define MMQ_X_Q5_1_AMPERE 128 -#define MMQ_Y_Q5_1_AMPERE 64 +#define MMQ_X_Q5_1_AMPERE 4 +#define MMQ_Y_Q5_1_AMPERE 32 #define NWARPS_Q5_1_AMPERE 4 #define MMQ_X_Q5_1_PASCAL 64 #define MMQ_Y_Q5_1_PASCAL 64 @@ -3800,8 +3800,8 @@ mul_mat_q5_1( #define MMQ_X_Q8_0_RDNA1 64 #define MMQ_Y_Q8_0_RDNA1 64 #define NWARPS_Q8_0_RDNA1 8 -#define MMQ_X_Q8_0_AMPERE 128 -#define MMQ_Y_Q8_0_AMPERE 64 +#define MMQ_X_Q8_0_AMPERE 4 +#define MMQ_Y_Q8_0_AMPERE 32 #define NWARPS_Q8_0_AMPERE 4 #define MMQ_X_Q8_0_PASCAL 64 #define MMQ_Y_Q8_0_PASCAL 64 @@ -3861,8 +3861,8 @@ template static __global__ void #define MMQ_X_Q2_K_RDNA1 128 #define MMQ_Y_Q2_K_RDNA1 32 #define NWARPS_Q2_K_RDNA1 8 -#define MMQ_X_Q2_K_AMPERE 64 -#define MMQ_Y_Q2_K_AMPERE 128 +#define MMQ_X_Q2_K_AMPERE 4 +#define MMQ_Y_Q2_K_AMPERE 32 #define NWARPS_Q2_K_AMPERE 4 #define MMQ_X_Q2_K_PASCAL 64 #define MMQ_Y_Q2_K_PASCAL 64 @@ -3922,8 +3922,8 @@ mul_mat_q2_K( #define MMQ_X_Q3_K_RDNA1 32 #define MMQ_Y_Q3_K_RDNA1 128 #define NWARPS_Q3_K_RDNA1 8 -#define MMQ_X_Q3_K_AMPERE 128 -#define MMQ_Y_Q3_K_AMPERE 128 +#define MMQ_X_Q3_K_AMPERE 4 +#define MMQ_Y_Q3_K_AMPERE 32 #define NWARPS_Q3_K_AMPERE 4 #define MMQ_X_Q3_K_PASCAL 64 #define MMQ_Y_Q3_K_PASCAL 64 @@ -3985,8 +3985,8 @@ template static __global__ void #define MMQ_X_Q4_K_RDNA1 32 #define MMQ_Y_Q4_K_RDNA1 64 #define NWARPS_Q4_K_RDNA1 8 -#define MMQ_X_Q4_K_AMPERE 64 -#define MMQ_Y_Q4_K_AMPERE 128 +#define MMQ_X_Q4_K_AMPERE 4 +#define MMQ_Y_Q4_K_AMPERE 32 #define NWARPS_Q4_K_AMPERE 4 #define MMQ_X_Q4_K_PASCAL 64 #define MMQ_Y_Q4_K_PASCAL 64 @@ -4048,8 +4048,8 @@ template static __global__ void #define MMQ_X_Q5_K_RDNA1 32 #define MMQ_Y_Q5_K_RDNA1 64 #define NWARPS_Q5_K_RDNA1 8 -#define MMQ_X_Q5_K_AMPERE 64 -#define MMQ_Y_Q5_K_AMPERE 128 +#define MMQ_X_Q5_K_AMPERE 4 +#define MMQ_Y_Q5_K_AMPERE 32 #define NWARPS_Q5_K_AMPERE 4 #define MMQ_X_Q5_K_PASCAL 64 #define MMQ_Y_Q5_K_PASCAL 64 @@ -4109,8 +4109,8 @@ mul_mat_q5_K( #define MMQ_X_Q6_K_RDNA1 32 #define MMQ_Y_Q6_K_RDNA1 64 #define NWARPS_Q6_K_RDNA1 8 -#define MMQ_X_Q6_K_AMPERE 64 -#define MMQ_Y_Q6_K_AMPERE 64 +#define MMQ_X_Q6_K_AMPERE 4 +#define MMQ_Y_Q6_K_AMPERE 32 #define NWARPS_Q6_K_AMPERE 4 #define MMQ_X_Q6_K_PASCAL 64 #define MMQ_Y_Q6_K_PASCAL 64 @@ -7252,7 +7252,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; -#if 0 +#if 1 { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 half * src0_as_f16 = nullptr; @@ -7309,6 +7309,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm } } #else + // NOTE: this seems faster for tiny models and small batch-size { // convert src0 to fp32, multiply as fp32 float * src0_as_f32 = nullptr; @@ -7372,7 +7373,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); - } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 1) { + } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 32) { + // F16 and quantized src0 + high-batch src1 ggml_cuda_mul_mat_mat_deq_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); From 4c6744b526106bdd96c2fa54da2a8dff5da05240 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 18:25:13 +0300 Subject: [PATCH 05/23] cuda : remove duplicated cuBLAS GEMM code --- ggml-cuda.cu | 169 +++++---------------------------------------------- 1 file changed, 15 insertions(+), 154 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bafe080884073..27824f47f1ae4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7202,151 +7202,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_cuda_pool_free(dst_f16, dst_as); } -static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); - - // require tensor cores - const int compute_capability = g_compute_capabilities[id]; - GGML_ASSERT(compute_capability >= CC_VOLTA); - - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - //GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); - - GGML_ASSERT(ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; GGML_UNUSED(ne02); - const int64_t ne03 = src0->ne[3]; GGML_UNUSED(ne03); - - const int64_t nb01 = src0->nb[1]; GGML_UNUSED(nb01); - const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); - const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; GGML_UNUSED(ne12); - const int64_t ne13 = src1->ne[3]; GGML_UNUSED(ne13); - - const int64_t nb11 = src1->nb[1]; GGML_UNUSED(nb11); - const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); - const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); - - const int64_t ne1 = ggml_nelements(src1); - const int64_t ne = ggml_nelements(dst); - - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); - - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - void * src0_ddq = src0_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - -#if 1 - { - // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 - half * src0_as_f16 = nullptr; - size_t src0_as = 0; - if (src0->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - const size_t ne = ne01*ne00; - src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); - to_fp16_cuda(src0_ddq, src0_as_f16, ne, main_stream); - } - - const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_ddq : src0_as_f16; - - half * src1_as_f16 = nullptr; - size_t src1_as = 0; - if (src1->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - const size_t ne = ne11*ne10; - src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); - to_fp16_cuda(src1_ddf, src1_as_f16, ne, main_stream); - } - - const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf : src1_as_f16; - - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne01*ne11 * sizeof(half), &dst_as); - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); - CUBLAS_CHECK( - cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16, CUDA_R_16F, ne01, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16, dst_ddf, ne01*ne11, main_stream); - - ggml_cuda_pool_free(dst_f16, dst_as); - - if (src0_as != 0) { - ggml_cuda_pool_free(src0_as_f16, src0_as); - } - - if (src1_as != 0) { - ggml_cuda_pool_free(src1_as_f16, src1_as); - } - } -#else - // NOTE: this seems faster for tiny models and small batch-size - { - // convert src0 to fp32, multiply as fp32 - float * src0_as_f32 = nullptr; - size_t src0_as = 0; - if (src0->type != GGML_TYPE_F32) { - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); - GGML_ASSERT(to_fp32_cuda != nullptr); - const size_t ne = ne01*ne00; - src0_as_f32 = (float *) ggml_cuda_pool_malloc(ne * sizeof(float), &src0_as); - to_fp32_cuda(src0_ddq, src0_as_f32, ne, main_stream); - } - - const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_ddq : src0_as_f32; - - const float * src1_ptr = (const float *) src1_ddf; - - const float alpha = 1.0f; - const float beta = 0.0f; - - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); - CUBLAS_CHECK( - cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha, src0_ptr, ne00, - src1_ptr, ne10, - &beta, dst_ddf, ne01)); - - if (src0_as != 0) { - ggml_cuda_pool_free(src0_as_f32, src0_as); - } - } -#endif -} - static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && - src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; + const bool all_on_device = + (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && + (src1->backend == GGML_BACKEND_GPU) && + ( dst->backend == GGML_BACKEND_GPU); int64_t min_compute_capability = INT_MAX; for (int64_t id = 0; id < g_device_count; ++id) { @@ -7373,9 +7233,6 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); - } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 32) { - // F16 and quantized src0 + high-batch src1 - ggml_cuda_mul_mat_mat_deq_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { @@ -7393,15 +7250,19 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); } } else { - if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { + // ref: https://github.com/ggerganov/llama.cpp/pull/3776 + bool use_mul_mat_q = g_mul_mat_q && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); + + // TODO: better way to determine availability of tensor cores + // currently fails for GeForce GTX 1660 which is TURING arch but does not have tensor cores + if (min_compute_capability >= CC_VOLTA && src1->ne[1] > 32) { + // when tensor cores are available, use them for large batch size + use_mul_mat_q = false; + } + + if (use_mul_mat_q) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); } else { - //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); - //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); - //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); - //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); - //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); - //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } } From a4e15a36e4cd7120945eef560e591efaeb5fbd2b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 18:48:36 +0300 Subject: [PATCH 06/23] cuda : add CUDA_USE_TENSOR_CORES and GGML_CUDA_FORCE_MMQ macros --- ggml-cuda.cu | 122 +++++++++++++++++++++++++++++++++++++++++++-------- llama.cpp | 2 - llama.h | 2 +- 3 files changed, 104 insertions(+), 22 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 27824f47f1ae4..558cd0bd8861a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -87,6 +87,23 @@ #define CC_OFFSET_AMD 1000000 #define CC_RDNA2 (CC_OFFSET_AMD + 1030) +// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication +// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant +// for large computational tasks. the drawback is that this requires some extra amount of VRAM: +// - 7B quantum model: +100-200 MB +// - 13B quantum model: +200-400 MB +//#define GGML_CUDA_FORCE_MMQ + +// TODO: improve this to be correct for more hardware +// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores +// probably other such cases, and not sure what happens on AMD hardware +#if !defined(GGML_CUDA_FORCE_MMQ) +#define CUDA_USE_TENSOR_CORES +#endif + +// max batch size to use MMQ kernels when tensor cores are available +#define MMQ_MAX_BATCH_SIZE 32 + #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -470,7 +487,6 @@ static int g_device_count = -1; static int g_main_device = 0; static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; -static bool g_mul_mat_q = true; static void * g_scratch_buffer = nullptr; static size_t g_scratch_size = 0; // disabled by default @@ -3554,9 +3570,15 @@ static __device__ __forceinline__ void mul_mat_q( #define MMQ_X_Q4_0_RDNA1 64 #define MMQ_Y_Q4_0_RDNA1 64 #define NWARPS_Q4_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q4_0_AMPERE 4 #define MMQ_Y_Q4_0_AMPERE 32 #define NWARPS_Q4_0_AMPERE 4 +#else +#define MMQ_X_Q4_0_AMPERE 64 +#define MMQ_Y_Q4_0_AMPERE 128 +#define NWARPS_Q4_0_AMPERE 4 +#endif #define MMQ_X_Q4_0_PASCAL 64 #define MMQ_Y_Q4_0_PASCAL 64 #define NWARPS_Q4_0_PASCAL 8 @@ -3615,9 +3637,15 @@ template static __global__ void #define MMQ_X_Q4_1_RDNA1 64 #define MMQ_Y_Q4_1_RDNA1 64 #define NWARPS_Q4_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q4_1_AMPERE 4 #define MMQ_Y_Q4_1_AMPERE 32 #define NWARPS_Q4_1_AMPERE 4 +#else +#define MMQ_X_Q4_1_AMPERE 64 +#define MMQ_Y_Q4_1_AMPERE 128 +#define NWARPS_Q4_1_AMPERE 4 +#endif #define MMQ_X_Q4_1_PASCAL 64 #define MMQ_Y_Q4_1_PASCAL 64 #define NWARPS_Q4_1_PASCAL 8 @@ -3678,9 +3706,15 @@ template static __global__ void #define MMQ_X_Q5_0_RDNA1 64 #define MMQ_Y_Q5_0_RDNA1 64 #define NWARPS_Q5_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q5_0_AMPERE 4 #define MMQ_Y_Q5_0_AMPERE 32 #define NWARPS_Q5_0_AMPERE 4 +#else +#define MMQ_X_Q5_0_AMPERE 128 +#define MMQ_Y_Q5_0_AMPERE 64 +#define NWARPS_Q5_0_AMPERE 4 +#endif #define MMQ_X_Q5_0_PASCAL 64 #define MMQ_Y_Q5_0_PASCAL 64 #define NWARPS_Q5_0_PASCAL 8 @@ -3739,9 +3773,15 @@ template static __global__ void #define MMQ_X_Q5_1_RDNA1 64 #define MMQ_Y_Q5_1_RDNA1 64 #define NWARPS_Q5_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q5_1_AMPERE 4 #define MMQ_Y_Q5_1_AMPERE 32 #define NWARPS_Q5_1_AMPERE 4 +#else +#define MMQ_X_Q5_1_AMPERE 128 +#define MMQ_Y_Q5_1_AMPERE 64 +#define NWARPS_Q5_1_AMPERE 4 +#endif #define MMQ_X_Q5_1_PASCAL 64 #define MMQ_Y_Q5_1_PASCAL 64 #define NWARPS_Q5_1_PASCAL 8 @@ -3800,9 +3840,15 @@ mul_mat_q5_1( #define MMQ_X_Q8_0_RDNA1 64 #define MMQ_Y_Q8_0_RDNA1 64 #define NWARPS_Q8_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q8_0_AMPERE 4 #define MMQ_Y_Q8_0_AMPERE 32 #define NWARPS_Q8_0_AMPERE 4 +#else +#define MMQ_X_Q8_0_AMPERE 128 +#define MMQ_Y_Q8_0_AMPERE 64 +#define NWARPS_Q8_0_AMPERE 4 +#endif #define MMQ_X_Q8_0_PASCAL 64 #define MMQ_Y_Q8_0_PASCAL 64 #define NWARPS_Q8_0_PASCAL 8 @@ -3861,9 +3907,15 @@ template static __global__ void #define MMQ_X_Q2_K_RDNA1 128 #define MMQ_Y_Q2_K_RDNA1 32 #define NWARPS_Q2_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q2_K_AMPERE 4 #define MMQ_Y_Q2_K_AMPERE 32 #define NWARPS_Q2_K_AMPERE 4 +#else +#define MMQ_X_Q2_K_AMPERE 64 +#define MMQ_Y_Q2_K_AMPERE 128 +#define NWARPS_Q2_K_AMPERE 4 +#endif #define MMQ_X_Q2_K_PASCAL 64 #define MMQ_Y_Q2_K_PASCAL 64 #define NWARPS_Q2_K_PASCAL 8 @@ -3922,9 +3974,15 @@ mul_mat_q2_K( #define MMQ_X_Q3_K_RDNA1 32 #define MMQ_Y_Q3_K_RDNA1 128 #define NWARPS_Q3_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q3_K_AMPERE 4 #define MMQ_Y_Q3_K_AMPERE 32 #define NWARPS_Q3_K_AMPERE 4 +#else +#define MMQ_X_Q3_K_AMPERE 128 +#define MMQ_Y_Q3_K_AMPERE 128 +#define NWARPS_Q3_K_AMPERE 4 +#endif #define MMQ_X_Q3_K_PASCAL 64 #define MMQ_Y_Q3_K_PASCAL 64 #define NWARPS_Q3_K_PASCAL 8 @@ -3985,9 +4043,15 @@ template static __global__ void #define MMQ_X_Q4_K_RDNA1 32 #define MMQ_Y_Q4_K_RDNA1 64 #define NWARPS_Q4_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q4_K_AMPERE 4 #define MMQ_Y_Q4_K_AMPERE 32 #define NWARPS_Q4_K_AMPERE 4 +#else +#define MMQ_X_Q4_K_AMPERE 64 +#define MMQ_Y_Q4_K_AMPERE 128 +#define NWARPS_Q4_K_AMPERE 4 +#endif #define MMQ_X_Q4_K_PASCAL 64 #define MMQ_Y_Q4_K_PASCAL 64 #define NWARPS_Q4_K_PASCAL 8 @@ -4048,9 +4112,15 @@ template static __global__ void #define MMQ_X_Q5_K_RDNA1 32 #define MMQ_Y_Q5_K_RDNA1 64 #define NWARPS_Q5_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q5_K_AMPERE 4 #define MMQ_Y_Q5_K_AMPERE 32 #define NWARPS_Q5_K_AMPERE 4 +#else +#define MMQ_X_Q5_K_AMPERE 64 +#define MMQ_Y_Q5_K_AMPERE 128 +#define NWARPS_Q5_K_AMPERE 4 +#endif #define MMQ_X_Q5_K_PASCAL 64 #define MMQ_Y_Q5_K_PASCAL 64 #define NWARPS_Q5_K_PASCAL 8 @@ -4109,9 +4179,15 @@ mul_mat_q5_K( #define MMQ_X_Q6_K_RDNA1 32 #define MMQ_Y_Q6_K_RDNA1 64 #define NWARPS_Q6_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q6_K_AMPERE 4 #define MMQ_Y_Q6_K_AMPERE 32 #define NWARPS_Q6_K_AMPERE 4 +#else +#define MMQ_X_Q6_K_AMPERE 64 +#define MMQ_Y_Q6_K_AMPERE 64 +#define NWARPS_Q6_K_AMPERE 4 +#endif #define MMQ_X_Q6_K_PASCAL 64 #define MMQ_Y_Q6_K_PASCAL 64 #define NWARPS_Q6_K_PASCAL 8 @@ -5663,6 +5739,16 @@ void ggml_init_cublas() { CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; +#if defined(GGML_CUDA_FORCE_MMQ) + fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); +#else + fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); +#endif +#if defined(CUDA_USE_TENSOR_CORES) + fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); +#else + fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__); +#endif fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); for (int id = 0; id < g_device_count; ++id) { cudaDeviceProp prop; @@ -6347,7 +6433,7 @@ inline void ggml_cuda_op_mul_mat_cublas( cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, &alpha, src0_ddf_i, ne00, - src1_ddf_i, ne10, + src1_ddf_i, ne10, &beta, dst_dd_i, ldc)); if (src0_as != 0) { @@ -7204,18 +7290,23 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = - (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && + (src0->backend == GGML_BACKEND_GPU) && (src1->backend == GGML_BACKEND_GPU) && ( dst->backend == GGML_BACKEND_GPU); int64_t min_compute_capability = INT_MAX; for (int64_t id = 0; id < g_device_count; ++id) { - if (min_compute_capability > g_compute_capabilities[id] - && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { + if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { min_compute_capability = g_compute_capabilities[id]; } } +#ifdef CUDA_USE_TENSOR_CORES + const bool use_tensor_cores = true; +#else + const bool use_tensor_cores = false; +#endif + // debug helpers //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); @@ -7224,20 +7315,19 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // KQ single-batch ggml_cuda_mul_mat_vec_p021(src0, src1, dst); - } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_cuda_mul_mat_vec_nc(src0, src1, dst); - } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { - #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; #else @@ -7250,13 +7340,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); } } else { - // ref: https://github.com/ggerganov/llama.cpp/pull/3776 - bool use_mul_mat_q = g_mul_mat_q && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); + bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); - // TODO: better way to determine availability of tensor cores - // currently fails for GeForce GTX 1660 which is TURING arch but does not have tensor cores - if (min_compute_capability >= CC_VOLTA && src1->ne[1] > 32) { - // when tensor cores are available, use them for large batch size + // when tensor cores are available, use them for large batch size + // ref: https://github.com/ggerganov/llama.cpp/pull/3776 + if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) { use_mul_mat_q = false; } @@ -7614,10 +7702,6 @@ void ggml_cuda_set_main_device(const int main_device) { } } -void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) { - g_mul_mat_q = mul_mat_q; -} - void ggml_cuda_set_scratch_size(const size_t scratch_size) { // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously // it still won't always work as expected, but it's better than nothing diff --git a/llama.cpp b/llama.cpp index 61f30c3982f18..cc8669b0e9e23 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5959,8 +5959,6 @@ static int llama_decode_internal( } } - ggml_cuda_set_mul_mat_q(cparams.mul_mat_q); - // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed if (!lctx.embedding.empty()) { embeddings->backend = GGML_BACKEND_CPU; diff --git a/llama.h b/llama.h index 2f2fee0e2ff9f..beac9a0cedd76 100644 --- a/llama.h +++ b/llama.h @@ -178,7 +178,7 @@ extern "C" { float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model // Keep the booleans together to avoid misalignment during copy-by-value. - bool mul_mat_q; // if true, use experimental mul_mat_q kernels + bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) bool f16_kv; // use fp16 for KV cache, fp32 otherwise bool logits_all; // the llama_eval() call computes all logits, not just the last one bool embedding; // embedding mode only From 49af767fadfc44dd079d49038089b4ee99d77e0c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Oct 2023 13:21:04 +0300 Subject: [PATCH 07/23] build : add compile option to force use of MMQ kernels --- CMakeLists.txt | 7 +++++++ Makefile | 3 +++ ggml-cuda.cu | 1 + 3 files changed, 11 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 202f260491d39..d9fc86237b15c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,6 +82,7 @@ set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUBLAS "llama: use CUDA" OFF) #option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) +option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) @@ -305,6 +306,9 @@ if (LLAMA_CUBLAS) if (LLAMA_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV) endif() + if (LLAMA_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) if (DEFINED LLAMA_CUDA_DMMV_Y) @@ -405,6 +409,9 @@ if (LLAMA_HIPBLAS) if (LLAMA_CUDA_FORCE_DMMV) target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) endif() + if (LLAMA_CUDA_FORCE_MMQ) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_MMQ) + endif() target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) diff --git a/Makefile b/Makefile index 80179631f95a5..68069f9ff331e 100644 --- a/Makefile +++ b/Makefile @@ -397,6 +397,9 @@ endif # CUDA_DOCKER_ARCH ifdef LLAMA_CUDA_FORCE_DMMV NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV endif # LLAMA_CUDA_FORCE_DMMV +ifdef LLAMA_CUDA_FORCE_MMQ + NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ +endif # LLAMA_CUDA_FORCE_MMQ ifdef LLAMA_CUDA_DMMV_X NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) else diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 558cd0bd8861a..1ba951f688d82 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -92,6 +92,7 @@ // for large computational tasks. the drawback is that this requires some extra amount of VRAM: // - 7B quantum model: +100-200 MB // - 13B quantum model: +200-400 MB +// //#define GGML_CUDA_FORCE_MMQ // TODO: improve this to be correct for more hardware From a9e2b74f1a0f4ca195fb087dbd2b6377f7f9025e Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 17:23:06 -0500 Subject: [PATCH 08/23] Super hacky starting implementation of Min P --- llama.cpp | 104 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/llama.cpp b/llama.cpp index 3d431ee7bf526..47b411bd6a079 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7328,11 +7328,115 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can } } +void read_or_write_base_min_p(float* base_min_p) { + // Define the filename + const std::string filename = "SamplerBaseMinP.txt"; + + // Check if the file exists + std::ifstream infile(filename); + if (!infile.good()) { + // File doesn't exist, create it with default values + std::ofstream outfile(filename); + if (outfile.is_open()) { + // Set the default value for base_min_p + *base_min_p = 0.05f; + outfile << "base_min_p = " << *base_min_p << "\n"; + outfile.close(); + } else { + // Handle the error during file opening or writing here (optional) + // For example, you might want to set a safe default value or notify the user + } + } else { + // File exists, read the values from it + std::string line; + while (getline(infile, line)) { + std::istringstream iss(line); + std::string key; + float value; + char equals; // Using char to read the '=' character + + // Read each line in the format key = value + if (iss >> key >> equals >> value && equals == '=' && key == "base_min_p") { + *base_min_p = value; + } + // Note: If the key doesn't match, or the format is incorrect, + // you might want to handle the error or set a default value + } + infile.close(); + } +} + void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + // Variables for the special mode + float base_min_p; // This will hold the base minimum probability value + float multiplied_min_p; // This will hold the adjusted minimum probability threshold + + // If p is 1.0, we switch to a different sampling mode. if (p >= 1.0f) { + printf("returning, top p set to 1.0"); return; } + // If p is ~0.02, we switch to the Min P sampler + if (p >= 0.01f && p <= 0.03f) { + printf("USING MIN P SAMPLING MODE\n"); + + // Ensure the probabilities are calculated. + llama_sample_softmax(ctx, candidates); + + // Print the top tokens before filtering + printf("Top tokens before filtering:\n"); + for (size_t i = 0; i < candidates->size && i < 10; ++i) { + printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage + } + + base_min_p = 0.05; // For example, 5% as the base minimum probability before reading the text value + + read_or_write_base_min_p(&base_min_p); + + // Calculate the multiplication factor based on the highest scoring token. + float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted + printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor); + + // Calculate the dynamic threshold. + multiplied_min_p = base_min_p * multiplication_factor; + printf("Base min_p value: %f\n", base_min_p); + printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p); + + // Store the tokens that meet the threshold in a new list. + std::vector filtered_candidates; + filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations + + // Variable to count how many tokens meet the condition + int count_qualifying_tokens = 0; + + for (size_t i = 0; i < candidates->size; ++i) { + // If a token's probability is above the threshold, we keep it. + if (candidates->data[i].p >= multiplied_min_p) { + filtered_candidates.push_back(candidates->data[i]); + ++count_qualifying_tokens; // Increase count + } + } + + // Debug information about how many tokens were retained + printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens); + + llama_sample_softmax(ctx, candidates); // re-normalize after pruning + + // Print the top tokens after filtering + printf("Tokens after filtering:\n"); + for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display + printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage + } + + // Now we replace the original candidates with the filtered list. + std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data); + candidates->size = filtered_candidates.size(); + + // Since we're not actually sampling below a certain 'p', we return from the function after this. + return; + } + llama_sample_softmax(ctx, candidates); const int64_t t_start_sample_us = ggml_time_us(); From a235a0d2263de70eda9db96480313308ecabb627 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 20:49:17 -0500 Subject: [PATCH 09/23] Transform Min P into a proper CLI option --- common/common.cpp | 8 +++ common/sampling.cpp | 6 +- common/sampling.h | 1 + llama.cpp | 159 +++++++++++++++----------------------------- llama.h | 7 ++ 5 files changed, 75 insertions(+), 106 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index f81f4d354bc01..b7ed4827d08a1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -218,6 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } sparams.top_p = std::stof(argv[i]); + } else if (arg == "--min-p") { // Adding min_p argument + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.min_p = std::stof(argv[i]); // Parsing and setting the min_p value from command line } else if (arg == "--temp") { if (++i >= argc) { invalid_param = true; @@ -679,6 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); + printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); @@ -1275,6 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); + fprintf(stream, "min_p: %f # default: 0.05\n", sparams.min_p); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); } diff --git a/common/sampling.cpp b/common/sampling.cpp index c4996c9857d8a..673d67a6d5380 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -89,10 +89,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" - "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, - params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, + params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); return std::string(result); @@ -110,6 +110,7 @@ llama_token llama_sampling_sample( const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; + const float min_p = params.min_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; @@ -190,6 +191,7 @@ llama_token llama_sampling_sample( llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); + llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); llama_sample_temp (ctx_main, &cur_p, temp); id = llama_sample_token(ctx_main, &cur_p); diff --git a/common/sampling.h b/common/sampling.h index 62ea6d4cfb7e5..7c9b8dcf23bcb 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -14,6 +14,7 @@ typedef struct llama_sampling_params { int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled diff --git a/llama.cpp b/llama.cpp index 47b411bd6a079..f926b99f5bd2e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7328,112 +7328,8 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can } } -void read_or_write_base_min_p(float* base_min_p) { - // Define the filename - const std::string filename = "SamplerBaseMinP.txt"; - - // Check if the file exists - std::ifstream infile(filename); - if (!infile.good()) { - // File doesn't exist, create it with default values - std::ofstream outfile(filename); - if (outfile.is_open()) { - // Set the default value for base_min_p - *base_min_p = 0.05f; - outfile << "base_min_p = " << *base_min_p << "\n"; - outfile.close(); - } else { - // Handle the error during file opening or writing here (optional) - // For example, you might want to set a safe default value or notify the user - } - } else { - // File exists, read the values from it - std::string line; - while (getline(infile, line)) { - std::istringstream iss(line); - std::string key; - float value; - char equals; // Using char to read the '=' character - - // Read each line in the format key = value - if (iss >> key >> equals >> value && equals == '=' && key == "base_min_p") { - *base_min_p = value; - } - // Note: If the key doesn't match, or the format is incorrect, - // you might want to handle the error or set a default value - } - infile.close(); - } -} - void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - // Variables for the special mode - float base_min_p; // This will hold the base minimum probability value - float multiplied_min_p; // This will hold the adjusted minimum probability threshold - - // If p is 1.0, we switch to a different sampling mode. if (p >= 1.0f) { - printf("returning, top p set to 1.0"); - return; - } - - // If p is ~0.02, we switch to the Min P sampler - if (p >= 0.01f && p <= 0.03f) { - printf("USING MIN P SAMPLING MODE\n"); - - // Ensure the probabilities are calculated. - llama_sample_softmax(ctx, candidates); - - // Print the top tokens before filtering - printf("Top tokens before filtering:\n"); - for (size_t i = 0; i < candidates->size && i < 10; ++i) { - printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage - } - - base_min_p = 0.05; // For example, 5% as the base minimum probability before reading the text value - - read_or_write_base_min_p(&base_min_p); - - // Calculate the multiplication factor based on the highest scoring token. - float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted - printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor); - - // Calculate the dynamic threshold. - multiplied_min_p = base_min_p * multiplication_factor; - printf("Base min_p value: %f\n", base_min_p); - printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p); - - // Store the tokens that meet the threshold in a new list. - std::vector filtered_candidates; - filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations - - // Variable to count how many tokens meet the condition - int count_qualifying_tokens = 0; - - for (size_t i = 0; i < candidates->size; ++i) { - // If a token's probability is above the threshold, we keep it. - if (candidates->data[i].p >= multiplied_min_p) { - filtered_candidates.push_back(candidates->data[i]); - ++count_qualifying_tokens; // Increase count - } - } - - // Debug information about how many tokens were retained - printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens); - - llama_sample_softmax(ctx, candidates); // re-normalize after pruning - - // Print the top tokens after filtering - printf("Tokens after filtering:\n"); - for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display - printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage - } - - // Now we replace the original candidates with the filtered list. - std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data); - candidates->size = filtered_candidates.size(); - - // Since we're not actually sampling below a certain 'p', we return from the function after this. return; } @@ -7464,6 +7360,61 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can } } +void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + float base_min_p = p; // This will hold the base minimum probability value + float multiplied_min_p; // This will hold the adjusted minimum probability threshold + + printf("\nUSING MIN P SAMPLING MODE\n\n"); + + // Ensure the probabilities are calculated. + llama_sample_softmax(ctx, candidates); + + // Print the top tokens before filtering + printf("Top tokens before filtering:\n"); + for (size_t i = 0; i < candidates->size && i < 10; ++i) { + printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage + } + + // Calculate the multiplication factor based on the highest scoring token. + float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted + printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor); + + // Calculate the dynamic threshold. + multiplied_min_p = base_min_p * multiplication_factor; + printf("Base min_p value: %f\n", base_min_p); + printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p); + + // Store the tokens that meet the threshold in a new list. + std::vector filtered_candidates; + filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations + + // Variable to count how many tokens meet the condition + int count_qualifying_tokens = 0; + + for (size_t i = 0; i < candidates->size; ++i) { + // If a token's probability is above the threshold, we keep it. + if (candidates->data[i].p >= multiplied_min_p) { + filtered_candidates.push_back(candidates->data[i]); + ++count_qualifying_tokens; // Increase count + } + } + + // Debug information about how many tokens were retained + printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens); + + // Print the top tokens after filtering + printf("Tokens after filtering:\n\n"); + for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display + printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage + } + + // Now we replace the original candidates with the filtered list. + std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data); + candidates->size = filtered_candidates.size(); + + return; +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; diff --git a/llama.h b/llama.h index d901dcd9116d3..0ec28eb14d31e 100644 --- a/llama.h +++ b/llama.h @@ -600,6 +600,13 @@ extern "C" { float p, size_t min_keep); + /// @details Minimum P sampling by Kalomaze + LLAMA_API void llama_sample_min_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx, From 838d58dc3273c22ffa5f75633f4b9be7f09bb2d3 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 21:08:26 -0500 Subject: [PATCH 10/23] Min P disabled if set to 1.0 or 0, otherwise Top P --- common/common.cpp | 2 +- common/sampling.cpp | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b7ed4827d08a1..fe15d03a9cc4b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -685,7 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); - printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); + printf(" --min-p N min-p sampling (default: %.2f, 1.0 = disabled)\n", (double)sparams.min_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); diff --git a/common/sampling.cpp b/common/sampling.cpp index 673d67a6d5380..f5507dfca60ac 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -190,8 +190,11 @@ llama_token llama_sampling_sample( llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); - llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); - llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); + if (min_p != 1.0 && min_p != 0.0) { + llama_sample_min_p(ctx_main, &cur_p, min_p, min_keep); + } else { + llama_sample_top_p(ctx_main, &cur_p, top_p, min_keep); + } llama_sample_temp (ctx_main, &cur_p, temp); id = llama_sample_token(ctx_main, &cur_p); From 69ef4ca885ac96a998ea1806bc54818b73f69698 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 21:14:55 -0500 Subject: [PATCH 11/23] Debugging print statements removed --- llama.cpp | 28 ++-------------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/llama.cpp b/llama.cpp index f926b99f5bd2e..e742a1406751b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7364,22 +7364,13 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can float base_min_p = p; // This will hold the base minimum probability value float multiplied_min_p; // This will hold the adjusted minimum probability threshold - printf("\nUSING MIN P SAMPLING MODE\n\n"); - // Ensure the probabilities are calculated. llama_sample_softmax(ctx, candidates); - // Print the top tokens before filtering - printf("Top tokens before filtering:\n"); - for (size_t i = 0; i < candidates->size && i < 10; ++i) { - printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage - } - // Calculate the multiplication factor based on the highest scoring token. - float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted - printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor); + float multiplication_factor = candidates->data[0].p; - // Calculate the dynamic threshold. + // Calculate the minimum percentage requirement. multiplied_min_p = base_min_p * multiplication_factor; printf("Base min_p value: %f\n", base_min_p); printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p); @@ -7388,31 +7379,16 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can std::vector filtered_candidates; filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations - // Variable to count how many tokens meet the condition - int count_qualifying_tokens = 0; - for (size_t i = 0; i < candidates->size; ++i) { // If a token's probability is above the threshold, we keep it. if (candidates->data[i].p >= multiplied_min_p) { filtered_candidates.push_back(candidates->data[i]); - ++count_qualifying_tokens; // Increase count } } - // Debug information about how many tokens were retained - printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens); - - // Print the top tokens after filtering - printf("Tokens after filtering:\n\n"); - for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display - printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage - } - // Now we replace the original candidates with the filtered list. std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data); candidates->size = filtered_candidates.size(); - - return; } void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { From 833637b703c3719ff5d1ef94db5c82b38941193e Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 22:05:05 -0500 Subject: [PATCH 12/23] erring on the side of caution; disable by default --- common/sampling.h | 2 +- llama.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/sampling.h b/common/sampling.h index 7c9b8dcf23bcb..bbb9ffa077b6f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -14,7 +14,7 @@ typedef struct llama_sampling_params { int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled + float min_p = 1.0f; // 1.0 (or 0.0) = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled diff --git a/llama.h b/llama.h index 0ec28eb14d31e..62addffdfc693 100644 --- a/llama.h +++ b/llama.h @@ -600,7 +600,7 @@ extern "C" { float p, size_t min_keep); - /// @details Minimum P sampling by Kalomaze + /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841#issue-1966758357 LLAMA_API void llama_sample_min_p( struct llama_context * ctx, llama_token_data_array * candidates, From 62fc77153b968f12d1da1e67193d94f4c591f873 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 23:04:29 -0500 Subject: [PATCH 13/23] Remove accidentally kept prints + min_keep support --- llama.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index e742a1406751b..d311f546189c1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7372,17 +7372,30 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can // Calculate the minimum percentage requirement. multiplied_min_p = base_min_p * multiplication_factor; - printf("Base min_p value: %f\n", base_min_p); - printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p); // Store the tokens that meet the threshold in a new list. std::vector filtered_candidates; filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations + size_t kept_count = 0; // Counter for how many tokens are kept + for (size_t i = 0; i < candidates->size; ++i) { // If a token's probability is above the threshold, we keep it. if (candidates->data[i].p >= multiplied_min_p) { filtered_candidates.push_back(candidates->data[i]); + kept_count++; // Increment the counter + } + } + + // If not enough candidates meet the threshold, take the top 'min_keep' ones + if (kept_count < min_keep) { + std::sort(candidates->data, candidates->data + candidates->size, + [](const llama_token_data & a, const llama_token_data & b) { + return a.p > b.p; // Sort by probability in descending order + }); + filtered_candidates.clear(); // Clear the previously filtered candidates + for (size_t i = 0; i < min_keep; ++i) { + filtered_candidates.push_back(candidates->data[i]); } } From 49b68e8226b15475cceb04872feb3649fb034532 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 23:12:14 -0500 Subject: [PATCH 14/23] Standardize 0.0 disabling min_p upon feedback --- common/common.cpp | 4 ++-- common/sampling.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index fe15d03a9cc4b..5ff8c579de491 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -685,7 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); - printf(" --min-p N min-p sampling (default: %.2f, 1.0 = disabled)\n", (double)sparams.min_p); + printf(" --min-p N min-p sampling (default: %.2f, 0.0 = disabled)\n", (double)sparams.min_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); @@ -1282,7 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); - fprintf(stream, "min_p: %f # default: 0.05\n", sparams.min_p); + fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); } diff --git a/common/sampling.cpp b/common/sampling.cpp index f5507dfca60ac..6c4e6bc1fa144 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -190,7 +190,7 @@ llama_token llama_sampling_sample( llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); - if (min_p != 1.0 && min_p != 0.0) { + if (min_p != 0.0) { llama_sample_min_p(ctx_main, &cur_p, min_p, min_keep); } else { llama_sample_top_p(ctx_main, &cur_p, top_p, min_keep); From 6f7cdec38a260dc6fa07b8821bc27671bb1b40d7 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 23:37:18 -0500 Subject: [PATCH 15/23] Simplified counter by checking candidates size + fixed 0.0 default for min_p --- common/sampling.h | 2 +- llama.cpp | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/common/sampling.h b/common/sampling.h index bbb9ffa077b6f..97f2c170a474b 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -14,7 +14,7 @@ typedef struct llama_sampling_params { int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled - float min_p = 1.0f; // 1.0 (or 0.0) = disabled + float min_p = 0.0f; // 1.0 (or 0.0) = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled diff --git a/llama.cpp b/llama.cpp index d311f546189c1..9b0a9b5b2d308 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7377,18 +7377,15 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can std::vector filtered_candidates; filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations - size_t kept_count = 0; // Counter for how many tokens are kept - for (size_t i = 0; i < candidates->size; ++i) { - // If a token's probability is above the threshold, we keep it. - if (candidates->data[i].p >= multiplied_min_p) { + // If a token's probability is above the threshold or if we haven't kept enough tokens yet + if (candidates->data[i].p >= multiplied_min_p || filtered_candidates.size() < min_keep) { filtered_candidates.push_back(candidates->data[i]); - kept_count++; // Increment the counter } } // If not enough candidates meet the threshold, take the top 'min_keep' ones - if (kept_count < min_keep) { + if (filtered_candidates.size() < min_keep) { std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.p > b.p; // Sort by probability in descending order From cb233584cceed3c2ad481522153e06eaa56ca794 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 23:40:23 -0500 Subject: [PATCH 16/23] minor whitespace fix --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 9b0a9b5b2d308..07ec721e26a4d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7332,7 +7332,7 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can if (p >= 1.0f) { return; } - + llama_sample_softmax(ctx, candidates); const int64_t t_start_sample_us = ggml_time_us(); From fcbbfc1666d5a5604da00b176cd5461770ffce48 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 23:52:22 -0500 Subject: [PATCH 17/23] Even formatting + exclusively 0.0f to disable now --- common/sampling.cpp | 4 ++-- common/sampling.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 6c4e6bc1fa144..bd4f34b9e11ad 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -191,9 +191,9 @@ llama_token llama_sampling_sample( llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); if (min_p != 0.0) { - llama_sample_min_p(ctx_main, &cur_p, min_p, min_keep); + llama_sample_min_p(ctx_main, &cur_p, min_p, min_keep); } else { - llama_sample_top_p(ctx_main, &cur_p, top_p, min_keep); + llama_sample_top_p(ctx_main, &cur_p, top_p, min_keep); } llama_sample_temp (ctx_main, &cur_p, temp); diff --git a/common/sampling.h b/common/sampling.h index 97f2c170a474b..84051d62ab5d3 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -14,7 +14,7 @@ typedef struct llama_sampling_params { int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.0f; // 1.0 (or 0.0) = disabled + float min_p = 0.00f; // 0.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled From 69e638e56aec3fa67df2a8ee192059191805b697 Mon Sep 17 00:00:00 2001 From: cebtenzzre Date: Sun, 29 Oct 2023 00:23:21 -0400 Subject: [PATCH 18/23] cleanup --- common/common.cpp | 6 +++--- common/sampling.cpp | 2 +- llama.cpp | 43 ++++++++++++++----------------------------- llama.h | 2 +- 4 files changed, 19 insertions(+), 34 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 5ff8c579de491..44893a219d7c8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -218,12 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } sparams.top_p = std::stof(argv[i]); - } else if (arg == "--min-p") { // Adding min_p argument + } else if (arg == "--min-p") { if (++i >= argc) { invalid_param = true; break; } - sparams.min_p = std::stof(argv[i]); // Parsing and setting the min_p value from command line + sparams.min_p = std::stof(argv[i]); } else if (arg == "--temp") { if (++i >= argc) { invalid_param = true; @@ -685,7 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); - printf(" --min-p N min-p sampling (default: %.2f, 0.0 = disabled)\n", (double)sparams.min_p); + printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); diff --git a/common/sampling.cpp b/common/sampling.cpp index bd4f34b9e11ad..ccac97a247d7f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -191,7 +191,7 @@ llama_token llama_sampling_sample( llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); if (min_p != 0.0) { - llama_sample_min_p(ctx_main, &cur_p, min_p, min_keep); + llama_sample_min_p(ctx_main, &cur_p, min_p, min_keep); } else { llama_sample_top_p(ctx_main, &cur_p, top_p, min_keep); } diff --git a/llama.cpp b/llama.cpp index 07ec721e26a4d..1f27016c65da8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7361,44 +7361,29 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can } void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - float base_min_p = p; // This will hold the base minimum probability value - float multiplied_min_p; // This will hold the adjusted minimum probability threshold + if (p <= 0.0f || !candidates->size) { + return; + } - // Ensure the probabilities are calculated. llama_sample_softmax(ctx, candidates); - // Calculate the multiplication factor based on the highest scoring token. - float multiplication_factor = candidates->data[0].p; - - // Calculate the minimum percentage requirement. - multiplied_min_p = base_min_p * multiplication_factor; + const int64_t t_start_sample_us = ggml_time_us(); - // Store the tokens that meet the threshold in a new list. - std::vector filtered_candidates; - filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations + float scale = candidates->data[0].p; // scale by max prob + size_t i = 1; // first token always matches - for (size_t i = 0; i < candidates->size; ++i) { - // If a token's probability is above the threshold or if we haven't kept enough tokens yet - if (candidates->data[i].p >= multiplied_min_p || filtered_candidates.size() < min_keep) { - filtered_candidates.push_back(candidates->data[i]); + for (; i < candidates->size; ++i) { + if (candidates->data[i].p < p * scale && i >= min_keep) { + break; // prob too small } } - // If not enough candidates meet the threshold, take the top 'min_keep' ones - if (filtered_candidates.size() < min_keep) { - std::sort(candidates->data, candidates->data + candidates->size, - [](const llama_token_data & a, const llama_token_data & b) { - return a.p > b.p; // Sort by probability in descending order - }); - filtered_candidates.clear(); // Clear the previously filtered candidates - for (size_t i = 0; i < min_keep; ++i) { - filtered_candidates.push_back(candidates->data[i]); - } - } + // Resize the output vector to keep only the matching tokens + candidates->size = i; - // Now we replace the original candidates with the filtered list. - std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data); - candidates->size = filtered_candidates.size(); + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { diff --git a/llama.h b/llama.h index 62addffdfc693..157d755755311 100644 --- a/llama.h +++ b/llama.h @@ -600,7 +600,7 @@ extern "C" { float p, size_t min_keep); - /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841#issue-1966758357 + /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 LLAMA_API void llama_sample_min_p( struct llama_context * ctx, llama_token_data_array * candidates, From 3ddfd67d137841c51e991f920a4ed7b865dd5e23 Mon Sep 17 00:00:00 2001 From: cebtenzzre Date: Sun, 29 Oct 2023 01:31:14 -0400 Subject: [PATCH 19/23] permit simultaneous use of top_p and min_p --- common/sampling.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index ccac97a247d7f..673d67a6d5380 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -190,11 +190,8 @@ llama_token llama_sampling_sample( llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); - if (min_p != 0.0) { - llama_sample_min_p(ctx_main, &cur_p, min_p, min_keep); - } else { - llama_sample_top_p(ctx_main, &cur_p, top_p, min_keep); - } + llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); + llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); llama_sample_temp (ctx_main, &cur_p, temp); id = llama_sample_token(ctx_main, &cur_p); From 9248325f82cd1d8cda8f598ba56a28ba2cd5c174 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Tue, 31 Oct 2023 11:25:23 -0500 Subject: [PATCH 20/23] Update README & set 0.05 default --- common/sampling.h | 2 +- examples/main/README.md | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/common/sampling.h b/common/sampling.h index 84051d62ab5d3..7c9b8dcf23bcb 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -14,7 +14,7 @@ typedef struct llama_sampling_params { int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.00f; // 0.0 = disabled + float min_p = 0.05f; // 0.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled diff --git a/examples/main/README.md b/examples/main/README.md index a9561c383c0cb..1db03b991266d 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -208,6 +208,14 @@ Top-p sampling, also known as nucleus sampling, is another text generation metho Example usage: `--top-p 0.95` +### Min P Sampling + +- `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.0). + +Min-P sampling is a sampling method where the value represents the base required probability percentage. This value is adjusted based on the probability of the most likely token. For example, with a Min-P value set at 0.05 and the highest token probability being 90%, the minimum required threshold becomes 4.5%. This approach ensures a balance of quality and variety in the results. The default value is 0.05. + +Example usage: `--min-p 0.05` + ### Tail Free Sampling (TFS) - `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled). From 512cac630c8ca7c20964d9b488b5a93d159a55ae Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Tue, 31 Oct 2023 11:32:01 -0500 Subject: [PATCH 21/23] added a bit more context to the README --- examples/main/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/README.md b/examples/main/README.md index 1db03b991266d..952b6f752d87c 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -212,7 +212,7 @@ Example usage: `--top-p 0.95` - `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.0). -Min-P sampling is a sampling method where the value represents the base required probability percentage. This value is adjusted based on the probability of the most likely token. For example, with a Min-P value set at 0.05 and the highest token probability being 90%, the minimum required threshold becomes 4.5%. This approach ensures a balance of quality and variety in the results. The default value is 0.05. +Min-P sampling is a sampling method where the value represents the base required probability percentage. This value is scaled down based on the probability of the most likely token. For example, with a Min-P value set at 0.05 and the highest token probability being 90%, the percentage requirement threshold becomes 4.5%. This approach ensures a balance of quality and variety in the results, and was designed as an alternative to improve upon Top P. The default value is 0.05. Example usage: `--min-p 0.05` From 974640ac255c51661f6d25d08f5e707f3a6d69c9 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Tue, 31 Oct 2023 13:18:48 -0500 Subject: [PATCH 22/23] Update README for consistency --- examples/main/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/README.md b/examples/main/README.md index 952b6f752d87c..568d9b7e14cbe 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -212,7 +212,7 @@ Example usage: `--top-p 0.95` - `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.0). -Min-P sampling is a sampling method where the value represents the base required probability percentage. This value is scaled down based on the probability of the most likely token. For example, with a Min-P value set at 0.05 and the highest token probability being 90%, the percentage requirement threshold becomes 4.5%. This approach ensures a balance of quality and variety in the results, and was designed as an alternative to improve upon Top P. The default value is 0.05. +The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. Example usage: `--min-p 0.05` From 3b58af26488d15e32a35db0d5cb637bee334bb31 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Tue, 31 Oct 2023 13:19:49 -0500 Subject: [PATCH 23/23] forgot one small thing! --- examples/main/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/README.md b/examples/main/README.md index 568d9b7e14cbe..a3428b48763d0 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -210,7 +210,7 @@ Example usage: `--top-p 0.95` ### Min P Sampling -- `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.0). +- `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.05). The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.