From 058e4c6d8cb82bbae796b82fcd179d0b9cb45a54 Mon Sep 17 00:00:00 2001 From: Djip007 <3705339+Djip007@users.noreply.github.com> Date: Sat, 7 Dec 2024 20:49:49 +0100 Subject: [PATCH] more perfo with llamafile tinyblas on x86_64. - add bf16 suport - change dispache strategie (thanks: https://github.com/ikawrakow/ik_llama.cpp/pull/71 ) - reduce memory bandwidth --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 358 +++++++++++++------------- 1 file changed, 186 insertions(+), 172 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index da4146ec4f6886..9b290d66deb5cf 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -134,6 +134,16 @@ inline __m512 madd(__m512 a, __m512 b, __m512 c) { return _mm512_fmadd_ps(a, b, c); } #endif +#if defined(__AVX512BF16__) +template <> +inline __m512 madd(__m512bh a, __m512bh b, __m512 c) { + return _mm512_dpbf16_ps(c, a, b); +} +template <> +inline __m256 madd(__m256bh a, __m256bh b, __m256 c) { + return _mm256_dpbf16_ps(c, a, b); +} +#endif #endif #if defined(__ARM_FEATURE_FMA) @@ -225,6 +235,13 @@ template <> inline __m256 load(const float *p) { } #endif // __AVX__ +#if defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m256 load(const ggml_bf16_t *p) { + return _mm256_castsi256_ps( + _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16)); +} +#endif // __AVX2__ + #if defined(__F16C__) template <> inline __m256 load(const ggml_fp16_t *p) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); @@ -238,8 +255,27 @@ template <> inline __m512 load(const float *p) { template <> inline __m512 load(const ggml_fp16_t *p) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); } +template <> inline __m512 load(const ggml_bf16_t *p) { + return _mm512_castsi512_ps( + _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16)); +} #endif // __AVX512F__ +#if defined(__AVX512BF16__) +template <> inline __m512bh load(const ggml_bf16_t *p) { + return (__m512bh)_mm512_loadu_ps((const float *)p); +} +template <> inline __m256bh load(const ggml_bf16_t *p) { + return (__m256bh)_mm256_loadu_ps((const float *)p); +} +template <> inline __m512bh load(const float *p) { + return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p)); +} +template <> inline __m256bh load(const float *p) { + return _mm512_cvtneps_pbh(_mm512_loadu_ps(p)); +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // CONSTANTS @@ -251,6 +287,13 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl); //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION +template +static int64_t BLOCK_SIZE(size_t m) { + if (m % M == 0) return M; + const int64_t NB_BLOC_M = (m + M - 1) / M; + return (m / NB_BLOC_M) + 1; +} + template class tinyBLAS { public: @@ -263,184 +306,121 @@ class tinyBLAS { } void matmul(int64_t m, int64_t n) { - mnpack(0, m, 0, n); + // compute RN/RM for only tile with size RN&RN-1/RM&RM-1 +#if VECTOR_REGISTERS == 32 + if (n<3) { + // max bloc: 8x2 + const int64_t SIZE_M = BLOCK_SIZE<8>(m); + mnpack(m, n, SIZE_M, n); + return; + } + // max bloc: 5x5 + const int64_t SIZE_M = BLOCK_SIZE<5>(m); + const int64_t SIZE_N = BLOCK_SIZE<5>(n); +#else // VECTOR_REGISTERS == 16 + if (n==1) { + // max bloc: 8x1 + const int64_t SIZE_M = BLOCK_SIZE<8>(m); + mnpack(m, n, SIZE_M, 1); + return; + } + if (n==2) { + // max bloc: 5x2 + const int64_t SIZE_M = BLOCK_SIZE<5>(m); + mnpack(m, n, SIZE_M, 2); + return; + } + // max bloc: 3x4 + const int64_t SIZE_M = BLOCK_SIZE<3>(m); + const int64_t SIZE_N = BLOCK_SIZE<4>(n); +#endif + mnpack(m, n, SIZE_M, SIZE_N); } private: - NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { + NOINLINE void mnpack(int64_t m, int64_t n, int64_t SIZE_M, int64_t SIZE_N) { + switch ((SIZE_M << 4) | SIZE_N) { + case 0x81: gemm<8, 1>(m, n); break; #if VECTOR_REGISTERS == 32 - case 0x55: - mc = 5; - nc = 5; - gemm<5, 5>(m0, m, n0, n); - break; - case 0x45: - mc = 4; - nc = 5; - gemm<4, 5>(m0, m, n0, n); - break; - case 0x54: - mc = 5; - nc = 4; - gemm<5, 4>(m0, m, n0, n); - break; - case 0x44: - mc = 4; - nc = 4; - gemm<4, 4>(m0, m, n0, n); - break; - case 0x53: - mc = 5; - nc = 3; - gemm<5, 3>(m0, m, n0, n); - break; - case 0x35: - mc = 3; - nc = 5; - gemm<3, 5>(m0, m, n0, n); - break; - case 0x43: - mc = 4; - nc = 3; - gemm<4, 3>(m0, m, n0, n); - break; -#else - case 0x55: - case 0x54: - case 0x53: - case 0x45: - case 0x44: - case 0x43: - mc = 4; - nc = 3; - gemm<4, 3>(m0, m, n0, n); - break; - case 0x35: + case 0x82: gemm<8, 2>(m, n); break; + case 0x55: gemm<5, 5>(m, n); break; + case 0x54: gemm<5, 4>(m, n); break; + case 0x53: gemm<5, 3>(m, n); break; + case 0x51: gemm<5, 1>(m, n); break; + case 0x45: gemm<4, 5>(m, n); break; + case 0x44: gemm<4, 4>(m, n); break; + case 0x43: gemm<4, 3>(m, n); break; + case 0x35: gemm<3, 5>(m, n); break; #endif - case 0x34: - mc = 3; - nc = 4; - gemm<3, 4>(m0, m, n0, n); - break; - case 0x52: - mc = 5; - nc = 2; - gemm<5, 2>(m0, m, n0, n); - break; - case 0x33: - mc = 3; - nc = 3; - gemm<3, 3>(m0, m, n0, n); - break; - case 0x25: - mc = 2; - nc = 5; - gemm<2, 5>(m0, m, n0, n); - break; - case 0x42: - mc = 4; - nc = 2; - gemm<4, 2>(m0, m, n0, n); - break; - case 0x24: - mc = 2; - nc = 4; - gemm<2, 4>(m0, m, n0, n); - break; - case 0x32: - mc = 3; - nc = 2; - gemm<3, 2>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm<2, 3>(m0, m, n0, n); - break; - case 0x51: - mc = 5; - nc = 1; - gemm<5, 1>(m0, m, n0, n); - break; - case 0x41: - mc = 4; - nc = 1; - gemm<4, 1>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm<2, 2>(m0, m, n0, n); - break; - case 0x15: - mc = 1; - nc = 5; - gemm<1, 5>(m0, m, n0, n); - break; - case 0x14: - mc = 1; - nc = 4; - gemm<1, 4>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm<3, 1>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm<1, 3>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm<2, 1>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm<1, 1>(m0, m, n0, n); - break; + case 0x52: gemm<5, 2>(m, n); break; + case 0x42: gemm<4, 2>(m, n); break; + case 0x41: gemm<4, 1>(m, n); break; + case 0x34: gemm<3, 4>(m, n); break; + case 0x33: gemm<3, 3>(m, n); break; + case 0x32: gemm<3, 2>(m, n); break; + case 0x31: gemm<3, 1>(m, n); break; + case 0x25: gemm<2, 5>(m, n); break; + case 0x24: gemm<2, 4>(m, n); break; + case 0x23: gemm<2, 3>(m, n); break; + case 0x22: gemm<2, 2>(m, n); break; + case 0x21: gemm<2, 1>(m, n); break; + case 0x15: gemm<1, 5>(m, n); break; + case 0x14: gemm<1, 4>(m, n); break; + case 0x13: gemm<1, 3>(m, n); break; + case 0x12: gemm<1, 2>(m, n); break; + case 0x11: gemm<1, 1>(m, n); break; default: - return; + GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", (int)SIZE_M, (int)SIZE_N); + GGML_ASSERT(false); // we have miss something. } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; - mnpack(mp, m, n0, np); - mnpack(m0, m, np, n); } template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; + inline void gemm_bloc(int64_t ii, int64_t jj) { + D Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; l += KN) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + Cv[j][i] = madd(load(A + lda * (ii + i) + l), + load(B + ldb * (jj + j) + l), + Cv[j][i]); + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + + template + NOINLINE void gemm(int64_t m, int64_t n) { + const int64_t ytiles = (m + RM -1) / RM; + const int64_t xtiles = (n + RN -1) / RN; + const int64_t ii_RM = (ytiles - (ytiles * RM - m)); + const int64_t jj_RN = (xtiles - (xtiles * RN - n)); + + const int64_t tiles = xtiles * ytiles; + const int64_t duty = (tiles + nth - 1) / nth; + const int64_t start = duty * ith; int64_t end = start + duty; if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; - D Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; l += KN) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = madd(load(A + lda * (ii + i) + l), - load(B + ldb * (jj + j) + l), - Cv[j][i]); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + const int64_t ii = job / xtiles; + const int64_t jj = job % xtiles; + if (ii < ii_RM) { + if (jj < jj_RN) { + gemm_bloc(ii * RM, jj * RN); + } else { + gemm_bloc(ii * RM, + jj_RN * RN + (jj - jj_RN) * (RN - 1)); + } + } else { + if (jj < jj_RN) { + gemm_bloc(ii_RM * RM + (ii - ii_RM) * (RM - 1), + jj * RN); + } else { + gemm_bloc(ii_RM * RM + (ii - ii_RM) * (RM - 1), + jj_RN * RN + (jj - jj_RN) * (RN - 1)); + } + } } } @@ -1727,15 +1707,49 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda #endif } + case GGML_TYPE_BF16: { +#if defined(__AVX512BF16__) + if ((k % 32) == 0 && Btype == GGML_TYPE_BF16) { + tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; + } +#elif defined(__AVX512F__) + if ((k % 16) == 0 && Btype == GGML_TYPE_BF16) { + tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; + } +#elif defined(__AVX2__) + if ((k % 8) == 0 && Btype == GGML_TYPE_F32) { + tinyBLAS<16, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; + } +#endif + return false; + } case GGML_TYPE_F16: { #if defined(__AVX512F__) if (k % 16) return false; - if (Btype != GGML_TYPE_F32) + if (Btype != GGML_TYPE_F16) return false; - tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, + tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ k, + (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); @@ -1743,11 +1757,11 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) if (k % 8) return false; - if (Btype != GGML_TYPE_F32) + if (Btype != GGML_TYPE_F16) return false; - tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, + tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ k, + (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n);