Skip to content

Conversation

@hliuca
Copy link
Contributor

@hliuca hliuca commented Apr 30, 2025

The first dimension of tmp_out and exp_sums is inferred from block_tables.size(0), which may be different from query.shape(0). The later can be much larger than block_tables.size(0), which may cause OOM.

This PR fix the total_num_seq and the comments.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @hliuca . Do you have a test and/or example that triggers a failure without this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you update the shape comments for the other kernels in this file?

Copy link
Contributor Author

@hliuca hliuca Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, here are the commands to reproduce. The query.shape[0] is way larger than block_tables.shape[0] (the correct one from cuda/hip source code). Thank you.

server side commands,

docker.io/rocm/vllm-dev:nightly_main_20250420

export VLLM_USE_TRITON_FLASH_ATTN=0
export NCCL_MIN_NCHANNELS=112
export VLLM_FP8_PADDING=1
export VLLM_FP8_ACT_PADDING=1
export VLLM_FP8_WEIGHT_PADDING=1
export VLLM_FP8_REDUCE_CONV=1
export HIP_FORCE_DEV_KERNARG=1
export VLLM_USE_V1=1

vllm serve /data/huggingface/hub/amd/Llama-3.3-70B-Instruct-FP8-KV --dtype float16 --tensor-parallel-size 1 --kv-cache-dtype auto --quantization None --swap-space 16 --distributed-executor-backend mp --max-num-seqs 64 --max-model-len 16384 --max-seq-len-to-capture 16384 --max-num-batched-tokens 131072 --no-enable-prefix-caching --enable-chunked-prefill=False --disable-log-requests --uvicorn-log-level warning --port 8000

client side command:
python3 /app/vllm/benchmarks/benchmark_serving.py --host localhost --backend openai --port 8000 --model /data/huggingface/hub/amd/Llama-3.3-70B-Instruct-FP8-KV --dataset-name random --num-prompts 24 --random-input-len 8500 --random-output-len 150 --max-concurrency 8 --percentile-metrics ttft,tpot,itl,e2el

hliuca added 2 commits April 30, 2025 12:58
Signed-off-by: Hui Liu <96135754+hliuca@users.noreply.github.com>
Signed-off-by: Hui Liu <96135754+hliuca@users.noreply.github.com>
@hliuca hliuca changed the title fix tmp_out and exp_sums dimensions [Bugfix] fix tmp_out and exp_sums dimensions May 1, 2025
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok looking at this again, I think the problem is that this file is conflating num_tokens with num_seqs. num_tokens is the outermost dimension in the query and output tensors and num_seqs is the outermost dimension in the block_table. So what I'm saying is that we should update the query/output shape comments to be [num_tokens, ....] and leave the other shape arguments alone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras and I will work together to address the comments in the source code. Thank you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conclusion is that for V0 and V1 the kernel should be called with different value for num_seqs. But in the kernel itself the value does represent number of sequences, so we'll revert the comment change, leaving just the V1 callsite change

float* __restrict__ exp_sums, // [block_tables.size(0), num_heads, max_num_partitions]
float* __restrict__ max_logits, // [block_tables.size(0), num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [block_tables.size(0), num_heads, max_num_partitions, head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like final_out is a dead argument? Meaning, I don't see it used anywhere in this file. Are we sure these kernels are actually used?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this one is not related to the PR. We can run a round of cleanups in a follow up

Signed-off-by: Hui Liu <96135754+hliuca@users.noreply.github.com>
@gshtras
Copy link
Collaborator

gshtras commented May 1, 2025

@SageMoore I hope the above answers your questions

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) May 2, 2025 15:01
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 2, 2025
@mgoin mgoin added the bug Something isn't working label May 2, 2025
@robertgshaw2-redhat robertgshaw2-redhat merged commit 4c33d67 into vllm-project:main May 2, 2025
71 checks passed
@hliuca hliuca deleted the fix_chunked_prefill branch May 2, 2025 16:54
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Hui Liu <96135754+hliuca@users.noreply.github.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
Signed-off-by: Hui Liu <96135754+hliuca@users.noreply.github.com>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: Hui Liu <96135754+hliuca@users.noreply.github.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants