Skip to content

Commit

Permalink
Use bf16 kv cache when it's faster
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Aug 25, 2024
1 parent 98eff09 commit fa4c4e7
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
3 changes: 3 additions & 0 deletions llama.cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f32") {
return GGML_TYPE_F32;
}
if (s == "bf16") {
return GGML_TYPE_BF16;
}
if (s == "f16") {
return GGML_TYPE_F16;
}
Expand Down
5 changes: 3 additions & 2 deletions llama.cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <thread>
#include <unordered_map>
#include <tuple>
#include <cosmo.h>

#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
Expand Down Expand Up @@ -192,8 +193,8 @@ struct gpt_params {
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data

std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V
std::string cache_type_k = X86_HAVE(AVX512_BF16) ? "bf16" : "f16"; // KV cache data type for the K [jart]
std::string cache_type_v = X86_HAVE(AVX512_BF16) ? "bf16" : "f16"; // KV cache data type for the V [jart]

// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector
Expand Down
5 changes: 4 additions & 1 deletion llama.cpp/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16766,7 +16766,10 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}

if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
// [jart] allow bf16
if (params.type_v != GGML_TYPE_F16 &&
params.type_v != GGML_TYPE_BF16 &&
!params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
}
Expand Down
6 changes: 3 additions & 3 deletions llamafile/tinyblas_cpu_sgemm.inc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const

case GGML_TYPE_BF16: {
#if defined(__AVX512BF16__)
if (Btype == GGML_TYPE_F32 && n < 2) {
if (Btype == GGML_TYPE_F32 && n <= 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n);
Expand Down Expand Up @@ -120,7 +120,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const

case GGML_TYPE_F16: {
#if defined(__AVX512F__)
if (Btype == GGML_TYPE_F32 && n < 2) {
if (Btype == GGML_TYPE_F32 && n <= 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n);
Expand All @@ -136,7 +136,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
if (X86_CHECK(F16C)) {
if (Btype == GGML_TYPE_F32 && n < 2) {
if (Btype == GGML_TYPE_F32 && n <= 2) {
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n);
Expand Down

2 comments on commit fa4c4e7

@Djip007
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jart
Copy link
Collaborator Author

@jart jart commented on fa4c4e7 Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oooh good catch. c7c4d65

Please sign in to comment.