Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support bmm fp8 #469

Merged
merged 16 commits into from
Aug 26, 2024
Merged

feat: support bmm fp8 #469

merged 16 commits into from
Aug 26, 2024

Conversation

zhyncs
Copy link
Member

@zhyncs zhyncs commented Aug 26, 2024

torch.bmm doesn't support fp8 and torch._scaled_mm doesn't support 3d, so I write this one. @yzh119 cc @merrymercy @Ying1123 @ispobock

Thanks @yzh119 for assisting with debug.

AType: fp8 e4m3, fp8 e5m2
BType: fp8 e4m3, fp8 e5m2
DType: bf16, fp16

Does not support both AType and BType fp8 e5m2. ref https://docs.nvidia.com/cuda/cublas/#cublasltmatmul

pytest python/tests/test_bmm_fp8.py

works on H100

=================================================================================== test session starts ===================================================================================
platform linux -- Python 3.12.4, pytest-8.3.2, pluggy-1.5.0
rootdir: /flashinfer
collected 8 items

python/tests/test_bmm_fp8.py ...s...s                                                                                                                                                                       [100%]

============================================================================== 6 passed, 2 skipped in 2.16s ===============================================================================

@zhyncs zhyncs added the enhancement New feature or request label Aug 26, 2024
@zhyncs zhyncs requested a review from yzh119 August 26, 2024 18:04
@zhyncs zhyncs self-assigned this Aug 26, 2024
@yzh119
Copy link
Collaborator

yzh119 commented Aug 26, 2024

Another suggestion is to move group gemm and bmm fp8 to a common gemm.py, we should also update the group_gemm.rst (to gemm.rst) as well.

@zhyncs
Copy link
Member Author

zhyncs commented Aug 26, 2024

Another suggestion is to move group gemm and bmm fp8 to a common gemm.py, we should also update the group_gemm.rst (to gemm.rst) as well.

make sense

python/flashinfer/gemm.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for your contribution @zhyncs !

@yzh119 yzh119 merged commit f1c0b68 into main Aug 26, 2024
@zhyncs zhyncs deleted the fp8-bmm-scale branch August 26, 2024 19:32
yzh119 added a commit that referenced this pull request Aug 27, 2024
The documentation was not indexed properly in #469 , this PR fixes the
issue.
yzh119 added a commit that referenced this pull request Aug 27, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.6](v0.1.5...v0.1.6)
(2024-08-27)

### SM75 Support

Starting from
[0.1.6](v0.1.5...v0.1.6),
our pre-built wheels include experimental support sm75 (Turing
architecture GPUs such as Tesla T4, Quadro RTX 6000 and RTX 2080).

### API Changes

#### `plan`/`run`

Since
[0.1.6](v0.1.5...v0.1.6)
on, `begin_forward`/`forward`/`end_forward` APIs are replaced with the
new `plan`/`run` API.
- `forward` is renamed to `run`, which is more precise and consistent
with the naming convention of cutlass's python API.
- `begin_forward` is renamed to `plan`, which is consistent with the
naming convention of nvmath API.
- `end_forward` is deprecated and has no effect after this PR.

There is some slight difference between the old `forward` and the new
`run` API:
- All extra arguments such as `causal` and `logits_soft_cap` will be
provided in `plan` (previously `begin_forward`) API, and cached until
next `plan` call, and we only need to provide query and KV-Cache tensors
in `run` API.

The old `begin_forward`/`forward`/`end_forward` APIs are still
functional, but we will gradually deprecate them in future releases.

Check [#466](#466) for
more details.

#### `MultiLevelCascadeAttentionWrapper`

Since
[0.1.6](v0.1.5...v0.1.6)
on, we introduce a new `MultiLevelCascadeAttentionWrapper` API for
cascade inference,
which supports multi-level cascade inference where all levels' KV-Cache
can be managed in a unified Paged KV-Cache.

See
[documentation](https://docs.flashinfer.ai/api/python/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper)
and
[tutorial](https://docs.flashinfer.ai/tutorials/kv_layout.html#multi-level-cascade-inference-data-layout)
on API usage and layout explaination.

The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and
`BatchPrefillWithSharedPrefixPagedKVCacheWrapper` will be deprecated in
future releases.

### Features

* sm75 support
([#448](#448),
[#449](#449))
* add `MultiLevelCascadeAttentionWrapper` API
([#462](#462))
([1e37989](1e37989))
* add accept num, emit num metric for ChainSpeculativeSampling
([#450](#450))
([fa38b5e](fa38b5e))
* support bmm fp8
([#469](#469))
([f1c0b68](f1c0b68))

### Refactor

* refactor: replace `begin_forward`/`forward`/`end_forward` with
`plan`/`run`
[#466](#466)

### Misc

* misc: improve error handling of sampling kernels
([#456](#456))
([0dce178](0dce178))

### Performance Improvements

* slight optimization on f16->f8 fragment layout swizzling
([#453](#453))
([0d61871](0d61871))
* slight optimization on fragment layout swizzle
([#458](#458))
([7c397cb](7c397cb))
* use persistent kernel for merging attention states
([#459](#459))
([be6bf5b](be6bf5b))

### Acknowledgement

We thank [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU) on enhance
of speculative sampling operator,
[@merrymercy](https://github.com/merrymercy) on API change suggestion
and [@zhyncs](https://github.com/zhyncs) on integrating fp8 BMM cublas
implementation.

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants