Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: Fix sm75 kernel configuration #449

Merged
merged 12 commits into from
Aug 27, 2024
Merged

bugfix: Fix sm75 kernel configuration #449

merged 12 commits into from
Aug 27, 2024

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Aug 16, 2024

Some kernel configurations are not compatible with sm75, this pr fix these issues.

@yzh119 yzh119 requested a review from zhyncs August 16, 2024 06:18
@zhyncs
Copy link
Member

zhyncs commented Aug 16, 2024

Ok I will verify it locally asap.

@zhyncs zhyncs self-assigned this Aug 16, 2024
Copy link
Member

@zhyncs zhyncs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. It works when

python3 -m sglang.launch_server --model Qwen/Qwen1.5-1.8B-Chat --dtype float16 --disable-flashinfer-sampling --disable-cuda-graph

@yzh119
Copy link
Collaborator Author

yzh119 commented Aug 16, 2024

why disable sampling and cudagraph?

@zhyncs
Copy link
Member

zhyncs commented Aug 16, 2024

why disable sampling and cudagraph?

CUDA Graph doesn't work with Google Colab's T4 but works with GCP's T4, which is confusing. Also, Sampling doesn't work on sm75.

RuntimeError: TopKRenormProb failed with error code invalid device function

@yzh119
Copy link
Collaborator Author

yzh119 commented Aug 16, 2024

Sampling doesn't work on sm75.

I should fix this.

@zhyncs
Copy link
Member

zhyncs commented Aug 16, 2024

I reviewed the code and it should be compatible with sm75. I am going to extract this kernel and validate it on T4 separately.

@zhyncs
Copy link
Member

zhyncs commented Aug 17, 2024

Maybe we don't need to modify the const expr part subsequently, and choose to add --expt-relaxed-constexpr in the compilation parameters instead.

@zhyncs
Copy link
Member

zhyncs commented Aug 17, 2024

CUDA Graph doesn't work with Google Colab's T4 but works with GCP's T4, which is confusing.

I figured it out. This is because mem-frac was set when running on GCP T4, so actually CUDA Graph can be enabled. Since the VRAM of T4 is relatively small, the default mem-frac will prevent CUDA Graph from being enabled. This has nothing to do with FlashInfer. Cheers.

@zhyncs zhyncs mentioned this pull request Aug 17, 2024
4 tasks
@yzh119
Copy link
Collaborator Author

yzh119 commented Aug 17, 2024

Maybe we don't need to modify the const expr part subsequently, and choose to add --expt-relaxed-constexpr in the compilation parameters instead.

The main issue is __CUDA_ARCH__ only works in device functions (inside kernels), and we can't rely on them in host functions.

@yzh119 yzh119 force-pushed the sm75-kernel-config branch from 7b7c7f4 to bc9b9ab Compare August 26, 2024 22:21
@yzh119 yzh119 merged commit 3d38d0d into main Aug 27, 2024
yzh119 added a commit that referenced this pull request Aug 27, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.6](v0.1.5...v0.1.6)
(2024-08-27)

### SM75 Support

Starting from
[0.1.6](v0.1.5...v0.1.6),
our pre-built wheels include experimental support sm75 (Turing
architecture GPUs such as Tesla T4, Quadro RTX 6000 and RTX 2080).

### API Changes

#### `plan`/`run`

Since
[0.1.6](v0.1.5...v0.1.6)
on, `begin_forward`/`forward`/`end_forward` APIs are replaced with the
new `plan`/`run` API.
- `forward` is renamed to `run`, which is more precise and consistent
with the naming convention of cutlass's python API.
- `begin_forward` is renamed to `plan`, which is consistent with the
naming convention of nvmath API.
- `end_forward` is deprecated and has no effect after this PR.

There is some slight difference between the old `forward` and the new
`run` API:
- All extra arguments such as `causal` and `logits_soft_cap` will be
provided in `plan` (previously `begin_forward`) API, and cached until
next `plan` call, and we only need to provide query and KV-Cache tensors
in `run` API.

The old `begin_forward`/`forward`/`end_forward` APIs are still
functional, but we will gradually deprecate them in future releases.

Check [#466](#466) for
more details.

#### `MultiLevelCascadeAttentionWrapper`

Since
[0.1.6](v0.1.5...v0.1.6)
on, we introduce a new `MultiLevelCascadeAttentionWrapper` API for
cascade inference,
which supports multi-level cascade inference where all levels' KV-Cache
can be managed in a unified Paged KV-Cache.

See
[documentation](https://docs.flashinfer.ai/api/python/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper)
and
[tutorial](https://docs.flashinfer.ai/tutorials/kv_layout.html#multi-level-cascade-inference-data-layout)
on API usage and layout explaination.

The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and
`BatchPrefillWithSharedPrefixPagedKVCacheWrapper` will be deprecated in
future releases.

### Features

* sm75 support
([#448](#448),
[#449](#449))
* add `MultiLevelCascadeAttentionWrapper` API
([#462](#462))
([1e37989](1e37989))
* add accept num, emit num metric for ChainSpeculativeSampling
([#450](#450))
([fa38b5e](fa38b5e))
* support bmm fp8
([#469](#469))
([f1c0b68](f1c0b68))

### Refactor

* refactor: replace `begin_forward`/`forward`/`end_forward` with
`plan`/`run`
[#466](#466)

### Misc

* misc: improve error handling of sampling kernels
([#456](#456))
([0dce178](0dce178))

### Performance Improvements

* slight optimization on f16->f8 fragment layout swizzling
([#453](#453))
([0d61871](0d61871))
* slight optimization on fragment layout swizzle
([#458](#458))
([7c397cb](7c397cb))
* use persistent kernel for merging attention states
([#459](#459))
([be6bf5b](be6bf5b))

### Acknowledgement

We thank [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU) on enhance
of speculative sampling operator,
[@merrymercy](https://github.com/merrymercy) on API change suggestion
and [@zhyncs](https://github.com/zhyncs) on integrating fp8 BMM cublas
implementation.

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
@yzh119 yzh119 deleted the sm75-kernel-config branch August 27, 2024 04:51
@zhyncs
Copy link
Member

zhyncs commented Aug 27, 2024

Nice work! I will integrate it into SGLang.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants