Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add MultiLevelCascadeAttentionWrapper API #462

Merged
merged 1 commit into from
Aug 22, 2024

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Aug 22, 2024

Our existing cascade inference APIs all assumes shared prefix kv-cache are standalone tensors which is not the case for real-world llm serving.

This PR adds a more general MultiLevelCascadeAttentionWrapper API which not only supports multi-level cascade inference, and the kv-cache of all levels are stored in the unified paged kv-cache, which can seamlessly integrate with existing LLM serving frameworks.

Tutorials, tests and examples are updated correspondingly.

The old BatchDecodeWithSharedPrefixPagedKVCacheWrapper and BatchPrefillWithSharedPrefixPagedKVCacheWrapper should be deprecated, starting from 0.2.0.

@yzh119 yzh119 merged commit 1e37989 into main Aug 22, 2024
@yzh119 yzh119 deleted the multi-level-cascade-api branch August 23, 2024 21:22
yzh119 added a commit that referenced this pull request Aug 27, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.6](v0.1.5...v0.1.6)
(2024-08-27)

### SM75 Support

Starting from
[0.1.6](v0.1.5...v0.1.6),
our pre-built wheels include experimental support sm75 (Turing
architecture GPUs such as Tesla T4, Quadro RTX 6000 and RTX 2080).

### API Changes

#### `plan`/`run`

Since
[0.1.6](v0.1.5...v0.1.6)
on, `begin_forward`/`forward`/`end_forward` APIs are replaced with the
new `plan`/`run` API.
- `forward` is renamed to `run`, which is more precise and consistent
with the naming convention of cutlass's python API.
- `begin_forward` is renamed to `plan`, which is consistent with the
naming convention of nvmath API.
- `end_forward` is deprecated and has no effect after this PR.

There is some slight difference between the old `forward` and the new
`run` API:
- All extra arguments such as `causal` and `logits_soft_cap` will be
provided in `plan` (previously `begin_forward`) API, and cached until
next `plan` call, and we only need to provide query and KV-Cache tensors
in `run` API.

The old `begin_forward`/`forward`/`end_forward` APIs are still
functional, but we will gradually deprecate them in future releases.

Check [#466](#466) for
more details.

#### `MultiLevelCascadeAttentionWrapper`

Since
[0.1.6](v0.1.5...v0.1.6)
on, we introduce a new `MultiLevelCascadeAttentionWrapper` API for
cascade inference,
which supports multi-level cascade inference where all levels' KV-Cache
can be managed in a unified Paged KV-Cache.

See
[documentation](https://docs.flashinfer.ai/api/python/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper)
and
[tutorial](https://docs.flashinfer.ai/tutorials/kv_layout.html#multi-level-cascade-inference-data-layout)
on API usage and layout explaination.

The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and
`BatchPrefillWithSharedPrefixPagedKVCacheWrapper` will be deprecated in
future releases.

### Features

* sm75 support
([#448](#448),
[#449](#449))
* add `MultiLevelCascadeAttentionWrapper` API
([#462](#462))
([1e37989](1e37989))
* add accept num, emit num metric for ChainSpeculativeSampling
([#450](#450))
([fa38b5e](fa38b5e))
* support bmm fp8
([#469](#469))
([f1c0b68](f1c0b68))

### Refactor

* refactor: replace `begin_forward`/`forward`/`end_forward` with
`plan`/`run`
[#466](#466)

### Misc

* misc: improve error handling of sampling kernels
([#456](#456))
([0dce178](0dce178))

### Performance Improvements

* slight optimization on f16->f8 fragment layout swizzling
([#453](#453))
([0d61871](0d61871))
* slight optimization on fragment layout swizzle
([#458](#458))
([7c397cb](7c397cb))
* use persistent kernel for merging attention states
([#459](#459))
([be6bf5b](be6bf5b))

### Acknowledgement

We thank [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU) on enhance
of speculative sampling operator,
[@merrymercy](https://github.com/merrymercy) on API change suggestion
and [@zhyncs](https://github.com/zhyncs) on integrating fp8 BMM cublas
implementation.

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant