Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8-bit Quantization #298

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open

8-bit Quantization #298

wants to merge 31 commits into from

Conversation

kroggen
Copy link
Contributor

@kroggen kroggen commented Aug 15, 2023

This PR also has work from Aniket

It implements very basic but understandable 8-bit quantization (using quantize.c) and also dequantization on-the-fly on matmul, rmsnorm and dequantize_token

The RoPE weights are intentionally not quantized as it may cause some loss, although not tested

Example Usage

gcc quantize.c -o quantize
./quantize stories110M.bin
gcc -Ofast -march=native runq.c -o runq
./runq data.bin

@karpathy
Copy link
Owner

Nice, this will be a helpful reference. This is Q8_1 scheme. A few things that are in my mind for quantization:

  • I think I will change the python script directly to export in int8 instead of a quantize.c
  • I think I'll go for Q8_0 which is simpler and just as good
  • I think we have to quantize the activation vector x (dynamically), instead of keeping it float. otherwise we don't realize all the gains we'd want

roughly some of the things that come to mind

@byte-6174
Copy link
Contributor

Acts quant most likely would need fine tuning though, wouldn’t it?
So much more work needed to get that in good shape- huge potential runtime gains once we have it though.

@karpathy
Copy link
Owner

@byte-6174 not to my knowledge? it's possible to do quantization-aware finetuning to improve a model for quantization, but you can quantize it anyway.

@byte-6174
Copy link
Contributor

@karpathy I have been pouring thru literature about this for a few days now and mostly points to need for ft. see for eg ibert. But there are many more that point need for ft.
However, I'm not saying it won't work for sure, just going off of what I am seeing out there :)
We should try it anyhow :D

@byte-6174
Copy link
Contributor

btw, this PR also works for quantizing the llama2 7B model as well. compression from 25GB to 6.2GB. 🎆

@byte-6174
Copy link
Contributor

byte-6174 commented Aug 16, 2023

Following up with the little pseudo-code above, checked in a small experiment that does weights/acts quant on the fly. needs more experiments as blindly turning all matmults run on int8 will run everything to ground :)

the code of interest is this:

void get_quants_and_max(float *ptr, int size, int8_t *out_ptr, float* pmax, char* label){
    float max = -INFINITY;
    for (int i = 0; i < size; i++){
        if (ptr[i] > max) max = ptr[i];
    }
    *pmax = max;
    int8_t x_quant;
    for (int i = 0; i < size; i++){
        x_quant = round(127/max * ptr[i]);
        out_ptr[i] = x_quant;
    }
}

void matmulint(float* xout, float* x, float* w, int n, int d) {
    // W (d,n) @ x (n,) -> xout (d,)
    // by far the most amount of time is spent inside this little function

    // calcualte instantaneous max
    float maxx, maxw;
    int8_t *intx, *intw;
    intx = calloc(n, sizeof(int8_t)); 
    intw = calloc(n*d, sizeof(int8_t)); 
    get_quants_and_max(x, n, intx, &maxx, "x");
    get_quants_and_max(w, d * n, intw, &maxw, "w");

    #pragma omp parallel for private(i)
    for (int i = 0; i < d; i++) {
        int16_t vali = 0;
        for (int j = 0; j < n; j++) {
            // calculate int8 mults
            vali += intw[i*n + j] * intx[j];
        }
        xout[i] = (vali * (maxx * maxw)) / (127 * 127);
    }
}

@mgrabban
Copy link

gcc -Ofast -march=native runq.c -o runq

btw, this PR also works for quantizing the llama2 7B model as well. compression from 25GB to 6.2GB. 🎆

Hello @byte-6174 ,

When I try to run llama2 7b chat quantized version, I get gibberish. I did get coherent response from quantized stories42M.

using run

llama/llama2.c $ make run
gcc -O3 -o run run.c -lm
llama/llama2.c $ ./run bin/llama2_7b_chat.bin -n 16 -i "Why is sky blue?"
Why is sky blue?
How does the sky appear blue?
What is
achieved tok/s: 0.167125

using runq

llama/llama2.c $ gcc -Ofast -march=native runq.c -o runq -lm
llama/llama2.c $ ./runq bin/data.bin -n 16 -i "Why is sky blue?"
Why is sky blue?dj aj grandsls swo refuge花роз Louisiana Alb Alb
achieved tok/s: 1.536885

Not sure what could be wrong.

@kroggen
Copy link
Contributor Author

kroggen commented Aug 16, 2023

@mgrabban How did you quantize the 7B model?

Can you show the output of the quantization? (it can be a link)

@kroggen
Copy link
Contributor Author

kroggen commented Aug 17, 2023

For the CUDA implementation, check #310

@mgrabban
Copy link

@mgrabban How did you quantize the 7B model?

Can you show the output of the quantization? (it can be a link)

You can find the output here
I followed a two step process -

  1. convert original llama2_7b_chat *.pth file (from Meta) into llama2.c *.bin file
  2. quantize the llama2.c *.bin file from step 1 output to data.bin file

@kroggen
Copy link
Contributor Author

kroggen commented Aug 17, 2023

I suspect it is related to the shared_weights

With stories110M.bin it is equal to 1:

$ ./quantize stories110M.bin
vocab size = 32000  shared_weights=1

And with that chat model it is 0

But the quantize is not processing the additional wcls

Can you check with the last commit I sent?

@byte-6174
Copy link
Contributor

yes, that change was needed for llama2 7B model. thanks @kroggen !

./runq data.bin -n 16 -i "why is sky blue?"

why is sky blue? Here's a theory
 everyone's been looking
achieved tok/s: 0.107060

@karpathy
Copy link
Owner

Quantization here is per layer instead of groups. That feels risky? I'd expect llama.cpp does groups?

@byte-6174
Copy link
Contributor

Yes. Llama2.cpp has groups 64 etc.
Why risky ?

@karpathy
Copy link
Owner

One outlier nukes the whole tensor. I'm starting a branch for int8 quantization now. I'll do groups.

@byte-6174
Copy link
Contributor

Humm. Trying to understand this, So the groups of 64 avoids this how ?
You mean outlier in the magnitude sense I'm presuming ?

@karpathy
Copy link
Owner

If there is a bad outlier somewhere, only e.g. up to 63 elements get "messed up" with high error, not the entire tensor. So breaking things up into groups makes things more robust to outliers.

@byte-6174
Copy link
Contributor

byte-6174 commented Aug 17, 2023

Btw as a side note : there is experimental evidence,in llama.cpp and also places like llm.int8, of needing mixed precision to tackle outliers. Thought we might want to / have to consider ?!

@karpathy
Copy link
Owner

That's a good point... The outliers are really annoying because they necessitate ugly special casing. Hmmm

@byte-6174
Copy link
Contributor

Ya this is where weight statistics analysis and all that business is needed, not as clean 🤢

@byte-6174
Copy link
Contributor

Right to solve that to the extreme we can / or need to scale rows and columns convert to int8 send to matmul.

@byte-6174
Copy link
Contributor

byte-6174 commented Aug 17, 2023

There is another cool approach I remember from song han with k-means. But another added complexity. 😜

@kroggen
Copy link
Contributor Author

kroggen commented Aug 17, 2023

Quantization here is per layer instead of groups. That feels risky? I'd expect llama.cpp does groups?

The reason of the choice is because this repo is mainly for educational purposes, and using one scale factor per layer is the simplest way to show dequantization on-the-fly

The code is easier to understand

It is not the most precise. If we want to go to the edge, we can just use llama.cpp

It would be good to check the ppl values to compare how much it decreases with this approach

@karpathy
Copy link
Owner

Btw this PR doesn't actually perform integer dot products, it dequantizes the weight to float in the inner loop and uses that. Are people still obtaining speed improvements from this "weak" version of quantization?

@mgrabban
Copy link

Btw this PR doesn't actually perform integer dot products, it dequantizes the weight to float in the inner loop and uses that. Are people still obtaining speed improvements from this "weak" version of quantization?

I tested stories110M.bin. In normal CPU (e.g. i7-8700K), runq is faster than runopm (which is faster than runfast and run). For high end CPU, runq is faster than runfast and run, but runomp is the fastest (since there are many cores available).

@kroggen
Copy link
Contributor Author

kroggen commented Aug 18, 2023

These are my results

When compiled with these commands:

gcc -Ofast -march=native runq.c -o runq -lm
gcc -Ofast -march=native run.c -o run -lm

On MacBook Pro Intel:

llama2.c bernardo$ ./run stories110M.bin -t 0 | grep tok
achieved tok/s: 36.963495
llama2.c bernardo$ ./runq data.bin -t 0 | grep tok
achieved tok/s: 78.974013

On Linux with AMD EPYC-Rome:

root@tests-br:~/llama2.c# ./run stories110M.bin -t 0 | grep tok
achieved tok/s: 39.503754
root@tests-br:~/llama2.c# ./runq data.bin -t 0 | grep tok
achieved tok/s: 73.079325

On Apple M1/M2 chips there is no difference in performance, probably because they have integrated RAM and these other chips have external RAM.

When compiling both with OpenMP on Linux/AMD using these commands:

gcc -Ofast -march=native -fopenmp run.c -o run -lm
gcc -Ofast -march=native -fopenmp runq.c -o runq -lm

I got these results:

root@tests-br:~/llama2.c# ./run stories110M.bin -t 0 | grep tok
achieved tok/s: 44.485294
root@tests-br:~/llama2.c# ./runq data.bin -t 0 | grep tok
achieved tok/s: 72.897196

@kroggen
Copy link
Contributor Author

kroggen commented Aug 18, 2023

Btw this PR doesn't actually perform integer dot products, it dequantizes the weight to float in the inner loop and uses that

It is intentional. Multiplication with int8 requires a lot of other computations, making it slower than using float32 multiplication.

With dequantization on-the-fly, the code is simpler and faster at the same time

The downside in the current implementation is the use of only 1 scale factor per layer. This can be changed

@byte-6174
Copy link
Contributor

Yes, moreover even that downside is not that significant -- as proven by our results for at least Up to 7B model. Real deal breaker is the outliers and solution is mixed precision of some sort.
Whether any of that belongs in this repo or not is different question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants