diff --git a/llamafile/tinyblas_cpu.h b/llamafile/tinyblas_cpu.h index bf37ade822..f5876c3f76 100644 --- a/llamafile/tinyblas_cpu.h +++ b/llamafile/tinyblas_cpu.h @@ -49,7 +49,7 @@ #pragma GCC diagnostic ignored "-Wpedantic" #pragma GCC diagnostic ignored "-Wignored-attributes" -#define CHUNK 8 +#define CHUNK 16 #define ROW_ALIGN 64 #define MATRIX_ALIGN 4096 #define MAX_ALIGN 4096 @@ -416,6 +416,12 @@ inline void store(ggml_bf16_t *p, float f) { //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION +template +static long BLOCK_SIZE(long m) { + const long NB_BLOC_M = (m + M - 1) / M; + return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1; +} + template class tinyBLAS { public: @@ -424,180 +430,169 @@ class tinyBLAS { : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(long m, long n) { - mnpack(0, m, 0, n); - } - - private: - NOINLINE void mnpack(long m0, long m, long n0, long n) { - long mc, nc, mp, np; - + bool matmul(long m, long n) { #if VECTOR_REGISTERS == 32 - switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { - case 0x55: - mc = 5; - nc = 5; - gemm<5, 5>(m0, m, n0, n); - break; - case 0x54: - case 0x53: - case 0x52: - case 0x45: - case 0x44: - case 0x43: - case 0x42: - case 0x35: - case 0x34: - case 0x33: - case 0x32: - case 0x25: - case 0x24: - case 0x23: - case 0x22: - mc = 2; - nc = 2; - gemm<2, 2>(m0, m, n0, n); - break; - case 0x51: - case 0x41: - case 0x31: - case 0x21: - mc = 2; - nc = 1; - gemm<2, 1>(m0, m, n0, n); - break; - case 0x15: - case 0x14: - case 0x13: - case 0x12: - mc = 1; - nc = 2; - gemm<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm<1, 1>(m0, m, n0, n); - break; - default: - return; + if (m % 8 == 0 && n < 4) { + mnpack<8, 3, 1>(m, n, n); + return true; } -#endif - -#if VECTOR_REGISTERS == 16 - switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) { - case 0x43: - mc = 4; - nc = 3; - gemm<4, 3>(m0, m, n0, n); - break; - case 0x42: - case 0x33: - case 0x32: - case 0x23: - case 0x22: - mc = 2; - nc = 2; - gemm<2, 2>(m0, m, n0, n); - break; - case 0x41: - case 0x31: - case 0x21: - mc = 2; - nc = 1; - gemm<2, 1>(m0, m, n0, n); - break; - case 0x13: - case 0x12: - mc = 1; - nc = 2; - gemm<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm<1, 1>(m0, m, n0, n); - break; - default: - return; + if (m % 16 == 0) { + const long SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 4>(m, n, SIZE_N); + return true; + } + if (m % 8 == 0) { + const long SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 2>(m, n, SIZE_N); + return true; + } + if (m % 4 == 0) { + const long SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 1>(m, n, SIZE_N); + return true; + } +#else // VECTOR_REGISTERS == 16 + if (m % 4 == 0 && n < 3) { + mnpack<4, 2, 1>(m, n, n); + return true; + } + if (m % 16 == 0) { + const long SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 4>(m, n, SIZE_N); + return true; + } + if (m % 8 == 0) { + const long SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 2>(m, n, SIZE_N); + return true; + } + if (m % 4 == 0) { + const long SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 1>(m, n, SIZE_N); + return true; } #endif + return false; + } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; - mnpack(mp, m, n0, np); - mnpack(m0, m, np, n); + private: + template + inline void mnpack(long m, long n, long SIZE_N) { + if (SIZE_N == RN) { + return gemm(m, n); + } + if constexpr (RN > 1) { + return mnpack(m, n, SIZE_N); + //} else { + // GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + // GGML_ASSERT(false); // we have miss something. + } } template - NOINLINE void gemm(long m0, long m, long n0, long n) { - D stack[bsr(k / CHUNK + 1) + 1][RN][RM]; - long ytiles = RM > 1 ? (m - m0) / RM : 1; - long xtiles = RN > 1 ? (n - n0) / RN : 1; - long tiles = xtiles * ytiles; - long duty = (tiles + nth - 1) / nth; - long start = duty * ith; - long end = start + duty; - if (end > tiles) - end = tiles; - for (long job = start; job < end; ++job) { - long ii = m0 + job / xtiles * RM; - long jj = n0 + job % xtiles * RN; - - size_t chunk, sp = 0; - int i, j, rule, step = 2; - for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) { - - D Cv[RN][RM] = {}; - for (long l = 0; l < KN * CHUNK * 4; l += KN) + inline void gemm_bloc(long ii, long jj, long l, D Cv[RN][RM]) { + // help compiler for op order. + if constexpr (RM <= RN) { + V Av[RM]; #pragma GCC unroll 100 - for (j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) { + Av[i] = load(A + lda * (ii + i) + l); + } #pragma GCC unroll 100 - for (i = 0; i < RM; ++i) - Cv[j][i] = madd(load(INDEX(A, lda, ii + i, chunk + l)), // - load(INDEX(B, ldb, jj + j, chunk + l)), // - Cv[j][i]); - - for (rule = bsr(step & -step); --rule;) - for (--sp, j = 0; j < RN; ++j) - for (i = 0; i < RM; ++i) - Cv[j][i] += stack[sp][j][i]; - - for (j = 0; j < RN; ++j) - for (i = 0; i < RM; ++i) - stack[sp][j][i] = Cv[j][i]; + for (int64_t j = 0; j < RN; ++j) { + V Bv = load(B + ldb * (jj + j) + l); +#pragma GCC unroll 100 + for (int64_t i = 0; i < RM; ++i) { + Cv[j][i] = madd(Av[i], Bv, Cv[j][i]); + } } - - D Cv[RN][RM] = {}; - for (; chunk + KN <= k; chunk += KN) + } else { + V Bv[RN]; #pragma GCC unroll 100 - for (j = 0; j < RN; ++j) + for (int64_t j = 0; j < RN; ++j) { + Bv[j] = load(B + ldb * (jj + j) + l); + } #pragma GCC unroll 100 - for (i = 0; i < RM; ++i) - Cv[j][i] = madd(load(INDEX(A, lda, ii + i, chunk)), // - load(INDEX(B, ldb, jj + j, chunk)), // - Cv[j][i]); + for (int64_t i = 0; i < RM; ++i) { + V Av = load(A + lda * (ii + i) + l); +#pragma GCC unroll 100 + for (int64_t j = 0; j < RN; ++j) { + Cv[j][i] = madd(Av, Bv[j], Cv[j][i]); + } + } + } + } + + template + inline void gemm_bloc(long ii, long jj) { + D stack[bsr(k / CHUNK + 1) + 1][RN][RM]; + long chunk, sp = 0; + int i, j, rule, step = 2; + for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) { - while (sp--) - for (j = 0; j < RN; ++j) + D Cv[RN][RM] = {}; + for (long l = 0; l < KN * CHUNK * 4; l += KN) + gemm_bloc(ii, jj, chunk + l, Cv); + + for (rule = bsr(step & -step); --rule;) + for (--sp, j = 0; j < RN; ++j) for (i = 0; i < RM; ++i) Cv[j][i] += stack[sp][j][i]; - float Cf[RN][RM]; for (j = 0; j < RN; ++j) for (i = 0; i < RM; ++i) - Cf[j][i] = hsum(Cv[j][i]); + stack[sp][j][i] = Cv[j][i]; + } - for (; chunk < k; ++chunk) - for (j = 0; j < RN; ++j) - for (i = 0; i < RM; ++i) - Cf[j][i] = fmaf(load(INDEX(A, lda, ii + i, chunk)), // - load(INDEX(B, ldb, jj + j, chunk)), // - Cf[j][i]); + D Cv[RN][RM] = {}; + for (; chunk + KN <= k; chunk += KN) + gemm_bloc(ii, jj, chunk, Cv); + + while (sp--) + for (j = 0; j < RN; ++j) + for (i = 0; i < RM; ++i) + Cv[j][i] += stack[sp][j][i]; + + float Cf[RN][RM]; + for (j = 0; j < RN; ++j) + for (i = 0; i < RM; ++i) + Cf[j][i] = hsum(Cv[j][i]); + for (; chunk < k; ++chunk) for (j = 0; j < RN; ++j) for (i = 0; i < RM; ++i) - store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]); + Cf[j][i] = fmaf(load(INDEX(A, lda, ii + i, chunk)), // + load(INDEX(B, ldb, jj + j, chunk)), // + Cf[j][i]); + + for (j = 0; j < RN; ++j) + for (i = 0; i < RM; ++i) + store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]); + } + + template + NOINLINE void gemm(long m, long n) { + // GGML_ASSERT(m % (RM * BM) == 0); + const long ytiles = m / (RM * BM); + const long xtiles = (n + RN -1) / RN; + const long jj_RN = (xtiles - (xtiles * RN - n)); + + long tiles = xtiles * ytiles; + long duty = (tiles + nth - 1) / nth; + long start = duty * ith; + long end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + const int64_t ii = job / xtiles; + const int64_t jj = job % xtiles; + for (int64_t bi = 0; bi < BM; ++bi) { + if (jj < jj_RN) { + gemm_bloc((ii * BM + bi) * RM, jj * RN); + } else if constexpr (RN > 1) { + gemm_bloc((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1)); + } + } } } diff --git a/llamafile/tinyblas_cpu_sgemm.inc b/llamafile/tinyblas_cpu_sgemm.inc index 99f8378b84..b267d01f72 100644 --- a/llamafile/tinyblas_cpu_sgemm.inc +++ b/llamafile/tinyblas_cpu_sgemm.inc @@ -54,18 +54,15 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const #if defined(__AVX512F__) tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{ k, (const float *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); #elif defined(__AVX__) || defined(__AVX2__) tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{ k, (const float *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); #elif defined(__ARM_NEON) tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{ k, (const float *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); #else return NOT_SUPPORTED; #endif @@ -73,11 +70,13 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const case GGML_TYPE_BF16: { #if defined(__AVX512BF16__) + if (n < 2 && !FLAG_precise) + // TODO(jart): Why is ggml_vec_dot_bf16_unroll() so fast at matvec? + return NOT_PROFITABLE; if (Btype == GGML_TYPE_F32 && n <= 2) { tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); } if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; @@ -86,33 +85,33 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const if (n > 1) { tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{ k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); } else { tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{ k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); } #elif defined(__AVX512F__) + if (Btype != GGML_TYPE_F32) + return NOT_SUPPORTED; tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); #elif defined(__AVX2__) + if (n < 2 && !FLAG_precise) + // TODO(jart): Why is ggml_vec_dot_bf16_unroll() so fast at matvec? + return NOT_PROFITABLE; if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); #else return NOT_SUPPORTED; #endif @@ -120,11 +119,13 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const case GGML_TYPE_F16: { #if defined(__AVX512F__) + if (n < 2 && !FLAG_precise) + // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec? + return NOT_PROFITABLE; if (Btype == GGML_TYPE_F32 && n <= 2) { tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); } if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; @@ -132,15 +133,16 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const return NOT_SUPPORTED; tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{ k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) if (X86_CHECK(F16C)) { + if (n < 2 && !FLAG_precise) + // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec? + return NOT_PROFITABLE; if (Btype == GGML_TYPE_F32 && n <= 2) { tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); } if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; @@ -148,8 +150,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const return NOT_SUPPORTED; tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{ k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); } else { return NOT_SUPPORTED; } @@ -163,8 +164,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const return NOT_SUPPORTED; tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{ k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (n < 2 && !FLAG_precise) // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec? @@ -173,8 +173,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const return NOT_SUPPORTED; tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth}; - tb.matmul(m, n); - return true; + return tb.matmul(m, n); #else return NOT_SUPPORTED; #endif