Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ArtifactPath:
When compiling new cubins for backend directories, update the corresponding path.
"""

TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "b793e1b2cf7c419f070372ba55bbe53ca6fb9016/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988"
)
Expand Down Expand Up @@ -120,7 +120,7 @@ class CheckSumHash:
"""

TRTLLM_GEN_FMHA: str = (
"639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f"
"20c017db0761a30130f05080ed2078f6c8044c0c2b3be7c4353ec740034b4432"
)
TRTLLM_GEN_BMM: str = (
"85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf"
Expand Down
43 changes: 23 additions & 20 deletions include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ class TllmGenFmhaKernel {
inline uint64_t hashID(int qkvLayout, int maskType, int kernelType, int scheduler,
int multiCtasKvMode, int headDimPerCtaV, int headDimQk, int headDimV,
int tileSizeKv, int numTokensPerPage, int maxNumHeadsQPerKvInCta,
bool reuseSmemKForV, bool uses2CtaMma) const {
bool reuseSmemKForV, bool uses2CtaMma, bool sparseMla) const {
FLASHINFER_CHECK((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) &&
(headDimPerCtaV <= 2048) && (headDimQk <= 2048) && (headDimV <= 2048) &&
(numTokensPerPage <= 128),
"Expect (32 <= headDim <= 2048) && (numTokensPerPage <= 128), "
"got headDimPerCtaV=%d, headDimQk=%d, "
"headDimV=%d, numTokensPerPage=%d",
headDimPerCtaV, headDimQk, headDimV, numTokensPerPage);
(headDimPerCtaV <= 1024) && (headDimQk <= 1024) && (headDimV <= 1024),
"Expect (32 <= headDim <= 1024), got headDimPerCtaV=%d, headDimQk=%d, "
"headDimV=%d",
headDimPerCtaV, headDimQk, headDimV);
// The numTokensPerPage must be power of 2.
FLASHINFER_CHECK((numTokensPerPage & (numTokensPerPage - 1)) == 0,
"The numTokensPerPage must be power of 2.");
FLASHINFER_CHECK(maxNumHeadsQPerKvInCta <= 128,
"The maxNumHeadsQPerKvInCta <= 128 is required.");
FLASHINFER_CHECK(tileSizeKv == 64 || tileSizeKv == 128, "The tileSizeKv must be 64 or 128.");
Expand All @@ -113,25 +114,26 @@ class TllmGenFmhaKernel {
// Bit 8 - 11: kernelType.
// Bit 12 - 15: tileScheduler.
// Bit 16 - 17: multiCtasKvMode.
// Bit 18 - 24: (headDimPerCtaV >> 5).
// Bit 25 - 31: (headDimQk >> 5).
// Bit 32 - 38: (headDimV >> 5).
// Bit 39 - 40: (tileSizeKv >> 6).
// Bit 41 - 48: numTokensPerPage.
// Bit 18 - 25: (headDimPerCtaV >> 3).
// Bit 26 - 33: (headDimQk >> 3).
// Bit 34 - 41: (headDimV >> 3).
// Bit 42 - 43: (tileSizeKv >> 6).
// Bit 44 - 48: (log2(numTokensPerPage)).
// Bit 49 - 56: maxNumHeadsQPerKvInCta.
// Bit 57 - 57: reuseSmemKForV.
// Bit 58 - 58: uses2CtaMma.
// Bit 59 - 59: sparseMla.
return (static_cast<uint64_t>(qkvLayout) << 0) | (static_cast<uint64_t>(maskType) << 4) |
(static_cast<uint64_t>(kernelType) << 8) | (static_cast<uint64_t>(scheduler) << 12) |
(static_cast<uint64_t>(multiCtasKvMode) << 16) |
(static_cast<uint64_t>(headDimPerCtaV >> 5) << 18) |
(static_cast<uint64_t>(headDimQk >> 5) << 25) |
(static_cast<uint64_t>(headDimV >> 5) << 32) |
(static_cast<uint64_t>(tileSizeKv >> 6) << 39) |
(static_cast<uint64_t>(numTokensPerPage) << 41) |
(static_cast<uint64_t>(headDimPerCtaV >> 3) << 18) |
(static_cast<uint64_t>(headDimQk >> 3) << 26) |
(static_cast<uint64_t>(headDimV >> 3) << 34) |
(static_cast<uint64_t>(tileSizeKv >> 6) << 42) |
(static_cast<uint64_t>(log2(numTokensPerPage)) << 44) |
(static_cast<uint64_t>(maxNumHeadsQPerKvInCta) << 49) |
(static_cast<uint64_t>(reuseSmemKForV) << 57) |
(static_cast<uint64_t>(uses2CtaMma) << 58);
(static_cast<uint64_t>(uses2CtaMma) << 58) | (static_cast<uint64_t>(sparseMla) << 59);
}

uint64_t hashID(KernelMeta const& kernelMeta) const {
Expand All @@ -140,7 +142,7 @@ class TllmGenFmhaKernel {
kernelMeta.mHeadDimPerCtaV, kernelMeta.mHeadDimQk, kernelMeta.mHeadDimV,
kernelMeta.mTileSizeKv, kernelMeta.mNumTokensPerPage,
kernelMeta.mMaxNumHeadsQPerKvInCta, kernelMeta.mReuseSmemKForV,
kernelMeta.m2CtaMma);
kernelMeta.m2CtaMma, kernelMeta.mSparseMla);
}

std::pair<bool, std::string> checkIfKernelExist(RunnerParams const& params) const {
Expand Down Expand Up @@ -552,7 +554,8 @@ class TllmGenFmhaKernel {
static_cast<int>(selectKernelParams.mMultiCtasKvMode),
selectKernelParams.mHeadDimPerCtaV, params.mHeadDimQk, params.mHeadDimV,
selectKernelParams.mTileSizeKv, numTokensPerPage, maxNumHeadsQPerKvInCta,
selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma),
selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma,
/* sparseMla */ false),
info);
}

Expand Down
20 changes: 20 additions & 0 deletions include/flashinfer/trtllm/fmha/kernelParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ struct KernelParams {
// The sequence lengths for K/V. Required by pagedKv kernels to avoid unnecessary computation
// based on (ptrCumSeqLensKv[batchIdx + 1] - ptrCumSeqLensKv[batchIdx]).
int32_t const* ptrSeqLensKv;
// The reserved memory buffer.
int32_t* ptrReservedMem;
// The softmax stats buffer.
float2* ptrSoftmaxStats;

Expand Down Expand Up @@ -139,6 +141,8 @@ struct KernelParams {
int64_t mNumHiddenEltsO;
// The total number of pages in the paged-kv memory pool.
int32_t mNumPagesInMemPool;
// The number of tokens per page (used if dynamic numTokensPerPage is enabled).
int32_t mNumTokensPerPageLog2;
// The output scale for FP8 quantization.
float mOutputScale;
// The scaling factor for softmax (multiplied by log2 to use faster exp2).
Expand All @@ -147,11 +151,15 @@ struct KernelParams {
float mScaleSfKv;
// The SF scale for O.
float mScaleSfO;
// The reserved parameter.
float mReservedParam;
// The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase
// kernel when inflight batching is enabled in TRT-LLM.
int32_t mStartTokenIdxSfO;
// The sum of sequence lengths for Q and K/V.
int32_t mSumOfSeqLensQ, mSumOfSeqLensKv;
// The sparseMla topK value.
int32_t mSparseMlaTopK;
// The flag to use block sparse attention.
bool mUseBlockSparseAttention;

Expand Down Expand Up @@ -537,6 +545,8 @@ struct KernelParams {
int32_t maxNumCtasQ, int32_t maxNumCtasKv) {
// Create the return struct.
KernelParams params;
// Memset the kernel parameters to 0.
memset(&params, 0, sizeof(KernelParams));

// Get the device pointers for TMA descriptors.
auto [qPtr, kPtr, vPtr] = getDevicePtrs(options, get_size_in_bytes(kernelMeta.mDataTypeKv));
Expand Down Expand Up @@ -681,6 +691,16 @@ struct KernelParams {
// Default 0 means that chunked attention is disabled.
params.mChunkedAttentionSizeLog2 = 0;
}

// Compute the log of numTokensPerPage
int32_t numTokensPerPageLog2{-1};
if (isPagedKv(options.mQkvLayout)) {
FLASHINFER_CHECK((options.mNumTokensPerPage & (options.mNumTokensPerPage - 1)) == 0,
"NumTokensPerPage must be power of 2");
numTokensPerPageLog2 = (int)log2f((float)options.mNumTokensPerPage);
}
params.mNumTokensPerPageLog2 = numTokensPerPageLog2;

params.mMaxSeqLenQ = options.mMaxSeqLenQ;
params.mMaxSeqLenKv = options.mMaxSeqLenKv;
params.mMaxNumCtasQ = maxNumCtasQ;
Expand Down