-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : add option to override model tensor buffers #11397
base: master
Are you sure you want to change the base?
Conversation
Is there a chance that the direction you're taking these changes might allow for scheduling specific threads to work on specific tensors? With R1 coming out, I'm very interested in reviving my work on trying to improve memory locality to increase CPU inference speeds. |
No, that's something that would need to handled at a lower level in the CPU backend. |
Thanks for the reply @slaren. I figured it wouldn't directly help, but that maybe you'd be adding useful metadata to tensor objects that could help coordinate affinity in the future. I'll start a fresh branch and see how far I get.
I'll also try to pull this branch and test it to see what the speedup and sysmem savings look like. |
Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU: -ngl 0 = 4.65t/s So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run. |
@bmtwl |
What are the shared expert tensors called in |
I believe the pattern |
Thanks - I'll give this a try later in the week. This PR together with Reddit post opens up the interesting possibility: https://old.reddit.com/r/LocalLLaMA/comments/1ibbloy/158bit_deepseek_r1_131gb_dynamic_gguf/ of quantising up/gate projections to q2_k and down projections to q4_k (or something similar), then keeping everything else as Sadly I need to move some stuff about to get space to upscale the fp8 download to bf16 before I can try it, but will report back when I do. |
It might be worth trying |
Just being able to split the experts between NUMA nodes would make a big difference, but not sure how easy that would be as IIRC the experts' tensors are all in one huge tensor now? |
During normal operation, When I fit a model between ram and vram, Does the offloading follow a set layer sequence? (layer 0 is chosen first to be offloaded to GPU, then layer 1, etc) Between GPU offloading and ram, which takes priority?
Do you remember how much of a speedup? No need for extensive benchmarks, just the rough % estimate. |
I can't seem to offload more than 29 layers of R1 (unsloth's UD-IQ2_XXS) via RPC. 29 layers and below work fine, but 30 just crashes my rpc_server, with no error output. It is not an issue of VRAM as even setting context very low so that it takes up nowhere near my GPU's limits and it still crashes. |
I had a similar problem where if I used a single GPU (via If I didn't use either of these it tried to allocate this 1.4TB monster buffer:
After some searching I found this issue: and recompiled using (It's likely nothing to do with this PR, but thought it might help!) |
I figured it out: you have to reorder the devices so the local and mainly these:
Means this works: --device "RPC[IP1:PORT1],RPC[IP1:PORT2],RPC[IP1:PORT1],RPC[IP2:PORT2],CUDA0,CUDA1" But if I don't do this I get OOM errors with plenty of VRAM left like you had. |
I'm testing this with and without #11446 and without on unsloth's UD-IQ2_XXS I was only able to offload 29 layers, and with I was able to allocate only 28 (on a Q4_K_S quant). This is not a VRAM issue, it would have plenty of spare VRAM, it would even get past allocation, and get to warmup, where the rpc-server would then just crash. The other issue is performance the more layers I allocate the worse performance gets while bmtwl shows performance increase with more layers offloaded with non-RPC based offloading. |
I am able to load the model with
But as soon as I send the prompt I receive:
Without the Testing with 4x RTX 3090 and 320GiB RAM. Built with |
Maybe try |
No luck, still the same issue. Oddly enough, the issue only happens when sending more than 450 tokens. |
It's trying to allocate a tensor of size 2^64, which suggest there is an integer overflow somewhere. If you set the environment variable |
It is the Is it possible to try to force this particular one to be allocated into the GPU buffer? |
This is most likely a bug, we need to understand why it is happening and fix it. Since you mentioned that it only happens with large prompts, I suspect that this is caused by a zero-sized tensors. When evaluating a batch where no logits are required (which happens when evaluating a prompt that needs to be split into multiple ubatches), zero-size tensors are created to skip the calculation of the logits. diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index 9a3bf9f29..470ef13e6 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -179,6 +179,9 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
// this should never happen
GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
__func__, size, max_avail);
+ GGML_LOG_ERROR("%s: tensor: %s, shape: %ld %ld %ld %ld, size: %zu",
+ __func__, tensor->name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3],
+ ggml_nbytes(tensor));
GGML_ABORT("not enough space in the buffer");
}
} |
Ok nvm, I think I see the problem. I will push a possible fix soon. |
I've got it working:
I think this could be a super powerful command line option when mixed with RPC! Thanks for adding this! If anybody has a Mac Studio they want to test this on then I can help craft the regexes to test it - I'm interested to see what sort of boost you could get without so many stages of latency. |
I'm up for the testing - I have a mac studio M2 ultra 192GB <---10Gbps---> 13700K+192GBDDR5+RTX6000ada. Also if its helpful (seems to be?) I can get a Thunderbolt Gen4 egpu case and plug my RTX6000ada there... |
It didn't help me due the latency between the parts all pushing the hidden state. I used 10gbit Ethernet for all the machines so not sure upping to 40gbit (or whatever Thunderbolt is) will make that much difference - I think the problem is latency rather than bandwidth for this part sadly. Possibly using InfiniBand might help as IIRC it has lower latency, but not sure. I think the eventual solution would be to have RPC use a better method of pipeline parallelism like Deepspeed: It would definitely help the batch processing, and mixed data and pipeline would remove some latency if multiple GPUs per machine like I have. |
Just figured egpu won't help as Apple silicon cannot run cuda... |
Can you post speeds (with whatever configurations you tested), also not sure how much flash attention would impact speed, but it would shrink that compute buffer. |
I think the RPC stuff is never really gonna work properly until it can do async buffering: the way it is set up now each stage in the pipeline is stalling for every communication and this adds the full latency. If it was async and buffered the next job would start almost immediately with no wait, and you could probably optimise this even more by having the fastest devices at the start of the pipeline and the slowest at the end to get almost no degradation from latency. |
The GPU ultilization could be an illussion. I'll try get some numbers across different setup. |
Many many thanks, @slaren, for this PR. I really hope it gets merged. I have used this What I have learned so far (don't know if it is applicable for R1):
This is the order I found best improves token generation for mixtral on my system. This is with non-RPC devices (cuda), and may not correspond in the same way with other kind of backends. Many thanks again for this PR. |
If you don't mind can you post some more info: |
Sure. I will try to summarize. Note 1: It is no longer a 70% improvement, as I have been able to fit two extra layers by using q8 KV cache. It is now "only" a 60% improvement. Note 2: When I say "All layers" I mean 56 layers. The last layer is always kept on CPU. Note 3: All tests were performed with llama-server, with a long prompt of about 18K tokens. Model and system information:
Summary table:
TG improves, but PP is a bit worse (but acceptable). More details below. Command and resultsBaseline (normal layer offloading):
Partial command: GPU_LAYERS="-fa -ngl 33 -b 1024 -ub 512 --tensor_split 15/15/3 -ctv q8_0 -ctk q8_0" Raw results: prompt eval time = 180153.06 ms / 18020 tokens ( 10.00 ms per token, 100.03 tokens per second)
eval time = 95548.09 ms / 199 tokens ( 480.14 ms per token, 2.08 tokens per second)
total time = 275701.15 ms / 18219 tokens With override tensors (best):
Partial command: GPU_LAYERS="-fa -ngl 56 -b 1024 -ub 512 --tensor_split 35/21/0 \
-ot ([2][3-9]|[3][0-9]|[4][0-3]).ffn_up_exps=CUDA1 \
-ot ([4][4-9]|[5][0-9]).ffn_up_exps=CUDA2 \
-ot ffn_gate_exps=CPU -ot ffn_up_exps=CPU -ctv q8_0 -ctk q8_0" Raw results: prompt eval time = 242236.42 ms / 18020 tokens ( 13.44 ms per token, 74.39 tokens per second)
eval time = 62097.34 ms / 214 tokens ( 290.17 ms per token, 3.45 tokens per second)
total time = 304333.76 ms / 18234 tokens Experiments regarding exps tensors "importance" These are the tests I performed to find what During these experiments:
Summary table of offloaded
Best was offloading the Experiment detailsBaseline (all exps tensors on CPU): Partial command: GPU_LAYERS="-fa -ngl 56 -b 1024 -ub 512 --tensor_split 34/22/0 -ot exps=CPU -ctv q8_0 -ctk q8_0" Raw results: prompt eval time = 275711.20 ms / 18020 tokens ( 15.30 ms per token, 65.36 tokens per second)
eval time = 148688.36 ms / 273 tokens ( 544.65 ms per token, 1.84 tokens per second)
total time = 424399.56 ms / 18293 tokens Offloading Test with all layers offloaded to GPU (automatic for Partial command: GPU_LAYERS="-fa -ngl 56 -b 1024 -ub 512 --tensor_split 34/22/0 \
-ot ffn_up_exps=CPU -ot ffn_down_exps=CPU -ctv q8_0 -ctk q8_0" Raw results: prompt eval time = 241260.09 ms / 18020 tokens ( 13.39 ms per token, 74.69 tokens per second)
eval time = 133838.82 ms / 309 tokens ( 433.14 ms per token, 2.31 tokens per second)
total time = 375098.92 ms / 18329 tokens Offloading Test with all layers offloaded to GPU (automatic for Partial command: GPU_LAYERS="-fa -ngl 56 -b 1024 -ub 512 --tensor_split 34/22/0 \
-ot ffn_down_exps=CPU -ot ffn_gate_exps=CPU -ctv q8_0 -ctk q8_0" Raw results: prompt eval time = 224444.22 ms / 18020 tokens ( 12.46 ms per token, 80.29 tokens per second)
eval time = 116046.18 ms / 272 tokens ( 426.64 ms per token, 2.34 tokens per second)
total time = 340490.41 ms / 18292 tokens Offloading Test with all layers offloaded to GPU (automatic for Partial command: GPU_LAYERS="-fa -ngl 56 -b 1024 -ub 512 --tensor_split 34/22/0 \
-ot ffn_up_exps=CPU -ot ffn_gate_exps=CPU -ctv q8_0 -ctk q8_0" Raw results: prompt eval time = 192436.88 ms / 18020 tokens ( 10.68 ms per token, 93.64 tokens per second)
eval time = 102894.79 ms / 261 tokens ( 394.23 ms per token, 2.54 tokens per second)
total time = 295331.68 ms / 18281 tokens (Summary table moved to top of section) |
I thought I would give this a test, but maybe I am doing something wrong. It seems to give me no change in performance. My system is a threadripper 7965WX, 512 GB system memory, 3090. I am trying to run this on windows 10. It seems to fill up my GPU as well as system memory, so I imagine it is using the GPU. I've tried over a dozen commands to get it working. From a simple "llama-server --model DeepSeek-R1-Q4_K_M-00001-of-00011.gguf -ngl 99 -ot exps=CPU" to much more complicated ones with different options. I either get the same t/s or lower. Maybe it is my GPU? I have a 5090 on the way and will test this again when that arrives. If there are any launch commands you want me to try, I will give it a go. |
@slaren : Thank you so much for this PR. Hopefully some of these test results will be useful feedback: So running with the following: (AMD EPYC 7713 64-Core Processor 256GB RAM + 2xA6000ada): ./build/bin/llama-cli -ub 256 --no-mmap --tensor-split 19,20 --model /data/gguf/DeepSeek-R1-GGUF/DeepSeek-R1-UD-IQ1_S/DeepSeek-R1-UD-IQ1_S-00001-of-00003.gguf --cache-type-k q4_0 --threads 16 --prio 2 --temp 0.6 --ctx-size 8192 --seed 3407 --n-gpu-layers 36 -no-cnv --prompt "<|User|>You are an expert python developer. Write a factorial function using python.<|Assistant|>" -ot '^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$=CUDA0' -ot 'ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding=CUDA0' -ot '^model.embed_tokens=CPU' -ot '^model\\.layers\\..*\\.mlp$=CUDA0' -ot '^model\\.layers\\..*\\.self_attn$=CUDA1' yields:
compared with no -ot flag:
GPUs are running at no more that 25%. Hope this is useful. finally running with: ./build/bin/llama-cli -ub 256 --no-mmap --tensor-split 19,20 --model /data/gguf/DeepSeek-R1-GGUF/DeepSeek-R1-UD-IQ1_S/DeepSeek-R1-UD-IQ1_S-00001-of-00003.gguf --cache-type-k q4_0 --threads 16 --prio 2 --temp 0.6 --ctx-size 8192 --seed 3407 --n-gpu-layers 36 -no-cnv --prompt "<|User|>You are an expert python developer. Write a factorial function using python.<|Assistant|>" -ot ffn_up_exps=CPU -ot ffn_gate_exps=CPU
Performance is lower given that GPU memory does not appear to be fully utilised: |
The names of these tensors do not match the names of the tensors in llama.cpp. I suggest running with
You would need to increase the value of |
|
@slaren : thanks for pointing out that I was using the incorrect tensor names (infact the ktransformers were using the model names from safetensor format files and not gguf). So now I have rerun some tests and can see improved GPU usage, increasing to 50%:
However, using the -ot option it's seems impossible to utilise the full memory on the GPUs, the ffn_gate/ffn_up/ffn_down layers are simply too large to be loaded into 48gb vram. But this results in ~3.2 tok/s. The best combination appears to be --ngl 36 --tensor-split 19,20, where I can get over 4.2 tok/s. It seems that the bottleneck would be the CPU memory. With -ot we get more GPU utilisation, but this doesn't seem to make up for the time lost to having some of the layers on slower CPU memory. @jukofyork : I'm using the --ctk q4_0 as per the unsloth: (https://unsloth.ai/blog/deepseekr1-dynamic) If I remove this and use the default. I get CUDA OOM. I have tried the different ctk values but it doesn't appear to be any noticeable performance improvements. |
Have you tried rrcompiling with the right cuda architecture flag set? #4215 Looking at the nvidia docs sm_100 is what you need: |
I looked into that and it seems to have done the trick. cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="100" This changed it to sm=100 while it compiled. I still need to mess with settings to get the best speed, but here is the very first run. llama-server --model DeepSeek-R1-Q4_K_M-00001-of-00011.gguf --flash-attn --threads 36 --temp 0.6 --min-p 0.05 --ctx-size 2048 --no-mmap -ngl 36 -ot exps=CPU I am getting about 28% higher t/s for eval_time. For prompt eval_time, around a 50% improvement. (6.2 t/s / 14.1 t/s). This one leaves a lot of room for context as it only uses 17 GB of GPU memory. llama-server --model DeepSeek-R1-Q4_K_M-00001-of-00011.gguf --flash-attn --threads 36 --temp 0.6 --min-p 0.05 --ctx-size 2048 --no-mmap -ngl 62 -ot exps=CPU This command uses 26 GB of GPU memory, so still 6 GB for extra context over 2k context (I tested this and it uses 31.1 GB at 4096 context). This gets me around eval_time / prompt eval_time (7.8 t/s / 20.5 t/s). Overall, the changes you made lead to a 66% performance increase on eval_time and around 100% performance increase on promp eval_time vs CPU only on a threadripper 7965WX, 512 GB memory, 5090. You are an absolute genius. If you have some proper benches you want me to run, let me know. |
Another update. llama-server --model DeepSeek-R1-Q4_K_M-00001-of-00011.gguf --flash-attn --threads 40 --temp 0.6 --min-p 0.05 --ctx-size 4096 --no-mmap -ngl 62 -ot exps=CPU This uses up all my threads completely and I get a small performance bump. 82% performance increase now on eval time. Make me really want a 64 core threadripper now. Also, a second 5090 for more context. Using 31 GB of GPU memory right now at 4k. I am also curious if getting double the system memory bandwidth will make a difference after the 64 core threadripper upgrade. Maybe I can get up to 10-15 t/s. Another thing I noticed is that it no longer drops off a cliff in inference speed as I continue a story. After 1k context generated, then another new 2k context, the new t/s was still 8.01 t/s. If this was CPU, it would have dropped by 25% by then. The only real limiting factor is that 3.5k context seems like the absolute upper limit. I was having trouble with 4k context. I really need more context. Another issue is that promp eval time is actually all over the place. Sometimes it is fast, sometimes it does this: Another update: I found that --flash-attn makes no difference. Also, I changed --no-mmap to --mlock and I get consistent promp eval now around 12 t/s. Still pretty amazing for running Q4 of R1 on CPU with one consumer grade GPU. Yet another update. This time using Unsloth DeepSeek-R1-UD-IQ2_XXS-00001-of-00004.gguf. This model is still really good and uses only ~200 GB system memory and 27.5 GB GPU memory at 3k context. Was able to get 3600 context max with this unsloth model. The only real limited factor with this setup is context. Any chance KV cache allocation will resolve this issue? |
Thanks for all the testing, I will try to get this ready for merging over the next days. |
Yeah flash-attn is not supported yet in llama.cpp for DeepSeek-R1 psure, check out #11557
This is likely because without those args, llama.cpp defaults to normal thanks for the benchmarks, i'm revisiting this exciting branch after playing with ktransformers and trying to figure out how they get almost 2x inference speeds on R1. i noticed when i disabled CUDA Graphs on ktransformers, it performs almost same as llama.cpp again... however cuda graphs only work when not offloading any experts into VRAM hrmm... anyway enjoying the quest for more tok/sec! cheers! |
You can try it with the PR the comment is from and the modification shown at the bottom of the comment: #11446 (comment) . This further comment showed it worked, #11446 (comment) |
@Reactantvr Thanks for sharing your test results. Just curious, what is the ratings of your DIMM memory you are using on your setup? if you run nvtop do you see your GPU running at max compute? For me it seems that in my testing CPU memory is the limiting factor/bottleneck. |
My memory is 8x64 V-Color DDR5 6000 running at 4800. I didn't bother overclocking it yet because I am on 4 CCDs, which should limit me to around 230 GB/s. I assume I would not get more bandwidth until I upgrade to a 64 core Threadripper. Waiting on Shimada Peak for that. I'll probably run it at 6400 once I get that CPU. I've never used nvtop. Plus, I am doing everything in Windows 10, so not sure if I can use it. I can give you stats from GPU-Z. Looks like GPU load is around 18-19%. This was using DeepSeek-R1-UD-IQ2_XXS. |
Works perfect for me, with dual E5 v2 + 2080Ti by running DeepSeek-R1-UD-Q2_K_XL. boost the token generation speed from 1.8tps to 3.3tps. While disable one node of numa, it can increase to 3.8tps. |
Adds command line parameter
--override-tensor
(-ot
) that allows changing the buffer type where a model tensor is allocated. This gives user fine grained control over what tensors are to offloaded to each device.How is this useful: for example, to force the experts in MoE models to stay on the CPU, while offloading the rest to the GPU, you could use
-ngl 99 -ot exps=CPU
. This may allow more efficient offloading schemes.The syntax is
<tensor name pattern>=<buffer type>
. Currently the pattern is just a string search (edit: this is no longer the case, it is a C++ regex search), ie. any tensors that contains the characters in<tensor name pattern>
will be matched and loaded into the given buffer type. Multiple overrides can be given by separating them with commas, or passing the-ot
option multiple times. To see what tensors are being matched, enable debugging output with-v
.At this point it is just a demo, feel free to experiment and report if you find any interesting uses.
Edit: added regex support, for example to keep experts of layers 20-99 in the CPU you could use
-ot "[2-9][0-9]\.ffn_.*_exps\.=CPU"
TODO: