Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support bmm fp8 #469

Merged
merged 16 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ FlashInfer is a library for Large Language Models that provides high-performance
api/python/sparse
api/python/page
api/python/sampling
api/python/group_gemm
api/python/gemm
api/python/norm
api/python/rope
api/python/quantization
200 changes: 200 additions & 0 deletions include/flashinfer/bmm_fp8.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_BMM_FP8_CUH_
#define FLASHINFER_BMM_FP8_CUH_

#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <cublasLt.h>
#include <cuda_fp8.h>
#include <torch/extension.h>

#include <stdexcept>
#include <type_traits>

namespace flashinfer {

namespace bmm_fp8 {

template <typename T, cublasStatus_t (*destructor)(T*)>
struct CuBlasLtDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDABLAS_CHECK(destructor(x));
}
}
};

template <typename T, cublasStatus_t (*destructor)(T*)>
class CuBlasLtDescriptor {
public:
T* descriptor() const { return descriptor_.get(); }
T* descriptor() { return descriptor_.get(); }

protected:
std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
};

class CuBlasLtMatmulDescriptor
: public CuBlasLtDescriptor<cublasLtMatmulDescOpaque_t, &cublasLtMatmulDescDestroy> {
public:
CuBlasLtMatmulDescriptor(cublasComputeType_t compute_type, cudaDataType_t scale_type) {
cublasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatrixLayout
: public CuBlasLtDescriptor<cublasLtMatrixLayoutOpaque_t, &cublasLtMatrixLayoutDestroy> {
public:
CuBlasLtMatrixLayout(cudaDataType_t type, uint64_t rows, uint64_t cols, int64_t ld,
bool t = false) {
cublasLtMatrixLayout_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<cublasLtMatmulPreferenceOpaque_t,
&cublasLtMatmulPreferenceDestroy> {
public:
CuBlasLtMatmulPreference() {
cublasLtMatmulPreference_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(
::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

template <typename T>
cudaDataType_t get_cuda_data_type() {
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
return CUDA_R_8F_E4M3;
} else if constexpr (std::is_same_v<T, __nv_fp8_e5m2>) {
return CUDA_R_8F_E5M2;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return CUDA_R_16BF;
} else if constexpr (std::is_same_v<T, half>) {
return CUDA_R_16F;
} else {
throw std::runtime_error("Unsupported type");
}
}

template <typename AT, typename BT, typename DT>
void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale) {
const void* A_scale_ptr = static_cast<const void*>(A_scale);
const void* B_scale_ptr = static_cast<const void*>(B_scale);
auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F);
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T);
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N);
int8_t fast_accum = 1;
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum);

matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, A_scale_ptr);
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, B_scale_ptr);

cudaDataType_t a_type = get_cuda_data_type<AT>();
cudaDataType_t b_type = get_cuda_data_type<BT>();
cudaDataType_t d_type = get_cuda_data_type<DT>();
if (std::is_same_v<AT, __nv_fp8_e5m2> && std::is_same_v<BT, __nv_fp8_e5m2>) {
throw std::runtime_error("Unsupported combination: both A and B are e5m2");
}

auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true);
auto b_desp = CuBlasLtMatrixLayout(b_type, k, n, k);
auto d_desp = CuBlasLtMatrixLayout(d_type, m, n, m);

if (batch_size > 1) {
int64_t stride_a = m * k;
int64_t stride_b = k * n;
int64_t stride_d = m * n;
a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size);
a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_a);
b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size);
b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_b);
d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size);
d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_d);
}

CuBlasLtMatmulPreference preference;
size_t workspace_size = 1024 * 1024; // 1 MiB
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspace_size);
cublasLtMatmulHeuristicResult_t heuristic_result = {};
int returned_result = 0;
auto lt_handle = at::cuda::getCurrentCUDABlasLtHandle();
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
lt_handle, matmul_desp.descriptor(), a_desp.descriptor(), b_desp.descriptor(),
d_desp.descriptor(), d_desp.descriptor(), preference.descriptor(), 1, &heuristic_result,
&returned_result));
if (returned_result == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}

const float alpha = 1.0f;
const float beta = 0.0f;
cublasStatus_t status = cublasLtMatmul(
lt_handle, matmul_desp.descriptor(), &alpha, A, a_desp.descriptor(), B, b_desp.descriptor(),
&beta, nullptr, d_desp.descriptor(), D, d_desp.descriptor(), &heuristic_result.algo,
workspace.mutable_get(), workspace_size, at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status));
}

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k, const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>(
const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>(
const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k, const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>(
const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, half* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k, const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>(
const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale);

} // namespace bmm_fp8
} // namespace flashinfer

#endif // FLASHINFER_BMM_FP8_CUH_
68 changes: 68 additions & 0 deletions python/csrc/bmm_fp8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include <flashinfer/bmm_fp8.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
torch::Tensor& A_scale, torch::Tensor& B_scale) {
TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor");
TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor");
TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor");
TORCH_CHECK(A.dim() == 3, "Expected 3D tensor for A");
TORCH_CHECK(B.dim() == 3, "Expected 3D tensor for B");
TORCH_CHECK(D.dim() == 3, "Expected 3D tensor for D");
TORCH_CHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0), "Batch sizes must match");
TORCH_CHECK(A.size(2) == B.size(1), "Incompatible matrix sizes");
TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2),
"Result tensor has incorrect shape");
TORCH_CHECK(A.scalar_type() == torch::kFloat8_e4m3fn || A.scalar_type() == torch::kFloat8_e5m2,
"A must be Float8_e4m3fn or Float8_e5m2");
TORCH_CHECK(B.scalar_type() == torch::kFloat8_e4m3fn || B.scalar_type() == torch::kFloat8_e5m2,
"B must be Float8_e4m3fn or Float8_e5m2");
TORCH_CHECK(D.scalar_type() == torch::kBFloat16 || D.scalar_type() == torch::kHalf,
"D must be BFloat16 or Half");

TORCH_CHECK(A_scale.scalar_type() == torch::kFloat32 && B_scale.scalar_type() == torch::kFloat32,
"A_scale and B_scale must be Float32");

auto batch_size = A.size(0);
auto m = A.size(1);
auto k = A.size(2);
auto n = B.size(2);

// PyTorch is row major by default. cuBLASLt is column major by default.
// We need row major D as expected.
// A ^ T * B = D, so D ^ T = B ^ T * A
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(B.scalar_type(), b_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(A.scalar_type(), a_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] {
flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()));
return true;
});
});
});
}
1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
py::class_<CutlassSegmentGEMMPyTorchWrapper>(m, "CutlassSegmentGEMMPyTorchWrapper")
.def(py::init<torch::Tensor>())
.def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer)
Expand Down
3 changes: 3 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);
torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
torch::Tensor output_indptr, const std::string& bitorder);

void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
torch::Tensor& A_scale, torch::Tensor& B_scale);

class CutlassSegmentGEMMPyTorchWrapper {
public:
void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer);
Expand Down
12 changes: 6 additions & 6 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
limitations under the License.
"""

from .activation import gelu_tanh_and_mul, silu_and_mul
from .cascade import (
MultiLevelCascadeAttentionWrapper,
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper,
merge_state,
merge_state_in_place,
merge_states,
Expand All @@ -27,8 +28,7 @@
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache,
)
from .activation import gelu_tanh_and_mul, silu_and_mul
from .group_gemm import SegmentGEMMWrapper
from .gemm import SegmentGEMMWrapper, bmm_fp8
from .norm import fused_add_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
from .prefill import (
Expand All @@ -46,15 +46,15 @@
)
from .sampling import (
chain_speculative_sampling,
min_p_sampling_from_probs,
sampling_from_probs,
top_k_renorm_prob,
top_k_mask_logits,
top_k_renorm_prob,
top_k_sampling_from_probs,
top_k_top_p_sampling_from_probs,
top_k_top_p_sampling_from_logits,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
top_p_sampling_from_probs,
min_p_sampling_from_probs,
)
from .sparse import BlockSparseAttentionWrapper

Expand Down
Loading