Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dequantize + matrix multiplication CUDA kernels #2043

Closed

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Jun 28, 2023

I implemented CUDA kernels that do dequantization and matrix matrix multiplication in one step. This eliminates the need for temporary buffers to hold the dequantized f32 matrix given to cuBLAS. For 33b this saves at least ~600 MiB for -ngl <= 60, and at least ~1400 MiB for -ngl >= 61. As a result a few more layers can be offloaded. For this table pp == prompt processing, tg128 == generation of 128 tokens with an empty prompt:

GPU Model Test Max. -ngl master t/s master Max -ngl PR t/s PR
RTX 3090 33b q5_1 pp 53 131 57 61
RTX 3090 33b q5_1 tg128 55 7.86 58 9.90
RTX 3090 33b q5_k_m pp 55 135.19 61 53.73
RTX 3090 33b q5_k_m tg128 58 11.39 61 14.77

Unfortunately my matrix multiplication kernel is not very good and significantly slower than cuBLAS for prompt processing. This is particularly noticeable when there is plenty of VRAM anyways:

GPU Model Test t/s master t/s PR
RTX 3090 7b f16 pp 1023 516
RTX 3090 7b q4_0 pp 1048 477
RTX 3090 7b q4_1 pp 1042 443
RTX 3090 7b q5_0 pp 1032 383
RTX 3090 7b q5_1 pp 1032 359
RTX 3090 7b q8_0 pp 1029 476
RTX 3090 7b q2_K pp 1028 371
RTX 3090 7b q3_K_m pp 1009 310
RTX 3090 7b q4_K_m pp 1010 352
RTX 3090 7b q5_K_m pp 1040 308
RTX 3090 7b q6_K pp 1033 372

By default cuBLAS is still in use. The new dequantization + matrix multiplication kernels can be used by setting the compile option LLAMA_CUDA_DMM.

The implementation works by adding a template that takes a dequantization method. I revised the dequantization to require only the index of the data value to be dequantized. I plan to eventually apply this dequantization scheme for the other templates as well since it is simpler.

I don't understand why the performance of my matrix multiplication kernel is so bad. I tried several other variants but they all had even worse performance. I would very much welcome it if someone were to write a better one.

@JohannesGaessler
Copy link
Collaborator Author

While I was working on this the PR for k-quant super block sizes of 64 #2001 was merged. This option is not supported by my implementation.

@JohannesGaessler
Copy link
Collaborator Author

I don't understand why the macOS CI build is failing. Can someone help?

@Green-Sky
Copy link
Collaborator

I don't understand why the macOS CI build is failing. Can someone help?

its been flaky all day. I manually rerun failed actions, which usually fixes it.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jun 28, 2023

Quite good results, not a small feat despite not matching the speed.
Painful to see it's up to 3x faster in cuBLAS but anyway it's one potential (chunky) dependency less without cuBLAS.

@Green-Sky
Copy link
Collaborator

Making more room in vram matters alot on constrained system.

@ggerganov
Copy link
Owner

ggerganov commented Jun 29, 2023

@JohannesGaessler

I think you might be heading in a wrong direction - let me try to explain

Before starting work on the GPU code, we were focused on implementing a highly-optimized version for CPU-only inference. The key insight was to start using quantized data during matrix multiplication in order to reduce the memory bandwidth and improve the inference performance. To do that, we developed optimized SIMD "kernels" for computing the dot product directly working with the quantized integer data. I.e. it is essential that we do not dequantize back to F16 or F32, but instead quantize src1 to Q8 and perform the QX x Q8 dot product using integer intrinsic. This is much faster than anything else since it minimizes the required memory throughput and it outperforms all CPU BLAS implementations, simply because all of them require 32-bit FP precision. We have also extensively studied the precision of the quantized dot products and we showed that it is good enough (#951)

This was the main reason during your first CUDA PR to ask you to investigate this same approach for the GPU:

#1375 (comment)

Although you showed that this approach was not efficient on older GPUs, I think we should revisit this and put some more effort in making it work. The expectation is that using it for matrix x matrix multiplication will be much faster than F16 / F32 cuBLAS due to the same reasons that we found and confirmed during the CPU-only development.

The existing QX x Q8 dot products are a very good start for matrix x vector multiplications, but are not the best for matrix x matrix multiplications. At the very least, we are missing the block-tiling optimization which should allow us to utilize the L2 cache much better.

The goal of ggerganov/ggml#293 will be to add exactly this type of optimization (and potentially other). We will start with the CPU implementation and then I plan to do the same for Metal. I'm pretty confident that this is the correct approach to take and will recommend to do the same for CUDA and OpenCL.

@slaren
Copy link
Collaborator

slaren commented Jun 29, 2023

I think that dequantizing to f16 or f32 is going to be necessary to use the tensor cores, so I think this approach has merit. Dequantizing to SRAM should still be much faster than the previous dequantization to DRAM, and if the bottleneck is the compute (and it should be), then tensor cores should improve performance significantly, possibly more than integer math.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jun 29, 2023

Regarding floats vs ints: First of all I think it should be noted that CPUs and GPUs have very different performance characteristics. In particular:

  1. Floating point arithmetic is comparatively much faster on GPUs than it is on CPUs.
  2. (NVIDIA) GPUs are able to load 32-128 contiguous bytes at once. So if a GPU loads two q8_0 blocks it needs at least 4 memory accesses (because the current blocks have bad memory alignment) to load the quants and the scales compared to 2 (larger) memory accesses to load two contiguous segments of 32 f16/f32 values.
  3. Given the correct implementation matrix matrix multiplication should not be I/O bound. If there is a benefit from using quantized data it will come from being able to use less shared memory for the tiles. In particular, my implementation previously did f32 + f32 -> f32 matrix matrix multiplication without on-the-fly dequantization and the performance was essentially the same (but I think an implementation utilizing larger tiles would still be beneficial in terms of reducing I/O).
  4. Per-byte integer intrinsics, particularly those using tensor cores, are not available on all GPUs (or they have bad performance). So we will need an f32 implementation anyways and it serves as a baseline to compare the efficiency of the matrix multiplication kernel itself to a highly optimized implementation, i.e. cuBLAS.
  5. On GPUs you cannot save clock cycles by doing the integer dot product first and then scaling afterwards (when not using integer intrinsics). On a CPU you would have to do 96 multiplications if you dequantize first and then do the dot product compared to 34 multiplications if you first do the dot product and then dequantize. On an NVIDIA GPU both take 3 multiplications because you can do 32 multiplications in parallel. Edit: when using large amounts of shared memory this should actually be possible but I think undesirable.
  6. Tensor cores can use f32, f16, and int8 as input data so dequantization is not a prerequisite for using tensor cores. Sub-byte integer intrinsics currently do not have a stable API.

I think currently the biggest improvement for quantized CUDA performance to be had will be to rearrange the data and to utilize shared memory. For CUDA to efficiently load data it needs to be aligned to 32 bytes. However, the current blocks have sizes of e.g. 18 bytes. So what should be done I think is to rearrange the data when it's transferred into VRAM so that all quants come first followed by all scales. Then, when the data is used the quants are memory aligned and the scales can be read into shared memory as blocks (since they take up comparatively few bytes). Of course this is not mutually exclusive with using quantized dot products though.

@slaren
Copy link
Collaborator

slaren commented Jun 29, 2023

Tensor cores can use f32, f16, and int8 as input data so dequantization is not a prerequisite for using tensor cores. Sub-byte integer intrinsics currently do not have a stable API.

Right, but the intermediate results need higher precision than int4 or even int8. On the CPU, we usually do the dot products in int16 before scaling to float. I am not sure if that's possible to do with tensor cores, but I don't think it is. We could do the same on the GPU, just without tensor cores, and it would likely be faster than the current implementation, but I am not sure that it would be faster than converting to f16 and using the tensor cores.

The int4 ops have been deprecated in Hopper, so I don't expect that they will ever have a stable API.

@JohannesGaessler
Copy link
Collaborator Author

Right, but the intermediate results need higher precision than int4 or even int8.

According to the documentation the accumulator for int8 tensor cores is int32.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jun 29, 2023

Right, but the intermediate results need higher precision than int4 or even int8.

According to the documentation the accumulator for int8 tensor cores is int32.

Good to know. Though the 2nd tensor is usually so small that on-the-fly conversion is always possible. In my fblas-16bit I'm just down-converting the 2nd matrix inside the function, I didn't notice a performance difference.

I wonder if Nvidia will actually release the fp8 functions to consumer hardware, in the past year I had the impression that they actively work against the use of their gaming-GPU line for ML applications. Reducing VRAM, restricting driver licensing, removing SLI. Their datacenter line brings in 2 times more revenue as they just scale the price of large VRAM models up 10 fold.

@slaren
Copy link
Collaborator

slaren commented Jun 29, 2023

I think that merging this as an optional feature and iterating over it later would be ok, but duplicating the dequantization functions is really not good, it's going to add too much maintenance overhead in the future. My understanding was that dequantizing two values at a time was done for performance, so I am not sure why we need functions that only return one value now.

If that cannot be fixed in this implementation, I would suggest leaving this open as a draft for people who may be interested in trying it or iterating over it in the future, but not merging.

@JohannesGaessler
Copy link
Collaborator Author

I'm fine with keeping this PR open; I plan to revise the dequantize_mul_mat_vec kernels anyways.

@JohannesGaessler
Copy link
Collaborator Author

Also the reason I added dequantization kernels that dequantize only one value is to get a unified interface for all quantization formats. Especially if you want to try implementations that have a tile size < 32 in one direction I find that very convenient since you won't have to care about the data layout of the specific quantization types. If you look at the code for dequantize_mul_mat_vec you'll find that you have to explicitly consider whether the quantization has 1 or 2 data values per byte and determine the correct corresponding offset.

@ikawrakow
Copy link
Contributor

Why is the gain in VRAM usage so high? Excluding output.weight (the ngl <. 61case), the largest tensor is in the range of 100-120 million elements at 33B. If dequantized to fp16 and if it is made sure that the implementation only holds one dequantized tensor at any given time, it would seem that one can gain at most 250 MB by using direct quantized matrix multiplication. What am I missing?

@JohannesGaessler
Copy link
Collaborator Author

The order in which the pool buffers are allocated is bad. It first allocates a small buffer, then a larger one, and then an even larger one. So you can't just re-use the previously allocated buffers and instead have to allocate new ones. In #1935 I did a hacky patch that just allocates the largest size at the start so buffers can be reused but I think a preferable solution would be to get rid of the buffers entirely via tensor fusion.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think currently the biggest improvement for quantized CUDA performance to be had will be to rearrange the data and to utilize shared memory. For CUDA to efficiently load data it needs to be aligned to 32 bytes.

Maybe a quick and dirty prototype of this idea could be useful to see how much improvement the 32-byte alignment would bring. Some ideas from #1073 could be potentially useful

@JohannesGaessler
Copy link
Collaborator Author

Speaking of which, I tried reordering the data and it made no difference, presumably because the blocks are already in cache due to the contiguous nature. However, I have a prototype for quantizing the vector and then using SIMD intrinsics to do the dot product. On my RTX 3090 this was ~10% faster.

@JohannesGaessler
Copy link
Collaborator Author

After #2140 where not using __restrict__ massively gimped performance I tried just adding __restrict__ to the kernel in this PR. The result is a 2x speedup with q4_0 reaching 90% of the speed of cuBLAS. However, after #2067 I'll try an implementation based on integer intrinsics; ideally that will be faster than this PR anyways.

@ggerganov
Copy link
Owner

This is a good find - I think this is the first time where I see restrict making a difference and it is quite significant.
So far I've been using it just because it is a common wisdom to do so, but have never seen performance impact.

@JohannesGaessler
Copy link
Collaborator Author

Superseded by #2160 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants