Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: gemma 2 27B GGML_ASSERT n_dims <= ne0 #8246

Closed
duynt575 opened this issue Jul 1, 2024 · 11 comments · Fixed by #8348
Closed

Bug: gemma 2 27B GGML_ASSERT n_dims <= ne0 #8246

duynt575 opened this issue Jul 1, 2024 · 11 comments · Fixed by #8348
Labels
bug-unconfirmed low severity Used to report low severity bugs in llama.cpp (e.g. cosmetic issues, non critical UI glitches)

Comments

@duynt575
Copy link

duynt575 commented Jul 1, 2024

What happened?

I got the error using different quants from different authors. After asking llm a few times, llama.cpp crashed with these two lines:

GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml\src\ggml.c:13968: n_dims <= ne0
GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml\src\ggml.c:13968: n_dims <= ne0

Name and Version

Tested using b3266 and b3276, same result.

What operating system are you seeing the problem on?

Windows

Relevant log output

GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml\src\ggml.c:13968: n_dims <= ne0
GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml\src\ggml.c:13968: n_dims <= ne0
@duynt575 duynt575 added bug-unconfirmed low severity Used to report low severity bugs in llama.cpp (e.g. cosmetic issues, non critical UI glitches) labels Jul 1, 2024
@EliEron
Copy link

EliEron commented Jul 2, 2024

I'm experiencing the exact same issue. Though I'd classify it as a critical issue, as having llama.cpp constantly crash is quite detrimental.

For me the issue only started after the logit soft-capping was merged, before that there was no crashing, but obviously the generation quality was way lower before that fix.

And in case it matters I'm on Windows 11.

@derjoshder
Copy link

Having a similar issue with b3280 llama.cpp (llama-cli.exe, Cu12, 64bit)
Windows 10, started with:
llama-cli.exe -i --interactive-first -r "### Human:" -t 4 --ignore-eos --temp 0.8 --color -c 1024 -n -1 -ngl 20 --repeat_penalty 1.2 -m %scriptpath%models\gemma-2-27b-it-Q5_K_M.gguf
Hardware: 3060 (12GB) and 3050 (8GB)
It interrupts inference in the middle of my second answer (the path it gives here does not exist, in case that matters) :
GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml\src\ggml.c:13968: n_dims <= ne0
GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml\src\ggml.c:13968: n_dims <= ne0

@ko-alex
Copy link

ko-alex commented Jul 4, 2024

On Debian 12

GGML_ASSERT: ggml/src/ggml.c:13968: n_dims <= ne0
GGML_ASSERT: ggml/src/ggml.c:13968: n_dims <= ne0
GGML_ASSERT: ggml/src/ggml.c:13968: n_dims <= ne0
GGML_ASSERT: ggml/src/ggml.c:13968: n_dims <= ne0
GGML_ASSERT: ggml/src/ggml.c:13968: n_dims <= ne0
GGML_ASSERT: ggml/src/ggml.c:13968: n_dims <= ne0
[New LWP 6459]
[New LWP 6463]
[New LWP 6464]
[New LWP 6465]
[New LWP 6466]
[New LWP 6467]
[New LWP 6468]
[New LWP 6469]
warning: process 6458 is already traced by process 6728
warning: process 6458 is already traced by process 6728
ptrace: Operation not permitted.ptrace: Operation not permitted.warning: process 6458 is already traced by process 6728
warning: process 6458 is already traced by process 6728
warning: process 6458 is already traced by process 6728
ptrace: Operation not permitted.

ptrace: Operation not permitted.ptrace: Operation not permitted.

No stack.
No stack.No stack.

No stack.No stack.

The program is not being run.
The program is not being run.The program is not being run.

The program is not being run.
The program is not being run.
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007ff8b14f2b57 in wait4 () from /lib/x86_64-linux-gnu/libc.so.6
#0 0x00007ff8b14f2b57 in wait4 () from /lib/x86_64-linux-gnu/libc.so.6
#1 0x000055869825790b in ggml_print_backtrace ()
#2 0x0000558698288218 in ggml_compute_forward_rope_f16 ()
#3 0x000055869828e97b in ggml_graph_compute_thread.constprop.0.isra ()
#4 0x000055869828eb25 in ggml_graph_compute._omp_fn ()
#5 0x00007ff8b84ed0b6 in GOMP_parallel () from /lib/x86_64-linux-gnu/libgomp.so.1
#6 0x0000558698292d10 in ggml_graph_compute ()
#7 0x000055869829f2b1 in ggml_backend_cpu_graph_compute ()
#8 0x00005586982a4295 in ggml_backend_sched_graph_compute_async ()
#9 0x00005586982e6d2e in llama_kv_cache_update ()
#10 0x00005586982f4926 in llama_decode ()
#11 0x00005586981158da in main ()
[Inferior 1 (process 6458) detached]
[1] 6458 IOT instruction (core dumped)

on master branch, commit f8c4c07

@ggerganov
Copy link
Owner

Find the last commit that works

@ko-alex
Copy link

ko-alex commented Jul 5, 2024

#8156 Add support for Gemma2ForCausalLM crashes as well

To reproduce, used 8k context and 74k prompt. Larger ctx values improves stability somewhat.

make clean && make GGML_CUDA=1 GGML_LTO=1 -j

./llama-cli -m /work/models/misc/gemma-2-27b-it-Q6_K_L.gguf -t 6 --color --interactive --conversation --multiline-input --mirostat 2 --ctx-size 8192 --n-gpu-layers 12 --keep -1

model from https://huggingface.co/bartowski/gemma-2-27b-it-GGUF/tree/main

@duynt575
Copy link
Author

duynt575 commented Jul 5, 2024

#8156 seems to work for me but the output quality is bad. Another note: I have an rtx 3060 12gb and a gtx 1660 super 6gb. Maybe it has something to do with multiple gpus but I'm not sure (guessing based on to identical error lines). I'm on Windows 10. @derjoshder have similar setup like me.

@EliEron
Copy link

EliEron commented Jul 5, 2024

I only have one GPU the RTX 3080 10GB. So I don't think it's multi-GPU related. It's also worth mentioning that I'm mostly running the model on CPU, since I don't have a lot of VRAM.

@Nazosan
Copy link

Nazosan commented Jul 6, 2024

Apparently this issue has made it downstream. I'm getting it in KoboldCPP (utilizes llama.cpp internally) on my AMD 7800XT (using HIPBLAS.) Also single GPU, though of course I'm not offloading all layers. I checked the ggml.c file it references and this is the function the crash points to:

                if (!is_neox) {
                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
                        const float cos_theta = cache[i0 + 0];
                        const float sin_theta = cache[i0 + 1];
                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
                        const float x0 = GGML_FP16_TO_FP32(src[0]);
                        const float x1 = GGML_FP16_TO_FP32(src[1]);
                        dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
                        dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                    }

Specifically it crashes on the "const float x1" line. The crash always points to the same ggml.c file on the exact same line every time. For me it happens in a pretty specific way every time too. I'm running Gemma 2 27B Q4_K_L (I also tried Q4_K_M originally then switched to this with exactly the same result) at 16384 context and I might get a few rare crashes as early as 10K in the context as I go, but very rare at that point. Then, when I finally fill the full 16K it crashes maybe every second generation or so. I'm connecting to the API with SillyTavern generally and I don't know if it doesn't erase back context or what, but I can start a new chat and it may still crash nearly immediately. I have to fully close things out and reload and then it stops crashing until I reach a higher context again, so I do think this all has something to do with context in some form. I double checked my model file, even completely redownloading the whole thing and the md5sum checks out the same (as a side note, it sure would be nice if sites like HuggingFace and co listed checksums so one could verify in a more intelligent way than just downloading multiple times and checking sums of each.)

Not sure if this directly helps, but at least it does confirm the issue can happen in different setups, including AMD. I'm on Manjaro 24.0.3.

EDIT: Upon someone's suggestion I disabled the context shift feature of KoboldCPP to see what would happen. This seems to have stopped the crashes, though it also results in having to process the full prompt almost every time too unfortunately. (Very slow on my hardware unfortunately.) Interestingly, with context shift disabled, it seems to be limiting context now to less than the maximum. It generally limits to around 14.2K or so (out of my 16) but when I hit a global key or similar it can get as high as 15.8K -- just under 16.

Could this be something like maybe estimating tokens wrong and overflowing?

I guess the rest of you aren't using context shifting since you're on llama.cpp. I wonder what happens if you adjust your maximum outputs so they are less than what you've set llama.cpp to load? Eg something like 6K or so for you 8K users.

@EliEron
Copy link

EliEron commented Jul 7, 2024

as a side note, it sure would be nice if sites like HuggingFace and co listed checksums so one could verify in a more intelligent way than just downloading multiple times and checking sums of each.

HuggingFace does actually list SHA256 checksums for all large files if you click on them. Down in the "Git LFS Details" section.
HFSHA256

@ggerganov
Copy link
Owner

Check if #8348 fixes the issue

@ko-alex
Copy link

ko-alex commented Jul 7, 2024

Check if #8348 fixes the issue

Looks good

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug-unconfirmed low severity Used to report low severity bugs in llama.cpp (e.g. cosmetic issues, non critical UI glitches)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants