Skip to content

Commit

Permalink
more perfo with llamafile tinyblas
Browse files Browse the repository at this point in the history
- change dispache strategie (thanks:
ikawrakow/ik_llama.cpp#71 )
- more cache freindly
  • Loading branch information
Djip007 committed Dec 11, 2024
1 parent 8fa1702 commit b542e64
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 182 deletions.
299 changes: 147 additions & 152 deletions llamafile/tinyblas_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
#pragma GCC diagnostic ignored "-Wpedantic"
#pragma GCC diagnostic ignored "-Wignored-attributes"

#define CHUNK 8
#define CHUNK 16
#define ROW_ALIGN 64
#define MATRIX_ALIGN 4096
#define MAX_ALIGN 4096
Expand Down Expand Up @@ -416,6 +416,12 @@ inline void store(ggml_bf16_t *p, float f) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION

template <int M>
static long BLOCK_SIZE(long m) {
const long NB_BLOC_M = (m + M - 1) / M;
return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
}

template <int CONFIG, int KN, typename D, typename V, typename TA, typename TB, typename TC>
class tinyBLAS {
public:
Expand All @@ -424,180 +430,169 @@ class tinyBLAS {
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}

void matmul(long m, long n) {
mnpack(0, m, 0, n);
}

private:
NOINLINE void mnpack(long m0, long m, long n0, long n) {
long mc, nc, mp, np;

bool matmul(long m, long n) {
#if VECTOR_REGISTERS == 32
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
case 0x55:
mc = 5;
nc = 5;
gemm<5, 5>(m0, m, n0, n);
break;
case 0x54:
case 0x53:
case 0x52:
case 0x45:
case 0x44:
case 0x43:
case 0x42:
case 0x35:
case 0x34:
case 0x33:
case 0x32:
case 0x25:
case 0x24:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x51:
case 0x41:
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x15:
case 0x14:
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
if (m % 8 == 0 && n < 4) {
mnpack<8, 3, 1>(m, n, n);
return true;
}
#endif

#if VECTOR_REGISTERS == 16
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3>(m0, m, n0, n);
break;
case 0x42:
case 0x33:
case 0x32:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x41:
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
if (m % 16 == 0) {
const long SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 4>(m, n, SIZE_N);
return true;
}
if (m % 8 == 0) {
const long SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 2>(m, n, SIZE_N);
return true;
}
if (m % 4 == 0) {
const long SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 1>(m, n, SIZE_N);
return true;
}
#else // VECTOR_REGISTERS == 16
if (m % 4 == 0 && n < 3) {
mnpack<4, 2, 1>(m, n, n);
return true;
}
if (m % 16 == 0) {
const long SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 4>(m, n, SIZE_N);
return true;
}
if (m % 8 == 0) {
const long SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 2>(m, n, SIZE_N);
return true;
}
if (m % 4 == 0) {
const long SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 1>(m, n, SIZE_N);
return true;
}
#endif
return false;
}

mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
private:
template <int RM, int RN, int BM>
inline void mnpack(long m, long n, long SIZE_N) {
if (SIZE_N == RN) {
return gemm<RM, RN, BM>(m, n);
}
if constexpr (RN > 1) {
return mnpack<RM, RN-1, BM>(m, n, SIZE_N);
//} else {
// GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
// GGML_ASSERT(false); // we have miss something.
}
}

template <int RM, int RN>
NOINLINE void gemm(long m0, long m, long n0, long n) {
D stack[bsr(k / CHUNK + 1) + 1][RN][RM];
long ytiles = RM > 1 ? (m - m0) / RM : 1;
long xtiles = RN > 1 ? (n - n0) / RN : 1;
long tiles = xtiles * ytiles;
long duty = (tiles + nth - 1) / nth;
long start = duty * ith;
long end = start + duty;
if (end > tiles)
end = tiles;
for (long job = start; job < end; ++job) {
long ii = m0 + job / xtiles * RM;
long jj = n0 + job % xtiles * RN;

size_t chunk, sp = 0;
int i, j, rule, step = 2;
for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) {

D Cv[RN][RM] = {};
for (long l = 0; l < KN * CHUNK * 4; l += KN)
inline void gemm_bloc(long ii, long jj, long l, D Cv[RN][RM]) {
// help compiler for op order.
if constexpr (RM <= RN) {
V Av[RM];
#pragma GCC unroll 100
for (j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i) {
Av[i] = load<V>(A + lda * (ii + i) + l);
}
#pragma GCC unroll 100
for (i = 0; i < RM; ++i)
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, chunk + l)), //
load<V>(INDEX(B, ldb, jj + j, chunk + l)), //
Cv[j][i]);

for (rule = bsr(step & -step); --rule;)
for (--sp, j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cv[j][i] += stack[sp][j][i];

for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
stack[sp][j][i] = Cv[j][i];
for (int64_t j = 0; j < RN; ++j) {
V Bv = load<V>(B + ldb * (jj + j) + l);
#pragma GCC unroll 100
for (int64_t i = 0; i < RM; ++i) {
Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
}
}

D Cv[RN][RM] = {};
for (; chunk + KN <= k; chunk += KN)
} else {
V Bv[RN];
#pragma GCC unroll 100
for (j = 0; j < RN; ++j)
for (int64_t j = 0; j < RN; ++j) {
Bv[j] = load<V>(B + ldb * (jj + j) + l);
}
#pragma GCC unroll 100
for (i = 0; i < RM; ++i)
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, chunk)), //
load<V>(INDEX(B, ldb, jj + j, chunk)), //
Cv[j][i]);
for (int64_t i = 0; i < RM; ++i) {
V Av = load<V>(A + lda * (ii + i) + l);
#pragma GCC unroll 100
for (int64_t j = 0; j < RN; ++j) {
Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
}
}
}
}

template <int RM, int RN>
inline void gemm_bloc(long ii, long jj) {
D stack[bsr(k / CHUNK + 1) + 1][RN][RM];
long chunk, sp = 0;
int i, j, rule, step = 2;
for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) {

while (sp--)
for (j = 0; j < RN; ++j)
D Cv[RN][RM] = {};
for (long l = 0; l < KN * CHUNK * 4; l += KN)
gemm_bloc<RM, RN>(ii, jj, chunk + l, Cv);

for (rule = bsr(step & -step); --rule;)
for (--sp, j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cv[j][i] += stack[sp][j][i];

float Cf[RN][RM];
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cf[j][i] = hsum(Cv[j][i]);
stack[sp][j][i] = Cv[j][i];
}

for (; chunk < k; ++chunk)
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cf[j][i] = fmaf(load<float>(INDEX(A, lda, ii + i, chunk)), //
load<float>(INDEX(B, ldb, jj + j, chunk)), //
Cf[j][i]);
D Cv[RN][RM] = {};
for (; chunk + KN <= k; chunk += KN)
gemm_bloc<RM, RN>(ii, jj, chunk, Cv);

while (sp--)
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cv[j][i] += stack[sp][j][i];

float Cf[RN][RM];
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cf[j][i] = hsum(Cv[j][i]);

for (; chunk < k; ++chunk)
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]);
Cf[j][i] = fmaf(load<float>(INDEX(A, lda, ii + i, chunk)), //
load<float>(INDEX(B, ldb, jj + j, chunk)), //
Cf[j][i]);

for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]);
}

template <int RM, int RN, int BM>
NOINLINE void gemm(long m, long n) {
// GGML_ASSERT(m % (RM * BM) == 0);
const long ytiles = m / (RM * BM);
const long xtiles = (n + RN -1) / RN;
const long jj_RN = (xtiles - (xtiles * RN - n));

long tiles = xtiles * ytiles;
long duty = (tiles + nth - 1) / nth;
long start = duty * ith;
long end = start + duty;
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
const int64_t ii = job / xtiles;
const int64_t jj = job % xtiles;
for (int64_t bi = 0; bi < BM; ++bi) {
if (jj < jj_RN) {
gemm_bloc<RM, RN>((ii * BM + bi) * RM, jj * RN);
} else if constexpr (RN > 1) {
gemm_bloc<RM, RN - 1>((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1));
}
}
}
}

Expand Down
Loading

0 comments on commit b542e64

Please sign in to comment.