Skip to content

Commit

Permalink
Reduce rounding errors by 10x
Browse files Browse the repository at this point in the history
This change uses the ruler function to modify the TinyBLAS GEMM function
on CPU so the roundoff error accumulation is always 10x less bad without
sacrificing any performance. This is important for larger models such as
Command-R+ which has dimensions as large as 26000 elements. For example:

                      average error bits
             0b000000000000000000000000xxxxxxxx naive
             0b0000000000000000000000000000xxxx ruler
             0b00000000000000000000000000000xxx kahan
               └──────┬───────┘
                  original
                  fidelity

                    worst case error bits
             0b00000000xxxxxxxxxxxxxxxxxxxxxxxx naive
             0b00000000000000xxxxxxxxxxxxxxxxxx ruler
             0b0000000000000000xxxxxxxxxxxxxxxx kahan
               └──────┬───────┘
                 bf16 & f16

The new implementation uses a non-recursive divide-and-conquer technique
for reducing dot products. It's not as good as Kahan summation, which we
previously made available behind the --precise flag. However it seems to
limit error growth nearly as well. This means that when you use the BF16
and F16 weights, llamafile will preserve the original fidelity. While an
average case error of 233 ULP before may not seem like a big deal all it
should take is a single worst case error to flip a single concept in the
LLM's brain to sow confusion. So having a better guarantee here matters.

    float fsumf_ruler(const float *p, size_t n) {
      size_t i, sp = 0;
      int rule, step = 2;
      float stack[bsr(n / CHUNK + 1) + 1];
      for (i = 0; i + CHUNK * 4 <= n; i += CHUNK * 4, step += 2) {
        float sum = 0;
        for (size_t j = 0; j < CHUNK * 4; ++j)
          sum += p[i + j];
        for (rule = bsr(step & -step); --rule;)
          sum += stack[--sp];
        stack[sp++] = sum;
      }
      float res = 0;
      while (sp)
        res += stack[--sp];
      while (i < n)
        res += p[i++];
      return res;
    }

The reference impl for this algorithm is what I call ruler summation and
it's a very fast way to sum a sequence of floating point numbers without
too many errors, and superior performance under both IEEE and fast math.
The only weird thing about this algorithm is that it requires a variable
length array, but since it only takes up logarithmic space you should be
able to run just about any LLM with on stack size of 64kb although we'll
be increasing it to 128kb in llamafile just to be safe.

See also https://oeis.org/A001511
  • Loading branch information
jart committed Jul 29, 2024
1 parent a73ea13 commit cb817f5
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 237 deletions.
1 change: 1 addition & 0 deletions llama.cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -19209,6 +19209,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl

pthread_attr_t attr;
pthread_attr_init(&attr);
pthread_attr_setstacksize(&attr, 128 * 1024);
pthread_attr_setguardsize(&attr, sysconf(_SC_PAGESIZE));
pthread_attr_setsigaltstacksize_np(&attr, sysconf(_SC_MINSIGSTKSZ) + 16384);
const int rc = ggml_thread_create((pthread_t *)&workers[j].thrd, &attr,
Expand Down
4 changes: 3 additions & 1 deletion llamafile/ansiblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <cmath>
#include <unistd.h>

int cpu_get_num_math();

namespace {
namespace ansiBLAS {

Expand Down Expand Up @@ -133,7 +135,7 @@ void sgemm(int m, int n, int k, //
const float *A, int lda, //
const float *B, int ldb, //
float *C, int ldc) {
int nth = sysconf(_SC_NPROCESSORS_ONLN);
static int nth = cpu_get_num_math();
#pragma omp parallel for
for (int ith = 0; ith < nth; ++ith) {
ansiBLAS tb{k, A, lda, B, ldb, C, ldc, ith, nth};
Expand Down
2 changes: 1 addition & 1 deletion llamafile/bench.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
x; \
__asm__ volatile("" ::: "memory"); \
} \
printf("%9lld us %s\n", (micros() - start + ITERATIONS - 1) / ITERATIONS, #x); \
printf("%12lld us %s\n", (micros() - start + ITERATIONS - 1) / ITERATIONS, #x); \
} while (0)
26 changes: 17 additions & 9 deletions llamafile/sgemm_matmul_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
#define ITERATIONS 30
#define ALLOC(n) (float *)memalign(4096, sizeof(float) * (n))

int cpu_get_num_math();

void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, const void *B,
long ldb, void *C, long ldc, int task, int Atype, int Btype, int Ctype,
int precision) {
int nth = sysconf(_SC_NPROCESSORS_ONLN);
static int nth = cpu_get_num_math();
#pragma omp parallel for
for (int ith = 0; ith < nth; ++ith) {
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype,
Expand All @@ -44,7 +46,7 @@ void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, con
int test(void) {
int m = 256;
int n = 500;
int k = 32768 * 2;
int k = 260000;
int lda = ROUNDUP(k, 16);
int ldb = ROUNDUP(k, 16);
int ldc = ROUNDUP(m, 16);
Expand All @@ -63,12 +65,15 @@ int test(void) {
BENCH(llamafile_sgemm_openmp(m, n, k, A, lda, B, ldb, C, ldc, GGML_TASK_TYPE_COMPUTE,
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, GGML_PREC_DEFAULT));

int flips = 0;
double err_sum = 0;
long long err_worst = 0;
for (int i = 0; i < m; ++i)
for (int j = 0; j < n; ++j) {
float g = G[ldc * j + i];
float c = C[ldc * j + i];
if (signbit(g) != signbit(c))
++flips;
if (flt::isnan(g)) {
fprintf(stderr, "%s:%err: found nan in reference matrix: i=%err j=%err\n", __FILE__,
__LINE__, i, j);
Expand All @@ -90,8 +95,9 @@ int test(void) {
}

double err_avg = err_sum / (m * n);
fprintf(stderr, "%9g ulp average\n", err_avg);
fprintf(stderr, "%9lld ulp worst\n", err_worst);
fprintf(stderr, "%12g ulp average\n", err_avg);
fprintf(stderr, "%12lld ulp worst\n", err_worst);
fprintf(stderr, "%12d flips\n", flips);

// using one accumulator
// 87015 us gemm
Expand Down Expand Up @@ -125,15 +131,17 @@ int test(void) {
int main(int argc, char *argv[]) {
int rc;

llamafile_trapping_enabled(+1);
// llamafile_trapping_enabled(+1);

printf("\n");
if ((rc = test()))
return rc;

printf("\nFLAG_precise = false;\n");
FLAG_precise = false;
printf("\n");
if ((rc = test()))
return rc;

printf("\nFLAG_precise = true;\n");
FLAG_precise = true;
printf("\n");
if ((rc = test()))
return rc;
}
22 changes: 13 additions & 9 deletions llamafile/sgemm_sss_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
#define ITERATIONS 5
#define ALLOC(n) (float *)memalign(4096, sizeof(float) * (n))

int cpu_get_num_math();

void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, const void *B,
long ldb, void *C, long ldc, int task, int Atype, int Btype, int Ctype,
int precision) {
int nth = sysconf(_SC_NPROCESSORS_ONLN);
static int nth = cpu_get_num_math();
#pragma omp parallel for
for (int ith = 0; ith < nth; ++ith) {
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype,
Expand All @@ -44,7 +46,7 @@ int test(void) {
float tolerance = 2e-5;
int m = 510;
int n = 513;
int k = 512 * 8;
int k = 260000;
int lda = ROUNDUP(k, 16);
int ldb = ROUNDUP(k, 16);
int ldc = ROUNDUP(m, 16);
Expand All @@ -60,8 +62,8 @@ int test(void) {
randomize(k, n, B, ldb);

BENCH(ansiBLAS::sgemm(m, n, k, A, lda, B, ldb, G, ldc));
BENCH(llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1, GGML_TASK_TYPE_COMPUTE,
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, GGML_PREC_DEFAULT));
BENCH(llamafile_sgemm_openmp(m, n, k, A, lda, B, ldb, C, ldc, GGML_TASK_TYPE_COMPUTE,
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, GGML_PREC_DEFAULT));

double err_sum = 0;
long long err_worst = 0;
Expand Down Expand Up @@ -104,15 +106,17 @@ int test(void) {
int main(int argc, char *argv[]) {
int rc;

llamafile_trapping_enabled(+1);
// llamafile_trapping_enabled(+1);

printf("\n");
if ((rc = test()))
return rc;

printf("\nFLAG_precise = false;\n");
FLAG_precise = false;
printf("\n");
if ((rc = test()))
return rc;

printf("\nFLAG_precise = true;\n");
FLAG_precise = true;
printf("\n");
if ((rc = test()))
return rc;
}
37 changes: 27 additions & 10 deletions llamafile/sgemm_vecdot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,29 @@
#include "macros.h"
#include "numba.h"
#include "sgemm.h"
#include <assert.h>

#define ITERATIONS 30
#define ALLOC(n) (float *)memalign(4096, sizeof(float) * (n))

int cpu_get_num_math();

void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, const void *B,
long ldb, void *C, long ldc, int task, int Atype, int Btype, int Ctype,
int precision) {
static int nth = cpu_get_num_math();
#pragma omp parallel for
for (int ith = 0; ith < nth; ++ith) {
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype,
Ctype, precision);
assert(res);
}
}

int test(void) {
int m = 1025;
int m = 1024;
int n = 1;
int k = 32768;
int k = 260000;
int lda = ROUNDUP(k, 16);
int ldb = ROUNDUP(k, 16);
int ldc = ROUNDUP(m, 16);
Expand All @@ -45,8 +60,8 @@ int test(void) {
randomize(k, n, B, ldb);

BENCH(ansiBLAS::sgemm(m, n, k, A, lda, B, ldb, G, ldc));
BENCH(llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1, GGML_TASK_TYPE_COMPUTE,
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, GGML_PREC_DEFAULT));
BENCH(llamafile_sgemm_openmp(m, n, k, A, lda, B, ldb, C, ldc, GGML_TASK_TYPE_COMPUTE,
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, GGML_PREC_DEFAULT));

double err_sum = 0;
long long err_worst = 0;
Expand Down Expand Up @@ -75,8 +90,8 @@ int test(void) {
}

double err_avg = err_sum / (m * n);
fprintf(stderr, "%9g ulp average\n", err_avg);
fprintf(stderr, "%9lld ulp worst\n", err_worst);
fprintf(stderr, "%12g ulp average\n", err_avg);
fprintf(stderr, "%12lld ulp worst\n", err_worst);

// using one accumulator
// 40209 us gemm
Expand Down Expand Up @@ -110,13 +125,15 @@ int test(void) {
int main(int argc, char *argv[]) {
int rc;

printf("\nFLAG_precise = false;\n");
FLAG_precise = false;
printf("\n");
if ((rc = test()))
return rc;

printf("\n");
if ((rc = test()))
return rc;

printf("\nFLAG_precise = true;\n");
FLAG_precise = true;
printf("\n");
if ((rc = test()))
return rc;
}
Loading

0 comments on commit cb817f5

Please sign in to comment.