Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: initial cuda graph support #256

Merged
merged 6 commits into from
May 24, 2024
Merged

perf: initial cuda graph support #256

merged 6 commits into from
May 24, 2024

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented May 24, 2024

As requested in #187 , this PR adds initial support of CUDAGraph compatibility of flashinfer batch decode attention kernels. This PR is the first step towards full CUDAGraph support and we will implement CUDAGraph compatible prefill operators in later PRs.

Proposed APIs

We add another wrapper CUDAGraphBatchDecodeWithPagedKVCacheWrapper, and user need to pre-allocation page data structure buffers to initialize this wrapper class. Once initiated, these buffers are pinned on GPUs in the life cycle of the wrapper class.

The behavior of CUDAGraphBatchDecodeWithPagedKVCacheWrapper is a little bit different from BatchDecodeWithPagedKVCacheWrapper's: we will only run a fixed set of kernels in CUDAGraph mode, no matter what the input shape is (the original implementation will dispatch to different kernels according to different input shapes).

This PR also fix the address of all kernel input pointers to accomodate the constraint of CUDAGraph capturing.

Examples

See test_cuda_graph_batch_decode_with_paged_kv_cache in unittests.
begin_forward functions should not be captured as some of the operators are not allowed to be captured.

cc @AgrawalAmey @LiuXiaoxuanPKU @comaniac

@yzh119 yzh119 changed the title perm: initial cuda graph support perf: initial cuda graph support May 24, 2024
@yzh119
Copy link
Collaborator Author

yzh119 commented May 24, 2024

Let's merge this PR first, and then iterate on updating this feature.

@yzh119 yzh119 merged commit 7e9cc7f into main May 24, 2024
@MasterJH5574 MasterJH5574 deleted the cudagraph branch May 28, 2024 20:32
yzh119 added a commit that referenced this pull request Jun 2, 2024
yzh119 added a commit that referenced this pull request Jun 20, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.0](v0.0.4...v0.1.0)
(2024-06-20)

### Highlights

* Support any GQA group size support for tensor-cores kernels.
* Support any page size support for tensor-cores kernels.
* Support CUDA-Graph for prefill/decode APIs.
* Add an option to accelerate decode kernels with Tensor Cores.
* Support custom attention mask.
(https://docs.flashinfer.ai/tutorials/kv_layout.html#mask-layout-2d-ragged-tensor)
* Support logits cap in Grok-1 models.
* Fused GPU-sampling kernels: top-p, top-k, speculative verification.
(https://docs.flashinfer.ai/api/python/sampling.html)
* PyTorch wrapper of group-gemm cutlass kernels.
(https://docs.flashinfer.ai/api/python/sampling.html)

### Acknowledgement

We thank [@ibsidorenko](https://github.com/ibsidorenko),
[@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU),
[@Yard1](https://github.com/Yard1)
[@AgrawalAmey](https://github.com/AgrawalAmey),
[@xuzhenqi](https://github.com/xuzhenqi),
[@mgerstgrasser](https://github.com/mgerstgrasser),
[@esmeetu](https://github.com/esmeetu),
[@yz-tang](https://github.com/yz-tang),
[@HSQ79815](https://github.com/HSQ79815),
[@Qubitium](https://github.com/Qubitium),
[@shreygupta2809](https://github.com/shreygupta2809),
[@sighingnow](https://github.com/sighingnow),
[@vinx13](https://github.com/vinx13),
[@tqchen](https://github.com/tqchen),
[@merrymercy](https://github.com/merrymercy),
[@comaniac](https://github.com/comaniac) and many others for their
contributions and helpful discussions for 0.0.5 release.

### Refactor

* support any GQA group size for tensor-cores kernels
([#301](#301))
([c111ca](c111ca6))
* support any page size for tensor-cores kernels
([#306](#306))
([82fd8c](82fd8c7))


### Features

* add `use_tensor_cores` option to decode kernels to accelerate GQA
([#317](#317))
([3b50dd5](3b50dd5))
* add group gemm operators
([#282](#282))
([e08ba42](e08ba42))
* initial support of distributed operators
([#289](#289))
([03553da](03553da))
* initial support of logits hook
([#298](#298))
([ab1e2ad](ab1e2ad))
* Separate Q and KV dtypes for decode
([#286](#286))
([5602659](5602659))
* support cuda graph for batched multi-query(prefill/append) attention
([#275](#275))
([83ceb67](83ceb67))
* support cuda graph for batched multi-query(prefill/append) attention
([#277](#277))
([24cc583](24cc583))
* support custom attention mask in prefill/append attention kernels
([#266](#266))
([7304282](7304282))
* fused speculative sampilng kernels
([#259](#259))
([cea2bb](cea2bb9))
* expose sampling APIs in pytorch
([#238](#238))
([092902](0929023))


### Performance Improvements

* initial cuda graph support
([#256](#256))
([7e9cc7f](7e9cc7f))
* split kv-cache for prefill/append kernels
([#310](#310))
([f0bb0a3](f0bb0a3))
* use packed bit array for attention mask
([#308](#308))
([3d43dc9](3d43dc9))

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant