-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : greatly reduce output buffer memory usage #6122
Conversation
The first logits used to evaluate the second choice were not from the end of the common prefix; instead, they were the logits from the end of the first choice. This has been corrected. The previous implementation sometimes had outliers in the scores of choices for some tasks, and the logic to skip choices words in the log-likelihood evaluation probably was an attempt to reduce those, but it was complex and didn't quite seem to be the right thing. This is simpler now, and the outlier scores aren't there anymore.
A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch.
As it is, this breaks pipeline parallelism because changes in the graph topology force a synchronization. I think it should be possible to fix this if the final You can test this without multiple GPUs by building in debug with |
It previously worked because lctx.inp_out_ids was not initialized, so it pointed to some garbage address which was somehow still valid when I ran my tests.
Yes, this should be possible. I initially put the skip of the rest of the graph there because tensors with 0 rows caused division by zero problems in
I'll see if I can try this on my Intel UHD Graphics 615 (since splits don't seem to happen on CPU-only).
which may be problematic when using a GPU. I'll try to avoid changing the graph topology. |
Yes, I forgot that the reallocation is automatic if there is only one backend, but that message also shows the problem. If that message disappears, then it should also work with pipeline parallelism. |
* ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1
Test result using Radeon VII HipBLAS build under Linux environment.
pr:
|
That sounds really promising! However, when trying to run this PR, I'm getting this error:
I was compiling it with CMake, Cuda Toolkit 12.3. RTX 2060, Core i7 9750H, 32 GB. Does this PR only support llama or its derivatives like Mistral as well? My model was Fimbulvetr-v2 at IQ4_XS, which is a Solar 10.7B merge. |
* llama : rework reallocation logic for llama_output_reserve Now comparing the actual size with the new total size of the output buffer to allow more efficient enabling and disabling of the embeddings and/or logits output in the future.
Answering 2 messages at once:
@fgdfgfthgr-fox The expected decrease is at most
It should be significant when using a very large batch size. For example, with
@Dampfinchen Thanks for testing this with CUDA. A more complete backtrace would be helpful1, but in this case this is probably happening because I didn't yet make the other backends than CPU skip empty tensors, so what you're seeing is likely the symptoms of a GPU-accelerated division by zero. Hopefully fixed by 8b826c5. This didn't happen to @fgdfgfthgr-fox because they didn't offload enough layers for the last one to be offloaded (which can now have tensors with no elements when no logits are used, causing problems when dividing dimensions with each other if not skipped).
The goal is for this to support every of the 23+ model architectures supported by Footnotes
|
Works well with CUDA, improves pp performance by 2-3% with a single GPU.
|
Note that the output buffer is always allocated in a CPU buffer, so this shouldn't affect VRAM usage. |
@compilade I can confirm the issue I've had has been fixed. Good work! |
command-r seems broken, it works until build b2536. INFO [ server_params_parse] logging to file is disabled. | tid="196984" timestamp=1711615937 |
@Wuzzooy I'm truly sorry about that, it seems the model's graph was structured a bit differently than the other models ( Should be fixed by #6367 (hopefully). |
* Support xverse model convert to gguf format. * 1. Convert xverse models to gguf; 2. Add LLM_ARCH_XVERSE inference in llama.cpp; 3. Add xverse item in Supported models in README.md; * * gguf-py: remove redundant logs * llama: remove the init_mapping_prefetch custom parameter * llama.cpp: Include the changes from #6122 to exclude the unused outputs of the last layers. * - Fix format issues - Remove duplicate set kqv_out to llm_build_kv * Update llama.cpp --------- Co-authored-by: willhe <willhe@xverse.cn> Co-authored-by: willhe <hexin@xverse.cn>
op_getrows_f32 is required since ggerganov#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
op_getrows_f32 is required since ggerganov#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
* llama : greatly reduce logits memory usage * llama : more compact state saving and reloading * llama : fix lctx.n_outputs not being set before building graph * perplexity : adapt to the logits API changes * perplexity : fix Winogrande, use correct logits for second choice start The first logits used to evaluate the second choice were not from the end of the common prefix; instead, they were the logits from the end of the first choice. This has been corrected. The previous implementation sometimes had outliers in the scores of choices for some tasks, and the logic to skip choices words in the log-likelihood evaluation probably was an attempt to reduce those, but it was complex and didn't quite seem to be the right thing. This is simpler now, and the outlier scores aren't there anymore. * perplexity : normalize spaces and punctuation in Winogrande sentences * llama : fix embedding conditions * llama : fix llama_get_embeddings_ith when the resulting id is 0 * llama : fix wrong n_outputs in llama_set_inputs A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch. * llama : fix not-skipping outputs of non-causal models * llama : fix running a batch with n_outputs == 0 It previously worked because lctx.inp_out_ids was not initialized, so it pointed to some garbage address which was somehow still valid when I ran my tests. * llama : keep same graph topology even when n_outputs == 0 * ggml : saner ggml_can_repeat with empty tensors * ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1 * ggml : do not multi-thread ops returning empty tensors * ggml : make ggml_is_empty public and work with views * llama : use a vector for ctx->output_ids * llama : rework reallocation logic for llama_output_reserve Now comparing the actual size with the new total size of the output buffer to allow more efficient enabling and disabling of the embeddings and/or logits output in the future. * ggml : skip empty tensors in all backends * llama : fix llama_output_reserve nullptr deref when new_size is 0 * perplexity : make Winogrande work as it does on master The problems with the Winogrande implementation will need to be fixed in a separate PR to ease review. * llama : clearer error messages for invalid logits or embeddings ids * llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them * llama : handle errors from llama_output_reserve at call sites * perplexity : make hellaswag and multiple-choice outputs identical to master Due to how the KV cache is updated, the logprobs for tokens in a batch are very slightly affected by the other tokens present in the batch, so to make hellaswag and multiple-choice return exactly the same results as on master, the last token of each sequence needs to be evaluated even though its output is not used at all. This will probably be changed back in the future to make these benchmarks a tiny bit faster. * perplexity : fix division by zero when using less than 100 multiple-choice tasks * llama : allow loading state saved with a different ctx size When loading a session file, the context size is now only required to be at least enough to load the KV cells contained in that session file, instead of requiring to use exactly the same context size as when saving. Doing this enables the use-case of extending or shrinking the context size of a saved session. This breaks existing session files because the meaning of kv_buf_size is slightly changed (previously it was the size of the whole KV cache, now it's only the size of the saved part of it). This allows for finer-grained sanity checks when loading in an effort to keep kv_buf_size useful even when the kv_size is changed. * llama : minor ggml-ci * readme : update recent API changes, and warn about Vulkan --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* Support xverse model convert to gguf format. * 1. Convert xverse models to gguf; 2. Add LLM_ARCH_XVERSE inference in llama.cpp; 3. Add xverse item in Supported models in README.md; * * gguf-py: remove redundant logs * llama: remove the init_mapping_prefetch custom parameter * llama.cpp: Include the changes from ggerganov#6122 to exclude the unused outputs of the last layers. * - Fix format issues - Remove duplicate set kqv_out to llm_build_kv * Update llama.cpp --------- Co-authored-by: willhe <willhe@xverse.cn> Co-authored-by: willhe <hexin@xverse.cn>
* llama : greatly reduce logits memory usage * llama : more compact state saving and reloading * llama : fix lctx.n_outputs not being set before building graph * perplexity : adapt to the logits API changes * perplexity : fix Winogrande, use correct logits for second choice start The first logits used to evaluate the second choice were not from the end of the common prefix; instead, they were the logits from the end of the first choice. This has been corrected. The previous implementation sometimes had outliers in the scores of choices for some tasks, and the logic to skip choices words in the log-likelihood evaluation probably was an attempt to reduce those, but it was complex and didn't quite seem to be the right thing. This is simpler now, and the outlier scores aren't there anymore. * perplexity : normalize spaces and punctuation in Winogrande sentences * llama : fix embedding conditions * llama : fix llama_get_embeddings_ith when the resulting id is 0 * llama : fix wrong n_outputs in llama_set_inputs A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch. * llama : fix not-skipping outputs of non-causal models * llama : fix running a batch with n_outputs == 0 It previously worked because lctx.inp_out_ids was not initialized, so it pointed to some garbage address which was somehow still valid when I ran my tests. * llama : keep same graph topology even when n_outputs == 0 * ggml : saner ggml_can_repeat with empty tensors * ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1 * ggml : do not multi-thread ops returning empty tensors * ggml : make ggml_is_empty public and work with views * llama : use a vector for ctx->output_ids * llama : rework reallocation logic for llama_output_reserve Now comparing the actual size with the new total size of the output buffer to allow more efficient enabling and disabling of the embeddings and/or logits output in the future. * ggml : skip empty tensors in all backends * llama : fix llama_output_reserve nullptr deref when new_size is 0 * perplexity : make Winogrande work as it does on master The problems with the Winogrande implementation will need to be fixed in a separate PR to ease review. * llama : clearer error messages for invalid logits or embeddings ids * llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them * llama : handle errors from llama_output_reserve at call sites * perplexity : make hellaswag and multiple-choice outputs identical to master Due to how the KV cache is updated, the logprobs for tokens in a batch are very slightly affected by the other tokens present in the batch, so to make hellaswag and multiple-choice return exactly the same results as on master, the last token of each sequence needs to be evaluated even though its output is not used at all. This will probably be changed back in the future to make these benchmarks a tiny bit faster. * perplexity : fix division by zero when using less than 100 multiple-choice tasks * llama : allow loading state saved with a different ctx size When loading a session file, the context size is now only required to be at least enough to load the KV cells contained in that session file, instead of requiring to use exactly the same context size as when saving. Doing this enables the use-case of extending or shrinking the context size of a saved session. This breaks existing session files because the meaning of kv_buf_size is slightly changed (previously it was the size of the whole KV cache, now it's only the size of the saved part of it). This allows for finer-grained sanity checks when loading in an effort to keep kv_buf_size useful even when the kv_size is changed. * llama : minor ggml-ci * readme : update recent API changes, and warn about Vulkan --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* Support xverse model convert to gguf format. * 1. Convert xverse models to gguf; 2. Add LLM_ARCH_XVERSE inference in llama.cpp; 3. Add xverse item in Supported models in README.md; * * gguf-py: remove redundant logs * llama: remove the init_mapping_prefetch custom parameter * llama.cpp: Include the changes from ggerganov#6122 to exclude the unused outputs of the last layers. * - Fix format issues - Remove duplicate set kqv_out to llm_build_kv * Update llama.cpp --------- Co-authored-by: willhe <willhe@xverse.cn> Co-authored-by: willhe <hexin@xverse.cn>
* llama : greatly reduce logits memory usage * llama : more compact state saving and reloading * llama : fix lctx.n_outputs not being set before building graph * perplexity : adapt to the logits API changes * perplexity : fix Winogrande, use correct logits for second choice start The first logits used to evaluate the second choice were not from the end of the common prefix; instead, they were the logits from the end of the first choice. This has been corrected. The previous implementation sometimes had outliers in the scores of choices for some tasks, and the logic to skip choices words in the log-likelihood evaluation probably was an attempt to reduce those, but it was complex and didn't quite seem to be the right thing. This is simpler now, and the outlier scores aren't there anymore. * perplexity : normalize spaces and punctuation in Winogrande sentences * llama : fix embedding conditions * llama : fix llama_get_embeddings_ith when the resulting id is 0 * llama : fix wrong n_outputs in llama_set_inputs A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch. * llama : fix not-skipping outputs of non-causal models * llama : fix running a batch with n_outputs == 0 It previously worked because lctx.inp_out_ids was not initialized, so it pointed to some garbage address which was somehow still valid when I ran my tests. * llama : keep same graph topology even when n_outputs == 0 * ggml : saner ggml_can_repeat with empty tensors * ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1 * ggml : do not multi-thread ops returning empty tensors * ggml : make ggml_is_empty public and work with views * llama : use a vector for ctx->output_ids * llama : rework reallocation logic for llama_output_reserve Now comparing the actual size with the new total size of the output buffer to allow more efficient enabling and disabling of the embeddings and/or logits output in the future. * ggml : skip empty tensors in all backends * llama : fix llama_output_reserve nullptr deref when new_size is 0 * perplexity : make Winogrande work as it does on master The problems with the Winogrande implementation will need to be fixed in a separate PR to ease review. * llama : clearer error messages for invalid logits or embeddings ids * llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them * llama : handle errors from llama_output_reserve at call sites * perplexity : make hellaswag and multiple-choice outputs identical to master Due to how the KV cache is updated, the logprobs for tokens in a batch are very slightly affected by the other tokens present in the batch, so to make hellaswag and multiple-choice return exactly the same results as on master, the last token of each sequence needs to be evaluated even though its output is not used at all. This will probably be changed back in the future to make these benchmarks a tiny bit faster. * perplexity : fix division by zero when using less than 100 multiple-choice tasks * llama : allow loading state saved with a different ctx size When loading a session file, the context size is now only required to be at least enough to load the KV cells contained in that session file, instead of requiring to use exactly the same context size as when saving. Doing this enables the use-case of extending or shrinking the context size of a saved session. This breaks existing session files because the meaning of kv_buf_size is slightly changed (previously it was the size of the whole KV cache, now it's only the size of the saved part of it). This allows for finer-grained sanity checks when loading in an effort to keep kv_buf_size useful even when the kv_size is changed. * llama : minor ggml-ci * readme : update recent API changes, and warn about Vulkan --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* Support xverse model convert to gguf format. * 1. Convert xverse models to gguf; 2. Add LLM_ARCH_XVERSE inference in llama.cpp; 3. Add xverse item in Supported models in README.md; * * gguf-py: remove redundant logs * llama: remove the init_mapping_prefetch custom parameter * llama.cpp: Include the changes from ggerganov#6122 to exclude the unused outputs of the last layers. * - Fix format issues - Remove duplicate set kqv_out to llm_build_kv * Update llama.cpp --------- Co-authored-by: willhe <willhe@xverse.cn> Co-authored-by: willhe <hexin@xverse.cn>
op_getrows_f32 is required since #6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
== Relevant log messages from source repo: commit 3d7ebf63123b8652fb7bbecef7ba731202309901 Author: 0cc4m <picard12@live.de> Date: Mon Jun 3 10:59:14 2024 +0200 Vulkan Mixture of Experts (MoE) support (#7628) * Finish Vulkan mul_mat_id implementation * Add Vulkan sum_rows and div ops * Fix MUL_MAT_ID matrix matrix shader * Fix MUL_MAT_ID matrix vector shader dispatch size * Fix MUL_MAT_ID matrix vector shader and dispatch code * Update Vulkan CPU offload for MUL_MAT_ID * Fix crash when using split mode none and setting a main GPU commit a10cda58d3199cd85305e0f03a8c6056714ae2e8 Author: Andy Tai <andy-tai@users.noreply.github.com> Date: Mon Jun 3 01:06:24 2024 -0700 cmake : add pkg-config spec file for llama.cpp (#7702) commit 6f28a333c1e3fdfdc7b4f9d0367f2b41a9b7e9d4 Author: zhangkaihuo <zhangkaihuo@gmail.com> Date: Mon Jun 3 15:49:30 2024 +0800 llama : MiniCPM support tied embeddings (#7664) * support lm_head * remove the code block --------- Co-authored-by: zhangkaihuo <zhangkaihuo@modelbest.cn> commit 549279d8049d78620a2b081e26edb654f83c3bbd Author: Georgi Gerganov <ggerganov@gmail.com> Date: Mon Jun 3 08:34:43 2024 +0300 llama : avoid double token-to-piece cache (#7654) ggml-ci commit 9e405b6e2ecb888e860f7b92720b4809e21b3915 Author: woachk <24752637+woachk@users.noreply.github.com> Date: Mon Jun 3 07:32:16 2024 +0200 kompute : implement op_getrows_f32 (#6403) op_getrows_f32 is required since ggerganov/llama.cpp#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
op_getrows_f32 is required since ggerganov/llama.cpp#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
op_getrows_f32 is required since ggerganov/llama.cpp#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
op_getrows_f32 is required since ggerganov/llama.cpp#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
op_getrows_f32 is required since ggerganov/llama.cpp#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
op_getrows_f32 is required since ggerganov/llama.cpp#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
op_getrows_f32 is required since ggerganov/llama.cpp#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
Supersedes #2700.
As I've noted in #6017 (comment), the logits buffer is way too big and mostly unused when the batch size is very large. This PR fixes this waste of memory by introducing another, smaller buffer of ids (
ctx->output_ids
) which point into the logits and embeddings buffers, allowing to keep most of the behavior ofllama_get_logits_ith()
andllama_get_embeddings_ith()
while making the output buffer's content contiguous. While I was at it, I've also noticed it was relatively easy to skip computing unused logits with the new output buffer layout introduced in this change.From @slaren's suggestion in #6017 (comment), this PR allocates space for
n_seq_max
outputs, then dynamically reallocates the output buffer when more logits and/or embeddings are necessary.(Thanks for making me think to look for #2700 when alluding to it in #6017 (comment))
API changes
llama_get_logits
andllama_get_embeddings
now return contiguous arrays (stillfloat *
).llama_batch.logits[i] == true
are stored contiguously in the order they have in the batch.logits_all == true
and when usingllama_batch_get_one()
(which doesn't setllama_batch.logits
).logits_all
now has a slight performance incentivellama_get_logits_ith
andllama_get_embeddings_ith
now always verify if the passed id is valid.perplexity
example previously relied on the non-verification of this (assertion error when running with a Debug build onmaster
), which has been fixed. (in other words, this fixes perplexity assert fails #6246)512
previously used at least63 MiB
, but now the equivalent session file takes only1 MiB
(mostly composed of the used KV cache cells from my test prompt, it now only includes the needed logits from the last batch), regardless of the batch size, whereas a batch size of1024
previously would use around126 MiB
with Tinyllama.Notes
The
perplexity
example used the previous layout of the logits very extensively.I've adapted most of it, but the parts of it which still use
logits_all
(even though this still works) will need to be changed in the future to benefit from skipping the computation of unused logits.The
perplexity
example should have the exact same output asmaster
when compiled with the same flags.Well, except with Winogrande, because I've modified the implementation.(EDIT: Winogrande output should be the same asmaster
as of 8f70dcb.)(note to @ikawrakow: the logic for skipping the choice words in
winogrande_score()
shouldn't be required, but I still kept it.so I've simplified it by also fixing the logits used when evaluating the second choice (previously, the logits of the end of the first choice were used there instead of the logits of the end of the common prefix, which caused a big skew in the log-likelyhood of some second choices). This will need to be fixed in a separate PR.)TODO
Since a (small) model-specific change was required for each of the 23+ architectures supported by
llama.cpp
, I'd like to at least ensure I didn't break any of them. I'd really like to know if it works on GPU and/or with MoE models.Feel free to edit the list below to add more tested models and examples.
Compare the output with
--temp 0
when using this PR vsmaster
.main
,parallel
(with and without-cb
),perplexity
(v1, v2, hellaswag, multiple-choices),imatrix
,save-load-state
,embedding
)speculative
(with Llama-160M as a draft model),save-load-state
))master
since common : disable repeat penalties by default #6127, so for the outputs to be the same,--repeat-penalty 1
now has to be passed.main
,save-load-state
)server
)llama-bench
,main
))Known issues:
server
with embeddingsGGML_OP_GET_ROWS
and don't fallback properly, (e.g. the Vulkan backend)