-
Notifications
You must be signed in to change notification settings - Fork 21
[Kernel] optimize kv cache update kernel block size #360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Chengji Yao <chengjiyao@google.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
| num_kv_pages_per_block=None, | ||
| num_queries_per_block=None, | ||
| vmem_limit_bytes=None, | ||
| vmem_limit_bytes=100 * 1024 * 1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: better to test the throughput before and after this change. Kernel won't be affected but next op's prefetch will be affected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would crash due to vmem OOM before the change.
| page_size: int, | ||
| num_slices_per_block: int, | ||
| dynamic_validate_inputs: bool, | ||
| vmem_limit_bytes: int = 40 * 1024 * 1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this calculated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically I'd like to have 32MB vmem to be used for scratch buffer, and it's round up to 40MB.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest - have an autotune and benchmarking in google workspace internally first. Like we do for RPA and quantized_matmul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! Since I'd like to make the kv cache update kernel on par with vLLM torch/xla path, we can add the auto-tuning and benchmarking later in google internally later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - if the current change blocks you
Description
Tests
client: python3 ./benchmarks/benchmark_serving.py --model meta-llama/Llama-3.1-8B-Instruct --dataset-name sonnet --dataset-path benchmarks/sonnet_4x.txt --sonnet-input-len 1800 --sonnet-output-len 128 --ignore_eos
We can observe the throughput increases from 8.11 req/s to 8.14 req/s.
Checklist
Before submitting this PR, please make sure: