Skip to content

Conversation

@yaochengji
Copy link
Collaborator

Description

Tests

  1. pytest -s -v tests/kernels/ragged_kv_cache_update_test.py
  2. server: TPU_BACKEND_TYPE=torchax VLLM_TORCHAX_ENABLED=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --gpu-memory-utilization 0.98 --max-num-batched-tokens 2048 --max-num-seqs 128 --max-model-len 2048 --no-enable-prefix-caching --tensor_parallel_size=1
    client: python3 ./benchmarks/benchmark_serving.py --model meta-llama/Llama-3.1-8B-Instruct --dataset-name sonnet --dataset-path benchmarks/sonnet_4x.txt --sonnet-input-len 1800 --sonnet-output-len 128 --ignore_eos

We can observe the throughput increases from 8.11 req/s to 8.14 req/s.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Signed-off-by: Chengji Yao <chengjiyao@google.com>
@yaochengji yaochengji requested a review from bythew3i July 31, 2025 21:53
@yaochengji
Copy link
Collaborator Author

cc @lsy323 @xiangxu-google

Copy link
Collaborator

@xiangxu-google xiangxu-google left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
vmem_limit_bytes=100 * 1024 * 1024,
Copy link
Collaborator

@bythew3i bythew3i Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: better to test the throughput before and after this change. Kernel won't be affected but next op's prefetch will be affected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would crash due to vmem OOM before the change.

page_size: int,
num_slices_per_block: int,
dynamic_validate_inputs: bool,
vmem_limit_bytes: int = 40 * 1024 * 1024,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this calculated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically I'd like to have 32MB vmem to be used for scratch buffer, and it's round up to 40MB.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest - have an autotune and benchmarking in google workspace internally first. Like we do for RPA and quantized_matmul

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Since I'd like to make the kv cache update kernel on par with vLLM torch/xla path, we can add the auto-tuning and benchmarking later in google internally later?

Copy link
Collaborator

@bythew3i bythew3i left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - if the current change blocks you

@yaochengji yaochengji merged commit a54ea5a into main Aug 1, 2025
1 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants