Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pytorch api of fp8 kv-cache #156

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,7 @@ cudaError_t BatchDecodeWithPagedKVCache(

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o,
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
float sm_scale, float rope_scale,
Expand Down Expand Up @@ -1304,7 +1304,7 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType
}

template <typename DTypeIn, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o,
cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim,
Expand Down
157 changes: 110 additions & 47 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,45 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache(
}

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());
auto o = torch::empty_like(
q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type()));
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32);
}

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
c_type* tmp = nullptr;
cudaError_t status = BatchDecodeWithPaddedKVCache<c_type, c_type>(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k_padded.data_ptr()),
static_cast<c_type*>(v_padded.data_ptr()), static_cast<c_type*>(o.data_ptr()),
/*tmp=*/tmp,
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ",
status);
return true;
});
bool success;
if (is_float8_tensor(q)) {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] {
nv_half* tmp = nullptr;
cudaError_t status = BatchDecodeWithPaddedKVCache<c_type, nv_half>(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k_padded.data_ptr()),
static_cast<c_type*>(v_padded.data_ptr()), static_cast<nv_half*>(o.data_ptr()),
/*tmp=*/tmp,
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ",
status);
return true;
});
} else {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
c_type* tmp = nullptr;
cudaError_t status = BatchDecodeWithPaddedKVCache<c_type, c_type>(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k_padded.data_ptr()),
static_cast<c_type*>(v_padded.data_ptr()), static_cast<c_type*>(o.data_ptr()),
/*tmp=*/tmp,
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ",
status);
return true;
});
}
TORCH_CHECK(success, "BatchDecodeWithPaddedKVCache kernel launch failed: supported data type");

if (return_lse) {
Expand Down Expand Up @@ -93,19 +113,36 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
cudaError_t status =
handler_.BeginForward<PageStorage::kIndices, KV_LAYOUT, c_type, c_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode));
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
})
});
bool success;
if (is_float8_tensor(empty_data)) {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
cudaError_t status =
handler_.BeginForward<PageStorage::kIndices, KV_LAYOUT, c_type, nv_half, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode));
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
})
});
} else {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
cudaError_t status =
handler_.BeginForward<PageStorage::kIndices, KV_LAYOUT, c_type, c_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode));
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
})
});
}

TORCH_CHECK(success, "BatchDecodeWithPagedKVCache failed to dispatch with dtype ",
empty_data.scalar_type());
Expand Down Expand Up @@ -151,31 +188,57 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor o = torch::empty_like(
q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type()));
torch::Tensor lse;
if (return_lse) {
lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32);
}
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
paged_kv_t<PageStorage::kIndices, KV_LAYOUT, c_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size,
static_cast<c_type*>(paged_kv_data.data_ptr()),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
cudaError_t status = BatchDecodeWithPagedKVCacheWrapper<PageStorage::kIndices, KV_LAYOUT,
c_type, c_type, int32_t>(
&handler_, static_cast<c_type*>(q.data_ptr()), /*q_offset=*/nullptr, paged_kv,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), num_qo_heads,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));

bool success;
if (is_float8_tensor(q)) {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
paged_kv_t<PageStorage::kIndices, KV_LAYOUT, c_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size,
static_cast<c_type*>(paged_kv_data.data_ptr()),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
cudaError_t status = BatchDecodeWithPagedKVCacheWrapper<PageStorage::kIndices, KV_LAYOUT,
c_type, nv_half, int32_t>(
&handler_, static_cast<c_type*>(q.data_ptr()), /*q_offset=*/nullptr, paged_kv,
static_cast<nv_half*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), num_qo_heads,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
});
return true;
});
return true;
});
} else {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
paged_kv_t<PageStorage::kIndices, KV_LAYOUT, c_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size,
static_cast<c_type*>(paged_kv_data.data_ptr()),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
cudaError_t status = BatchDecodeWithPagedKVCacheWrapper<PageStorage::kIndices, KV_LAYOUT,
c_type, c_type, int32_t>(
&handler_, static_cast<c_type*>(q.data_ptr()), /*q_offset=*/nullptr, paged_kv,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), num_qo_heads,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
});
return true;
});
}

TORCH_CHECK(success, "BatchDecodeWithPagedKVCache failed to dispatch with dtype ",
q.scalar_type());
Expand Down
28 changes: 28 additions & 0 deletions python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
*/
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <cuda_fp16.h>
#include <torch/extension.h>

#include "generated/dispatch.inc"
#ifdef FLASHINFER_ENABLE_BF16
#include <cuda_bf16.h>
#endif
#ifdef FLASHINFER_ENABLE_FP8
#include <cuda_fp8.h>
#endif

#ifdef FLASHINFER_ENABLE_BF16
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
Expand Down Expand Up @@ -49,6 +56,22 @@
}()
#endif

#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Float8_e4m3fn: { \
using c_type = __nv_fp8_e4m3; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Float8_e5m2: { \
using c_type = __nv_fp8_e5m2; \
return __VA_ARGS__(); \
} \
default: \
return false; \
} \
}()

#define _DISPATCH_SWITCH(cond, ...) \
[&]() -> bool { \
switch (cond) { \
Expand Down Expand Up @@ -99,3 +122,8 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)

#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)

inline bool is_float8_tensor(const torch::Tensor& tensor) {
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn ||
tensor.scalar_type() == at::ScalarType::Float8_e5m2;
}
40 changes: 28 additions & 12 deletions python/csrc/single_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,35 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
kv_len = k.size(1);
}
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());
auto o = torch::empty_like(
q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type()));

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
cudaError_t status = SingleDecodeWithKVCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<c_type*>(o.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim,
kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
return true;
});
bool success;
if (is_float8_tensor(q)) {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] {
cudaError_t status = SingleDecodeWithKVCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<nv_half*>(o.data_ptr()),
static_cast<nv_half*>(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim,
kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
return true;
});
} else {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
cudaError_t status = SingleDecodeWithKVCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<c_type*>(o.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim,
kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
return true;
});
}

TORCH_CHECK(success, "SingleDecodeWithKVCache kernel launch failed, error: unsupported dtype");
return o;
Expand Down
42 changes: 42 additions & 0 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import math
from typing import Optional
import torch
import logging

try:
from . import _kernels
Expand All @@ -36,6 +37,7 @@
expand_5d,
check_pos_encoding_mode,
check_kv_layout,
is_float8,
)


Expand Down Expand Up @@ -248,6 +250,14 @@ def single_prefill_with_kv_cache_return_lse(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
return _kernels.single_prefill_with_kv_cache(
q,
k,
Expand Down Expand Up @@ -485,6 +495,14 @@ def forward(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
paged_kv_data = paged_kv_data.to(torch.float16)

paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
return self._wrapper.forward(
q,
Expand Down Expand Up @@ -557,6 +575,14 @@ def forward_return_lse(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
paged_kv_data = paged_kv_data.to(torch.float16)

paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
return self._wrapper.forward(
q,
Expand Down Expand Up @@ -769,6 +795,14 @@ def forward(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
return self._wrapper.forward(
q,
self._qo_indptr,
Expand Down Expand Up @@ -838,6 +872,14 @@ def forward_return_lse(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
return self._wrapper.forward(
q,
self._qo_indptr,
Expand Down
Loading