Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: refactor ggml_cuda_op + lower GPU latency via quantization on main GPU and tiling #3110

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds the following 2 optimizations to the CUDA multi GPU code:

  1. When using mul_mat_q and mul_mat_vec_q the data from the hidden state is quantized by the main GPU and then distributed instead of being distributed as f32 and then quantized on each device. This reduces latency and used PCIe bandwidth by a factor of up to 3.56.
  2. For matrix matrix calculations on master the full hidden state is distributed from the main device, then each device does part of the matrix matrix calculation, then the full partial result is written back to the main device. With this PR the hidden state is instead split into tiles of size 128 (in direction of the batch size). The devices receive the tiles and start working on it and writing back results before the rest of the data has arrived. Because GPUs can transfer data and do calculations in parallel this reduces latency by a factor of up to 4 (assuming 512 batch size).

These are the results:

model GPU test t/s master t/s PR Speedup
7b q4_0 1x P40 pp 512 877.43 ± 0.81 871.14 ± 1.55 0.99
7b q4_0 1x P40 tg 128 59.50 ± 0.01 59.27 ± 0.03 1.00
7b q4_0 2x P40 pp 512 477.88 ± 0.26 785.01 ± 1.06 1.64
7b q4_0 2x P40 tg 128 48.08 ± 0.02 49.40 ± 0.02 1.03
7b q4_0 3x P40 pp 512 425.70 ± 1.31 893.17 ± 7.60 2.10
7b q4_0 3x P40 tg 128 51.26 ± 0.03 53.29 ± 0.03 1.04
70b q6_K 3x P40 pp 512 69.13 ± 0.02 147.37 ± 0.05 2.14
7nb q6_K 3x P40 tg 128 7.89 ± 0.00 8.10 ± 0.00 1.03

As part of the above changes I have refactored ggml_cuda_op. I have split the function into a function ggml_cuda_op_flatten and a function ggml_cuda_op_mul_mat. All tensors other than matrix multiplication use ggml_cuda_op_flatten and this function is equivalent to the the old ggml_cuda_op with flatten_rows=true (but greatly simplified). ggml_cuda_op_mul_mat is more complicated due to the various performance considerations but it should hopefully not be any more difficult to understand than the old ggml_cuda_op.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

Unfortunately this seems to reduce performance under WSL2, even with a single GPU.

ggml_init_cublas: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6

model test master t/s PR t/s Speedup
LLaMA 7B mostly Q4_0 pp 512 2244.27 ± 73.47 2081.22 ± 53.19 0.92
LLaMA 7B mostly Q4_0 tg 128 42.44 ± 0.38 37.01 ± 0.68 0.87

ggml_init_cublas: found 2 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6
Device 1: NVIDIA GeForce RTX 3080, compute capability 8.6

model test master t/s PR t/s Speedup
LLaMA 7B mostly Q4_0 pp 512 886.21 ± 4.66 763.63 ± 11.12 0.86
LLaMA 7B mostly Q4_0 tg 128 5.82 ± 0.52 5.72 ± 0.44 0.98

@JohannesGaessler
Copy link
Collaborator Author

What could be happening is that waiting on CUDA events is much more expensive on Windows than it is on Linux. I've pushed a commit that removes some unnecessary event waiting. Hopefully that will fix the single GPU performance (because I don't know what else could be the problem).

Also please check how varying the define MUL_MAT_SRC1_COL_STRIDE on line 402 changes performance. If set to >= batch size the multi GPU code should be identical except for the quantization happening prior to the GPU->GPU data transfer.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

The last commit didn't affect single or multi GPU performance, but setting MUL_MAT_SRC1_COL_STRIDE to 512 improves pp performance with multi GPU significantly, faster than master.

ggml_init_cublas: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2090.50 ± 56.08
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 37.21 ± 0.65

ggml_init_cublas: found 2 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6
Device 1: NVIDIA GeForce RTX 3080, compute capability 8.6

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 763.14 ± 10.99
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 5.74 ± 0.52

build: 54da1a2 (1209)

With MUL_MAT_SRC1_COL_STRIDE=512:

ggml_init_cublas: found 2 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6
Device 1: NVIDIA GeForce RTX 3080, compute capability 8.6

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 1250.09 ± 26.87
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 5.78 ± 0.56

@JohannesGaessler
Copy link
Collaborator Author

Wait, now that I look at your numbers, isn't the multi GPU performance terrible either way though? I mean, ~5 t/s for 7b q4_0 token generation is really bad for an RTX 3090 ti + an RTX 3080. So maybe for this PR it's sufficient to fix the single GPU performance?

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

Yes, multi GPU performance with WSL2 has always been very bad for me. The biggest issue is the degradation in single GPU performance.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Sep 11, 2023

The branch for this PR is rebased, with WIP commits removed. Can you check the single GPU performance on the branch cuda-tensor-tiling-2 which has all of the WIP commits? The performance on some of the commits is bad, but dbddcd1, c26cc71, and 55f7419 would be of interest to me . Basically, if performance for those commits is good then the issue comes from the tiling, otherwise from the ggml_cuda_op refactor.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

Looks like the performance degradation started with 55f7419:

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2275.98 ± 16.62
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 42.09 ± 0.54

build: dbddcd1 (1221)

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2274.08 ± 12.07
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 41.63 ± 0.61

build: c26cc71 (1245)

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2104.83 ± 11.23
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 41.68 ± 0.23

build: 55f7419 (1258)

@JohannesGaessler
Copy link
Collaborator Author

I don't understand what is causing the performance regression. Can you try to find the last commit with good performance?

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

I did a bisect between 55f7419 and c26cc71 and got this:

da6cb22db7b6478797c597dc8120ec7a2cf4e1ca is the first bad commit
commit da6cb22db7b6478797c597dc8120ec7a2cf4e1ca
Author: JohannesGaessler <johannesg@5d6.de>
Date:   Sat Sep 9 21:02:35 2023 +0200

    reorder loops

@JohannesGaessler
Copy link
Collaborator Author

I've pushed a version that has the loops in the old order. Can you check performance?

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

This improved pp performance, but it is still a bit slower than master. tg is still slower. Let me know if you need me to do another bisect.

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2202.92 ± 61.62
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 36.47 ± 0.76

build: bd79c94 (1211)

@JohannesGaessler
Copy link
Collaborator Author

The only thing in the loop that could maybe make a difference for performance is cudaSetDevice. It should be trivial to add a check to the CUDA library that only invokes the drivers if the new device is actually different from the last device but maybe they just forgot for Windows? I've pushed a version that does the caching of the last device in ggml-cuda.cu, please check.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

No change.

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2202.11 ± 59.57
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 36.14 ± 1.00

build: 866b502 (1212)

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

I also tested removing all the calls to cudaGetDevice but that also didn't make any difference, so I don't think that these calls add any latency.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

I noticed that ggml_cuda_set_device doesn't actually save the current device, so I fixed that. And the result is pretty significant.

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2370.14 ± 70.47
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 121.21 ± 0.51

build: 866b502 (1212)

@JohannesGaessler
Copy link
Collaborator Author

Okay so my intuition was correct and I just did the implementation wrong lol. It's pretty ridiculous though that this is apparently not being cached in the CUDA libraries.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

With multi GPU it crashes in quantize_row_q8_1_cuda, any idea what may be causing that?

CUDA error 400 at ggml-cuda.cu:5976: invalid resource handle

The only change is this:

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 1551517..deb5571 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -416,6 +416,8 @@ cudaError_t ggml_cuda_set_device(int device) {
         return cudaSuccess;
     }

+    current_device = device;
+
     return cudaSetDevice(device);
 }

@JohannesGaessler
Copy link
Collaborator Author

I am not able to reproduce the issue. Can you share the exact commands that you used?

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

The command is CUDA_VISIBLE_DEVICES=0,1 ./llama-bench. The weird part is that removing the current_device = device; fixes the issue. I tried initializing current_device to different values, but that doesn't change anything.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

This seems to fix the issue and doesn't affect performance:

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 1551517..611b080 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -410,7 +410,8 @@ struct ggml_tensor_extra_gpu {
 };

 cudaError_t ggml_cuda_set_device(int device) {
-    static int current_device = -1;
+    int current_device;
+    CUDA_CHECK(cudaGetDevice(&current_device));

     if (device == current_device) {
         return cudaSuccess;

Multi GPU performance is still very bad, but I guess that's expected since the cudaSetDevice calls can't be avoided there.

ggml_init_cublas: found 2 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6
Device 1: NVIDIA GeForce RTX 3080, compute capability 8.6

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 793.20 ± 8.43
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 6.02 ± 1.11

build: 866b502 (1212)

@JohannesGaessler
Copy link
Collaborator Author

So presumably the cause of the performance regression that you experienced should be fixed now. In the initial version cudaSetDevice was called in a loop for the KV cache matrix matrix multiplication which made prompt processing slower. Additionally because I separated the original loop into 2 loops you had an extra call to cudaSetDevice for each call of ggml_cuda_op_mul_mat which made both prompt processing and token generation slower but with a greater impact on token generation because you need ~512 more calls to ggml_cuda_op_mul_mat.

ggml-cuda.cu Outdated
return cudaSuccess;
}

current_device = device;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

Yeah, it is very strange that cudaSetDevice is so slow. I guess that was the difference in the ggml-backends branch. I'll see if I can figure a way to improve performance in WSL2 with multi GPU, but the calls to cudaSetDevice seem unavoidable there.

@JohannesGaessler
Copy link
Collaborator Author

On native Windows using my RTX 3090 I am also experiencing a performance uplift. master:

model size params backend ngl threads test t/s
LLaMA 7B mostly Q4_K - Medium 3.80 GiB 6.74 B CUDA 99 1 pp 512 1078.05 ± 108.80
LLaMA 7B mostly Q4_K - Medium 3.80 GiB 6.74 B CUDA 99 1 tg 128 63.50 ± 0.23

build: 1b0d092 (1213)

PR:

model size params backend ngl threads test t/s
LLaMA 7B mostly Q4_K - Medium 3.80 GiB 6.74 B CUDA 99 1 pp 512 1695.71 ± 45.13
LLaMA 7B mostly Q4_K - Medium 3.80 GiB 6.74 B CUDA 99 1 tg 128 69.48 ± 13.08

build: c923de7 (1212)

@JohannesGaessler
Copy link
Collaborator Author

I'll see if I can figure a way to improve performance in WSL2 with multi GPU, but the calls to cudaSetDevice seem unavoidable there.

Use nsys profile and look at the graphs that indicate compute and I/O utilization over time. This may help with determining whether the issue occurs before or after the actual computation happens. For instance, on my 3x P40 system the biggest issue (for prompt processing) is currently writing back the f32 results to the main GPU.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

Unfortunately, nsight systems doesn't work under WSL2. I'll test on native Windows though.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

The performance improvement under native Windows is similar (to WSL2) for me, which is good since it improves single GPU tg performance massively. Multi GPU tg is faster than master, but still very slow.

WINDOWS:

MASTER:

ggml_init_cublas: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2279.97 ± 60.85
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 42.28 ± 0.14

build: 1b0d092 (1213)

ggml_init_cublas: found 2 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6
Device 1: NVIDIA GeForce RTX 3080, compute capability 8.6

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 914.66 ± 12.79
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 6.78 ± 1.68

build: 1b0d092 (1213)

PR:

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2343.56 ± 90.51
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 114.97 ± 0.82

build: a599006 (1212)

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6
Device 1: NVIDIA GeForce RTX 3080, compute capability 8.6

model size params backend ngl test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 855.89 ± 13.49
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 tg 128 11.12 ± 1.33

build: a599006 (1212)

I'll check the multi GPU trace with nsight systems, but I think this can be merged already.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

Btw, I wonder if the difference between your results and mine (there is no way that the 3090Ti is that much faster than the 3090) is due to hardware-accelerated GPU scheduling, I have it disabled. I am also on Windows 11.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Sep 11, 2023

I had hardware-accelerated GPU scheduling enabled. This is the performance after disabling it:

model size params backend ngl threads test t/s
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 1 pp 512 2084.13 ± 97.16
LLaMA 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 1 tg 128 87.60 ± 20.60

build: c923de7 (1212)

model size params backend ngl threads test t/s
LLaMA 7B mostly Q4_K - Medium 3.80 GiB 6.74 B CUDA 99 1 pp 512 1902.89 ± 77.68
LLaMA 7B mostly Q4_K - Medium 3.80 GiB 6.74 B CUDA 99 1 tg 128 72.73 ± 9.38

build: c923de7 (1212)%

@JohannesGaessler JohannesGaessler merged commit d54a402 into ggerganov:master Sep 11, 2023
28 checks passed
@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

I am not entirely sure what is going on, but it looks like the latency is so bad that by the time the second GPU starts working on its part the first GPU is already done, so there is very little parallelism. Lots of dead time between operations.
image

@Dampfinchen
Copy link

Strange. I'm getting an out of memory error trying to offload 7B fully with this PR. In a previous committ, I was able to offload q3k_l on my RTX 2060 laptop with 6 GB VRAM. But this PR throws me an out of memory error.

-n 180 -c 2048 -t 6 --gpu-layers 99 and a prompt with a context of around 1800 tokens for both.

@slaren
Copy link
Collaborator

slaren commented Sep 11, 2023

Looks like this broke LoRA support.

#0  0x0000000000000000 in ?? ()
#1  0x000055555568eb4f in ggml_cuda_op_mul_mat_cublas (src0=0x7fff49070030, src1=0x7fff49138180, dst=0x7fff492002d0, src0_dd_i=0x1326a20400 "",
    src1_ddf_i=0x1326af2400, src1_ddq_i=0x0, dst_dd_i=0x1327400000, row_low=0, row_high=3200, src1_ncols=3200, src1_padded_row_size=512,
    stream=@0x7ffffffcef18: 0x5555566dad10) at /home/slaren/code/llama.cpp/ggml-cuda.cu:5637

The type of src0 is already f32 so to_fp32_cuda is NULL.

@cebtenzzre
Copy link
Collaborator

This PR broke #2506.

CUDA error 400 at /home/cebtenzzre/src/forks/llama.cpp/ggml-cuda.cu:6146: invalid resource handle
current device: 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants