Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

whisper : fix excessive memory usage #2443

Merged
merged 2 commits into from
Oct 5, 2024
Merged

whisper : fix excessive memory usage #2443

merged 2 commits into from
Oct 5, 2024

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Oct 2, 2024

alt #2433

  • dynamically resize the KV cache based on the number of decoders needed
  • do not allocate unused intermediate tensors in the compute graphs

@WilliamTambellini
Copy link
Contributor

WilliamTambellini commented Oct 2, 2024

ok tks very much @ggerganov
I m testing this change and atm I see a huge increase in memory need, so high that the lib actually easily oom:

...
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 560.00 MiB on device 0: cudaMalloc failed: out of memory
whisper_kv_cache_init: failed to allocate memory for the kv cache
whisper_full_with_state: whisper_kv_cache_init() failed for self-attention cache
bin/main: failed to process audio

I now have to decrease beamsize to 4 in order to be able to decode anything (using the v3 model) without OOM (6GB GPU).
I guess:

  • my change in PR 2433 only increases the size of the text kv cache (encoder?)
  • your current change increases the size of all caches

?
best
W

@ggerganov
Copy link
Owner Author

my change in PR 2433 only increases the size of the text kv cache (encoder?)

This PR also increases only the decoder KV cache. The KV cache size for the large model is 84 MB per decoder. So if you want to run 8 beams, you will need extra 672 MB of VRAM.

your current change increases the size of all caches

No, it doesn't.

I don't think this PR uses more memory than it should. Could you double-check your findings?

@WilliamTambellini
Copy link
Contributor

WilliamTambellini commented Oct 3, 2024

Sure.
With the v3 model:
whisper 1.5.4 :

whisper_backend_init: using CUDA backend
whisper_model_load:     CUDA buffer size =  3094.86 MB
whisper_model_load: model size    = 3094.36 MB
whisper_backend_init: using CUDA backend
whisper_init_state: kv self size  =  220.20 MB
whisper_init_state: kv cross size =  245.76 MB
whisper_init_state: compute buffer (conv)   =   32.42 MB
whisper_init_state: compute buffer (encode) =  212.42 MB
whisper_init_state: compute buffer (cross)  =    9.38 MB
whisper_init_state: compute buffer (decode) =   99.24 MB
...

beamsize 5: no OOM.
beamsize higher than 5 fails because running out of kv cache slots (because cache text factor = 3).

Todays master:
beamsize 5: no problem but low quality, some repetitions.
beamsize 8:

whisper_init_state: kv self size  =  251.66 MB
whisper_init_state: kv cross size =  251.66 MB
whisper_init_state: kv pad  size  =    7.86 MB
whisper_init_state: compute buffer (conv)   =   36.13 MB
whisper_init_state: compute buffer (encode) =  926.53 MB
whisper_init_state: compute buffer (cross)  =    9.25 MB
whisper_init_state: compute buffer (decode) =  215.82 MB
...
whisper_full_with_state: failed to decode

Just running out of kv cache slot.

With your current change (2443):
beamsize 5 or more: OOM at the very begining:

main: processing 'samples/proselytism.wav' (44928941 samples, 2808.1 sec), 4 threads, 1 processors, 5 beams + best of 5, lang = en, task = transcribe, timestamps = 1 ...
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 560.00 MiB on device 0: cudaMalloc failed: out of memory
whisper_kv_cache_init: failed to allocate memory for the kv cache
whisper_full_with_state: whisper_kv_cache_init() failed for self-attention cache

beamsize 4: OOM after 7mn:

whisper_model_load:    CUDA0 total size =  3094.36 MB
whisper_model_load: model size    = 3094.36 MB
whisper_backend_init_gpu: using CUDA backend
whisper_init_state: kv self size  =   83.89 MB
whisper_init_state: kv cross size =  251.66 MB
whisper_init_state: kv pad  size  =    7.86 MB
whisper_init_state: compute buffer (conv)   =   36.13 MB
whisper_init_state: compute buffer (encode) =  926.53 MB
whisper_init_state: compute buffer (cross)  =    9.25 MB
whisper_init_state: compute buffer (decode) =  208.56 MB
...
[00:06:59.980 --> 00:07:00.980]   But, I think you're going to be less caught up in these fads, but you're going to get
[00:07:00.980 --> 00:07:01.980]   less caught up in these fads, because you're going to be less caught up in these fads, because
[00:07:01.980 --> 00:07:02.980]   you're going to be less caught up in these fads, because you're going to be less caught up
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 560.00 MiB on device 0: cudaMalloc failed: out of memory

beamsize 3: OOM at the end :(
beamsize 2: OOM after 16mn (after repeating a bunch of tokens a dozen of times)

So I double confirm, for some reasons, does nt this PR use more mem than expected ?
Tks

@ggerganov
Copy link
Owner Author

ggerganov commented Oct 4, 2024

Good news - I looked deeper into this and indeed found a huge memory overhead across the implementation. The regression started back in ggerganov/ggml#731. I wasn't careful to adapt the whisper.cpp codebase to the new allocator changes so we ended up allocating way more extra memory than necessary.

I pushed a fix into this branch which should significantly reduce the memory usage. Also, it would be useful if you run some additional tests with Flash Attention enabled (add -fa to the CLI). This will additionally speed-up the computation and reduce the memory usage even further.

Looking for extra feedback whenever you get the time. If everything is good, we should make a new release as soon as possible to resolve this problem. Thanks.

P.S. I temporary lost access to my CUDA workstation for the next few days, so I am unable to run tests with CUDA and will rely on feedback from this discussion.

@WilliamTambellini
Copy link
Contributor

WilliamTambellini commented Oct 4, 2024

Of course.
I just reran and indeed:

  • no OOM whatever beamsize
  • no hallucination anymore
  • no repetition

I even suspect that the quality is now better than the original python implementation.
I will try later with FA but we should better merge that change ASAP and indeed trigger a new release.
Big congrats
Tks
W.
PS: if that could help, we could likely give you access to a free GPUs machines for your testing/implementation.

@ggerganov ggerganov changed the title whisper : fix KV cache allocation whisper : fix excessive memory usage Oct 4, 2024
@ggerganov
Copy link
Owner Author

Thanks, I also ran some tests today an all seems good.

@ggerganov ggerganov merged commit f62a546 into master Oct 5, 2024
87 checks passed
@ggerganov ggerganov deleted the gg/fix-kv-cache branch October 5, 2024 09:36
lyapple2008 pushed a commit to lyapple2008/whisper.cpp.mars that referenced this pull request Nov 2, 2024
* whisper : fix KV cache allocation

* whisper : reduce memory overhead from unused input tensors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants