Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flashinfer paged attention #2772

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

flashinfer paged attention #2772

wants to merge 22 commits into from

Conversation

sumo43
Copy link

@sumo43 sumo43 commented Feb 5, 2024

Description

This PR implements FlashInfer's implementation of GQA PagedAttention which is up to 2-3x faster than vLLM's version. I implement flashinfer for prefill and decoding, while still using cache_ops.

https://github.com/flashinfer-ai/flashinfer/
https://flashinfer.ai/2024/02/02/introduce-flashinfer.html

302388731-218fd1ee-bc64-4b47-b10b-5d4f0a22da24

Performance Results

I used the following setup:

python3 benchmarks/benchmark_throughput.py \
        --input-len 1000 \
        --output-len 1000 \
        --model mistralai/Mistral-7B-v0.1 \
        --num-prompts 100 \
        --enforce-eager

Throughput with flashinfer: 2.63 requests/s, 5258.27 tokens/s
Throughput without flashinfer: 1.82 requests/s, 3642.17 tokens/s

TODOS

  • Reduce overhead when creating prefill_wrapper, decode_wrapper (DONE)
  • check kv cache indexing, i think there is a bug there (DONE)
  • run and debug with tp, different models
  • fix sampler delay (?) (DONE)

@esmeetu
Copy link
Collaborator

esmeetu commented Feb 6, 2024

I got error on T4 GPU with half dtype model.
RuntimeError: BatchPrefillWithPagedKVCache failed to dispatch with dtype Half

@sumo43
Copy link
Author

sumo43 commented Feb 6, 2024

I got error on T4 GPU with half dtype model. RuntimeError: BatchPrefillWithPagedKVCache failed to dispatch with dtype Half

are you using the kvcache2 branch? also, try setting to float16 instead of half

@esmeetu
Copy link
Collaborator

esmeetu commented Feb 7, 2024

Hi @sumo43 Yes i do.
I changed to Yi-6B-Chat model, it throw new error:

  File "/.conda/envs/infer/lib/python3.10/site-packages/flashinfer/prefill.py", line 461, in forward
    return self._wrapper.forward(
RuntimeError: BatchPrefillWithPagedKVCache failed with error code no kernel image is available for execution on the device

And when i using the tp2 for testing, the engine got stuck. It seems tensor parallelism not supported.

@zhuohan123
Copy link
Member

@sumo43 Please feel free to ping me when the PR is ready for review!

@sumo43
Copy link
Author

sumo43 commented Feb 7, 2024

@sumo43 Please feel free to ping me when the PR is ready for review!

Sounds good. so far I made the kv cache compatible with flashinfer and checked that the outputs were coherent. I'm currently debugging a few issues, like the sampler potentially taking longer to run (?) but i'll make it ready for review soon. Thanks

@sumo43 sumo43 changed the title [WIP] flashinfer paged attention flashinfer paged attention Feb 7, 2024
@sumo43 sumo43 marked this pull request as ready for review February 7, 2024 22:33
@sumo43
Copy link
Author

sumo43 commented Feb 7, 2024

so, i tested the core functionality and it works. however, my code doesn't support cudagraphs so the tests fail (with eager mode they pass). also, flashinfer is only available with python3.10 and 3.11 wheels, so the docker tests using python3.8 don't pass.

@Yard1
Copy link
Collaborator

Yard1 commented Feb 8, 2024

Regarding CUDA graphs, this PR should help (though it may not be the only thing needed) - flashinfer-ai/flashinfer#111

@WoosukKwon
Copy link
Collaborator

Hi @sumo43, thanks for submitting the PR! To accelerate the merge, we'd like to directly push some modifications to the PR. For example, we'd like to use FlashInfer's C++ APIs rather than the Python APIs. Would you allow us to directly commit the changes to this PR? Of course, you'll remain as a co-author of the PR.

@sumo43
Copy link
Author

sumo43 commented Feb 12, 2024

Hi @WoosukKwon. Absolutely, feel free to make any changes you need.

@shanshanpt
Copy link
Contributor

shanshanpt commented Feb 20, 2024

hi @sumo43
docker image: nvcr.io/nvidia/pytorch:23.07-py3 , Python 3.10.6, A100 x 8

Try to run branch/kvcache2, found an error as follows:

script command:
python3 benchmarks/benchmark_throughput.py --input-len 1000 --output-len 1000 --model /model/Mistral-7B-v0.1 --num-prompts 100 --enforce-eager

error log:
INFO 02-20 03:21:17 llm_engine.py:327] # GPU blocks: 27153, # CPU blocks: 2048
Processed prompts: 0%| | 0/100 [00:00<?, ?it/s]Traceback (most recent call last):
File "/model/vllm-flashinfer/benchmarks/benchmark_throughput.py", line 336, in
main(args)
File "/model/vllm-flashinfer/benchmarks/benchmark_throughput.py", line 209, in main
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
File "/model/vllm-flashinfer/benchmarks/benchmark_throughput.py", line 111, in run_vllm
llm._run_engine(use_tqdm=True)
File "/model/vllm-flashinfer/vllm/entrypoints/llm.py", line 208, in _run_engine
step_outputs = self.llm_engine.step()
File "/model/vllm-flashinfer/vllm/engine/llm_engine.py", line 802, in step
all_outputs = self._run_workers(
File "/model/vllm-flashinfer/vllm/engine/llm_engine.py", line 989, in _run_workers
driver_worker_output = getattr(self.driver_worker,
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/model/vllm-flashinfer/vllm/worker/worker.py", line 219, in execute_model
output = self.model_runner.execute_model(seq_group_metadata_list,
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/model/vllm-flashinfer/vllm/worker/model_runner.py", line 613, in execute_model
hidden_states = model_executable(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/model/vllm-flashinfer/vllm/model_executor/models/mistral.py", line 304, in forward
hidden_states = self.model(input_ids, positions, kv_caches,
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/model/vllm-flashinfer/vllm/model_executor/models/mistral.py", line 257, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/model/vllm-flashinfer/vllm/model_executor/models/mistral.py", line 205, in forward
hidden_states = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/model/vllm-flashinfer/vllm/model_executor/models/mistral.py", line 155, in forward
attn_output = self.attn(q, k, v, kv_cache, input_metadata)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/model/vllm-flashinfer/vllm/model_executor/layers/attention.py", line 168, in forward
output = flashinfer.single_prefill_with_kv_cache(query, key.contiguous(), value.contiguous(), causal=True)
File "/usr/local/lib/python3.10/dist-packages/flashinfer/prefill.py", line 139, in single_prefill_with_kv_cache
return _kernels.single_prefill_with_kv_cache(
ValueError: When causal is true, kv_len must be greater than or equal to qo_len, got kv_len32032 and qo_len 128128

@markluofd
Copy link

hi @sumo43 docker image: nvcr.io/nvidia/pytorch:23.07-py3 , Python 3.10.6, A100 x 8

Try to run branch/kvcache2, found an error as follows:

script command: python3 benchmarks/benchmark_throughput.py --input-len 1000 --output-len 1000 --model /model/Mistral-7B-v0.1 --num-prompts 100 --enforce-eager

error log: INFO 02-20 03:21:17 llm_engine.py:327] # GPU blocks: 27153, # CPU blocks: 2048 Processed prompts: 0%| | 0/100 [00:00<?, ?it/s]Traceback (most recent call last): File "/model/vllm-flashinfer/benchmarks/benchmark_throughput.py", line 336, in main(args) File "/model/vllm-flashinfer/benchmarks/benchmark_throughput.py", line 209, in main elapsed_time = run_vllm(requests, args.model, args.tokenizer, File "/model/vllm-flashinfer/benchmarks/benchmark_throughput.py", line 111, in run_vllm llm._run_engine(use_tqdm=True) File "/model/vllm-flashinfer/vllm/entrypoints/llm.py", line 208, in _run_engine step_outputs = self.llm_engine.step() File "/model/vllm-flashinfer/vllm/engine/llm_engine.py", line 802, in step all_outputs = self._run_workers( File "/model/vllm-flashinfer/vllm/engine/llm_engine.py", line 989, in _run_workers driver_worker_output = getattr(self.driver_worker, File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/model/vllm-flashinfer/vllm/worker/worker.py", line 219, in execute_model output = self.model_runner.execute_model(seq_group_metadata_list, File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/model/vllm-flashinfer/vllm/worker/model_runner.py", line 613, in execute_model hidden_states = model_executable( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/model/vllm-flashinfer/vllm/model_executor/models/mistral.py", line 304, in forward hidden_states = self.model(input_ids, positions, kv_caches, File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/model/vllm-flashinfer/vllm/model_executor/models/mistral.py", line 257, in forward hidden_states, residual = layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/model/vllm-flashinfer/vllm/model_executor/models/mistral.py", line 205, in forward hidden_states = self.self_attn( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/model/vllm-flashinfer/vllm/model_executor/models/mistral.py", line 155, in forward attn_output = self.attn(q, k, v, kv_cache, input_metadata) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/model/vllm-flashinfer/vllm/model_executor/layers/attention.py", line 168, in forward output = flashinfer.single_prefill_with_kv_cache(query, key.contiguous(), value.contiguous(), causal=True) File "/usr/local/lib/python3.10/dist-packages/flashinfer/prefill.py", line 139, in single_prefill_with_kv_cache return _kernels.single_prefill_with_kv_cache( ValueError: When causal is true, kv_len must be greater than or equal to qo_len, got kv_len32032 and qo_len 128128

Got the same error

@Qiubo1
Copy link

Qiubo1 commented Feb 21, 2024

I using 300 requests to test LLAMA13B with flashinfer and with original pageattention ,original pageattention throughtput is faster than flashinfer 10%. I wonder flashinfer is only works in GQA construct ???

@sumo43
Copy link
Author

sumo43 commented Feb 21, 2024

@pythonononer yeah i noticed it too. I'm looking into whether C++ API is faster or not. Also, @shanshanpt i'd recommend using the pass-ci branch since it's a bit newer and passes CI tests.

@Qiubo1
Copy link

Qiubo1 commented Feb 21, 2024

@pythonononer yeah i noticed it too. I'm looking into whether C++ API is faster or not. Also, @shanshanpt i'd recommend using the pass-ci branch since it's a bit newer and passes CI tests.

I think c++ api is equal to python api. in custom , using pybind to turn c++ interface to python just improve a little. Maybe we need the arthur to make the optimization.

There are some restrictions I conclude: 1. python version>=3.9, torch >=2.1, cuda>11.8. 2. must open eager mode and tp==1 ,so LLAMA 70B not works.

@MikeChenfu
Copy link

Is this PR still active? I also get the same error .

 RuntimeError: BatchPrefillWithPagedKVCache failed to dispatch with dtype bFloat16. 

@jon-chuang
Copy link
Contributor

Please close as stale

AFAIK flashinfer is now merged, right?

Copy link

github-actions bot commented Nov 5, 2024

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale label Nov 5, 2024
@mergify mergify bot added the ci/build label Nov 5, 2024
Copy link

mergify bot commented Nov 5, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @sumo43 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 5, 2024
@github-actions github-actions bot added unstale and removed stale labels Nov 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants