Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: accelerate JIT compilation speed #618

Merged
merged 4 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ src/generated/
python/csrc/generated/
python/flashinfer/_build_meta.py
python/flashinfer/jit/aot_config.py
python/csrc_aot/generated/
python/csrc-aot/generated/

# Package files
python/flashinfer/data/
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ repos:
- id: clang-format
types_or: [c++, c, cuda]
exclude: |
(?x)^(3rdparty/.* src/generated/.* python/flashinfer/jit/aot_config.py python/csrc_aot/generated/.*)$
(?x)^(3rdparty/.* src/generated/.* python/flashinfer/jit/aot_config.py)$

- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
Expand Down
5 changes: 3 additions & 2 deletions include/flashinfer/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

#include <memory>
#include <sstream>
#include <stdexcept>

#include "exception.h"

namespace flashinfer {

Expand All @@ -44,7 +45,7 @@ struct AlignedAllocator {
std::ostringstream oss;
oss << "Failed to allocate memory for " << name << " with size " << size << " and alignment "
<< alignment << " in AlignedAllocator";
throw std::runtime_error(oss.str());
FLASHINFER_ERROR(oss.str());
}
return nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(typename AttentionVariant::ParamsT
if (nblks.x == 0 || nblks.y == 0) {
std::ostringstream err_msg;
err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")";
throw std::runtime_error(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
dim3 nthrs = dim3(bdx, bdy, bdz);
float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM);
Expand Down
8 changes: 4 additions & 4 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params
err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal "
"to qo_len, got kv_len"
<< kv_len << " and qo_len " << qo_len;
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}

const uint32_t group_size = num_qo_heads / num_kv_heads;
Expand Down Expand Up @@ -1442,7 +1442,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params
<< " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
} else {
constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE;
constexpr uint32_t num_rows_per_cta = NUM_FRAGS_Q * NUM_WARPS_Q * 16;
Expand Down Expand Up @@ -2165,7 +2165,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
<< " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
} else {
// TODO(Zihao): fix the following computation
uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) +
Expand Down Expand Up @@ -2267,7 +2267,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
<< " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
} else {
// TODO(Zihao): fix the following computation
uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) +
Expand Down
10 changes: 5 additions & 5 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ struct DecodePlanInfo {
if (vec.size() != 10) {
std::ostringstream err_msg;
err_msg << "DecodePlanInfo::FromVector: vec.size() should be 10, but got " << vec.size();
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
padded_batch_size = vec[0];
v_offset = vec[1];
Expand Down Expand Up @@ -440,14 +440,14 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
std::ostringstream err_msg;
err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]"
<< qo_indptr_h[i] << " should be non-negative";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]);
if (kv_len_arr[i] < 0) {
std::ostringstream err_msg;
err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] << " - kv_indptr[" << i << "]"
<< kv_indptr_h[i] << " should be non-negative";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
sum_packed_qo_len += packed_qo_len_arr[i];
}
Expand Down Expand Up @@ -570,7 +570,7 @@ struct PrefillPlanInfo {
if (vec.size() != 14) {
std::ostringstream err_msg;
err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 14, but got " << vec.size();
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
padded_batch_size = vec[0];
total_num_rows = vec[1];
Expand Down Expand Up @@ -601,7 +601,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
<< num_kv_heads;
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}

// step 0: get the number of SMs
Expand Down
48 changes: 48 additions & 0 deletions include/flashinfer/exception.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_EXCEPTION_H_
#define FLASHINFER_EXCEPTION_H_

#include <exception>
#include <sstream>

namespace flashinfer {

class Error : public std::exception {
private:
std::string message_;

public:
Error(const std::string& func, const std::string& file, int line, const std::string& message) {
std::ostringstream oss;
oss << "Error in function '" << func << "' "
<< "at " << file << ":" << line << ": " << message;
message_ = oss.str();
}

virtual const char* what() const noexcept override { return message_.c_str(); }
};

#define FLASHINFER_ERROR(message) throw Error(__FUNCTION__, __FILE__, __LINE__, message)

#define FLASHINFER_CHECK(condition, message) \
if (!(condition)) { \
FLASHINFER_ERROR(message); \
}

} // namespace flashinfer

#endif // FLASHINFER_EXCEPTION_H_
20 changes: 11 additions & 9 deletions include/flashinfer/gemm/bmm_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@
#include <cublasLt.h>
#include <cuda_fp8.h>

#include <stdexcept>
#include <iostream>
#include <memory>
#include <type_traits>

#define FLASHINFER_CUBLAS_CHECK(EXPR) \
{ \
cublasStatus_t e = (EXPR); \
if (e != CUBLAS_STATUS_SUCCESS) { \
throw std::runtime_error("CUBLAS Error: " + std::string(cublasGetStatusString(e))); \
} \
#include "../exception.h"

#define FLASHINFER_CUBLAS_CHECK(EXPR) \
{ \
cublasStatus_t e = (EXPR); \
FLASHINFER_CHECK(e == CUBLAS_STATUS_SUCCESS, \
"CUBLAS Error: " + std::string(cublasGetStatusString(e))); \
}

#ifndef NDEBUG
Expand Down Expand Up @@ -131,7 +133,7 @@ cudaDataType_t get_cuda_data_type() {
} else if constexpr (std::is_same_v<T, half>) {
return CUDA_R_16F;
} else {
throw std::runtime_error("Unsupported type");
FLASHINFER_ERROR("Unsupported type");
}
}

Expand All @@ -155,7 +157,7 @@ cublasStatus_t bmm_fp8_internal_cublaslt(void* workspace, size_t workspace_size_
cudaDataType_t b_type = get_cuda_data_type<BT>();
cudaDataType_t d_type = get_cuda_data_type<DT>();
if (std::is_same_v<AT, __nv_fp8_e5m2> && std::is_same_v<BT, __nv_fp8_e5m2>) {
throw std::runtime_error("Unsupported combination: both A and B are e5m2");
FLASHINFER_ERROR("Unsupported combination: both A and B are e5m2");
}

auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true);
Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/gemm/group_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe
if (status != cutlass::Status::kSuccess) {
std::ostringstream err_msg;
err_msg << "cutlass group_gemm.initialize failed: " << cutlassGetStatusString(status);
throw std::runtime_error(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
status = gemm.run(stream);
if (status != cutlass::Status::kSuccess) {
std::ostringstream err_msg;
err_msg << "cutlass group_gemm.run failed: " << cutlassGetStatusString(status);
throw std::runtime_error(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
});

Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/gemm/group_gemm_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si
sizeof(DTypeIn) == 1) {
std::ostringstream err_msg;
err_msg << "Row-major layout is not supported for fp8 data type";
throw std::runtime_error(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
} else {
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
Expand Down
2 changes: 2 additions & 0 deletions include/flashinfer/math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <cstdint>

namespace flashinfer {
namespace math {

Expand Down
22 changes: 11 additions & 11 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

#include <cstdint>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>

#include "exception.h"

#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)

Expand Down Expand Up @@ -57,7 +57,7 @@

#define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, ...) \
if (allow_fp16_qk_reduction) { \
throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \
FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \
} else { \
constexpr bool ALLOW_FP16_QK_REDUCTION = false; \
__VA_ARGS__ \
Expand All @@ -73,7 +73,7 @@
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported num_frags_q: " << num_frags_q; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
}

#define DISPATCH_NUM_FRAGS_KV(max_frags_kv, NUM_FRAGS_KV, ...) \
Expand All @@ -92,7 +92,7 @@
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported max_frags_kv: " << max_frags_kv; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
}

#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \
Expand All @@ -115,7 +115,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand All @@ -138,7 +138,7 @@
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported group_size: " << group_size; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
}

#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \
Expand All @@ -161,7 +161,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported mask_mode: " << int(mask_mode); \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand Down Expand Up @@ -190,7 +190,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported head_dim: " << head_dim; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand All @@ -214,7 +214,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand Down Expand Up @@ -248,7 +248,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand Down
1 change: 0 additions & 1 deletion python/aot_MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

prune */__pycache__
prune csrc
prune csrc_aot
exclude aot_setup.py
exclude setup.py

Expand Down
14 changes: 7 additions & 7 deletions python/aot_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def write_if_different(path: pathlib.Path, content: str) -> None:


def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]:
path = root / "python" / "csrc_aot" / "generated"
path = root / "python" / "csrc" / "generated"
path.mkdir(parents=True, exist_ok=True)

head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
Expand Down Expand Up @@ -423,12 +423,12 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None:
"csrc/quantization.cu",
"csrc/rope.cu",
"csrc/sampling.cu",
"csrc_aot/activation.cu",
"csrc_aot/batch_decode.cu",
"csrc_aot/batch_prefill.cu",
"csrc_aot/flashinfer_ops.cu",
"csrc_aot/single_decode.cu",
"csrc_aot/single_prefill.cu",
"csrc/activation.cu",
"csrc/batch_decode.cu",
"csrc/batch_prefill.cu",
"csrc/single_decode.cu",
"csrc/single_prefill.cu",
"csrc/flashinfer_ops.cu",
]
+ files_decode
+ files_prefill,
Expand Down
Loading