diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py new file mode 100644 index 0000000000..52dfaeda03 --- /dev/null +++ b/benchmarks/bench_blackwell_attention.py @@ -0,0 +1,91 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from triton.testing import do_bench + +import flashinfer + + +def bench_fmha_blackwell( + batch_size, + qkv_len, + num_heads, + head_dim, + causal, + dtype, +): + q = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + k = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + v = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + + qo_segment_offsets = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len + ) + kv_segment_offsets = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len + ) + + o, lse = flashinfer.prefill.fmha_varlen( + q, k, v, qo_segment_offsets, kv_segment_offsets, causal=causal + ) + + ms = do_bench( + lambda: flashinfer.prefill.fmha_varlen( + q, + k, + v, + qo_segment_offsets, + kv_segment_offsets, + causal=causal, + ) + ) + + def flops(ms): + if causal: + return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 + else: + return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 + + print( + f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s" + ) + + +if __name__ == "__main__": + bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16) + bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16) + bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16) + bench_fmha_blackwell(16, 4096, 32, 128, False, torch.bfloat16) + bench_fmha_blackwell(8, 8192, 32, 128, False, torch.bfloat16) + bench_fmha_blackwell(4, 16384, 32, 128, False, torch.bfloat16) + bench_fmha_blackwell(2, 32768, 32, 128, False, torch.bfloat16) + bench_fmha_blackwell(1, 65536, 32, 128, False, torch.bfloat16) + + bench_fmha_blackwell(128, 512, 32, 128, True, torch.bfloat16) + bench_fmha_blackwell(64, 1024, 32, 128, True, torch.bfloat16) + bench_fmha_blackwell(32, 2048, 32, 128, True, torch.bfloat16) + bench_fmha_blackwell(16, 4096, 32, 128, True, torch.bfloat16) + bench_fmha_blackwell(8, 8192, 32, 128, True, torch.bfloat16) + bench_fmha_blackwell(4, 16384, 32, 128, True, torch.bfloat16) + bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16) + bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16) diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu new file mode 100644 index 0000000000..016aba492e --- /dev/null +++ b/csrc/fmha_cutlass_sm100.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "pytorch_extension_utils.h" + +#define DISPATCH_mask_mode(mask_mode, MASK_MODE, ...) \ + [&]() -> bool { \ + if (mask_mode == MaskMode::kNone) { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + return __VA_ARGS__(); \ + } else if (mask_mode == MaskMode::kCausal) { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + return __VA_ARGS__(); \ + } \ + return false; \ + }() + +#define DISPATCH_head_dim(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, ...) \ + [&]() -> bool { \ + if (head_dim_qk == 192 && head_dim_vo == 128) { \ + constexpr int HEAD_DIM_QK = 192; \ + constexpr int HEAD_DIM_VO = 128; \ + return __VA_ARGS__(); \ + } else if (head_dim_qk == 128 && head_dim_vo == 128) { \ + constexpr int HEAD_DIM_QK = 128; \ + constexpr int HEAD_DIM_VO = 128; \ + return __VA_ARGS__(); \ + } \ + return false; \ + }() + +#define DISPATCH_DTYPE_IN_OUT(in_dtype, out_dtype, c_type_in, c_type_out, ...) \ + [&]() -> bool { \ + if (in_dtype == out_dtype) { \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(in_dtype, c_type_in, [&] { \ + using c_type_out = c_type_in; \ + return __VA_ARGS__(); \ + }); \ + } \ + return false; \ + }() + +#define DISPATCH_context(DTypeIn, DTypeOut, HEAD_DIM_QK, HEAD_DIM_VO, MaskMode, ...) \ + { \ + DISPATCH_mask_mode(mask_mode, MaskMode, [&] { \ + return DISPATCH_DTYPE_IN_OUT(scalar_type_in, scalar_type_out, DTypeIn, DTypeOut, [&] { \ + return DISPATCH_head_dim(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, \ + [&] { return __VA_ARGS__(); }); \ + }); \ + }); \ + } + +using namespace flashinfer; + +void FMHACutlassSM100Run(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_lens, + at::Tensor kv_lens, at::Tensor qo_segment_offsets, + at::Tensor kv_segment_offsets, at::Tensor o, + std::optional maybe_lse, int64_t mask_mode_code, + double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim_qk, int64_t head_dim_vo, int64_t batch_size, + int64_t total_qo_len, int64_t total_kv_len, int64_t max_qo_len, + int64_t max_kv_len) { + CHECK(q.scalar_type() == k.scalar_type()); + auto scalar_type_in = q.scalar_type(); + auto scalar_type_out = o.scalar_type(); + MaskMode mask_mode = static_cast(mask_mode_code); + DISPATCH_context(DTypeIn, DTypeOut, HEAD_DIM_QK, HEAD_DIM_VO, MASK_MODE, [&] { + using cutlass_type_in = cutlass_dtype_t; + using cutlass_type_out = cutlass_dtype_t; + using TILE_Q = _256; + using TILE_KV = _128; + using D_QK = cute::Int; + using D_VO = cute::Int; + using TileShapeQK = Shape; + using TileShapePV = Shape; + using CutlassMaskMode = + typename std::conditional::type; + run_fmha_fwd( + q, k, v, qo_lens, kv_lens, qo_segment_offsets, kv_segment_offsets, o, maybe_lse, + mask_mode_code, sm_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, batch_size, + total_qo_len, total_kv_len, max_qo_len, max_kv_len); + + return true; + }); +} diff --git a/csrc/fmha_cutlass_sm100_pybind.cu b/csrc/fmha_cutlass_sm100_pybind.cu new file mode 100644 index 0000000000..773b79b02b --- /dev/null +++ b/csrc/fmha_cutlass_sm100_pybind.cu @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pytorch_extension_utils.h" + +void FMHACutlassSM100Run(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_lens, + at::Tensor kv_lens, at::Tensor qo_segment_offsets, + at::Tensor kv_segment_offsets, at::Tensor o, + std::optional maybe_lse, int64_t mask_mode_code, + double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim_qk, int64_t head_dim_vo, int64_t batch_size, + int64_t total_qo_len, int64_t total_kv_len, int64_t max_qo_len, + int64_t max_kv_len); + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("run", FMHACutlassSM100Run); } diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index c0b39c3741..d4ad402da8 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -43,6 +43,7 @@ from .attention import ( gen_customize_single_prefill_module as gen_customize_single_prefill_module, ) +from .attention import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module from .attention import gen_pod_module as gen_pod_module from .attention import gen_sampling_tvm_binding as gen_sampling_tvm_binding from .attention import gen_single_decode_module as gen_single_decode_module diff --git a/flashinfer/jit/attention/__init__.py b/flashinfer/jit/attention/__init__.py index 69e1551193..ffc556290c 100644 --- a/flashinfer/jit/attention/__init__.py +++ b/flashinfer/jit/attention/__init__.py @@ -31,6 +31,7 @@ from .pytorch import ( gen_customize_single_prefill_module as gen_customize_single_prefill_module, ) +from .pytorch import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module from .pytorch import gen_pod_module as gen_pod_module from .pytorch import gen_single_decode_module as gen_single_decode_module from .pytorch import gen_single_prefill_module as gen_single_prefill_module diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index 524ea19e51..ec3577d41c 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -20,7 +20,7 @@ import jinja2 import torch -from ..core import load_cuda_ops, logger, sm90a_nvcc_flags +from ..core import load_cuda_ops, logger, sm90a_nvcc_flags, sm100a_nvcc_flags from ..env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR from ..utils import ( dtype_map, @@ -1340,3 +1340,63 @@ def gen_customize_batch_prefill_module( ) else: raise ValueError(f"Invalid backend: {backend}") + + +def get_fmha_cutlass_sm100a_uri( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, +) -> str: + # NOTE(Zihao): use different uri after when support customize attention + return "fmha_cutlass_sm100a" + # return ( + # f"fmha_cutlass_sm100a_dtype_q_{filename_safe_dtype_map[dtype_q]}_" + # f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" + # f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" + # f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" + # f"head_dim_qk_{head_dim_qk}_" + # f"head_dim_vo_{head_dim_vo}_" + # f"posenc_{pos_encoding_mode}_" + # f"use_swa_{use_sliding_window}_" + # f"use_logits_cap_{use_logits_soft_cap}" + # ) + + +def gen_fmha_cutlass_sm100a_module( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, +): + uri = get_fmha_cutlass_sm100a_uri( + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_qk, + head_dim_vo, + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + ) + + source_paths = [ + FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100.cu", + FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100_pybind.cu", + ] + return load_cuda_ops( + uri, + source_paths, + extra_cuda_cflags=sm100a_nvcc_flags, + ) diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 7797ad2aaa..6d1564d497 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -78,6 +78,7 @@ def remove_unwanted_pytorch_nvcc_flags(): remove_unwanted_pytorch_nvcc_flags() sm90a_nvcc_flags = ["-gencode", "arch=compute_90a,code=sm_90a"] +sm100a_nvcc_flags = ["-gencode", "arch=compute_100a,code=sm_100a"] def load_cuda_ops( @@ -113,6 +114,7 @@ def load_cuda_ops( "-lineinfo", "--ptxas-options=-v", "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", + "-DCUTLASS_DEBUG_TRACE_LEVEL=2", ] else: # non debug mode diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 88a773a547..6200b05f9b 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -25,6 +25,7 @@ from .jit import ( gen_batch_prefill_module, gen_customize_batch_prefill_module, + gen_fmha_cutlass_sm100a_module, gen_single_prefill_module, get_batch_prefill_uri, get_single_prefill_uri, @@ -47,6 +48,7 @@ canonicalize_torch_dtype, determine_attention_backend, is_float8, + is_sm100a_supported, register_custom_op, register_fake_op, ) @@ -58,6 +60,35 @@ _batch_prefill_jit_modules = {} +@functools.cache +def get_fmha_module( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: PosEncodingMode, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool = False, +): + if is_sm100a_supported(torch.device("cuda")): + return gen_fmha_cutlass_sm100a_module( + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_qk, + head_dim_vo, + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + ) + else: + raise ValueError(f"SM100A is not supported on this device") + + def get_single_prefill_module(backend): def backend_module(*args): global _single_prefill_modules, _single_prefill_sm90_modules @@ -2335,9 +2366,12 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap use_fp16_qk_reduction, ) - self._cached_module = get_batch_prefill_module(self._backend)( - *get_module_args - ) + if self._backend == "cutlass": + self._cached_module = get_cutlass_mha_module()(*get_module_args) + else: + self._cached_module = get_batch_prefill_module(self._backend)( + *get_module_args + ) self._plan_info = self._cached_module.plan( self._float_workspace_buffer, @@ -2573,3 +2607,259 @@ def forward_return_lse( def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" pass + + +def fmha_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qo_segment_offsets: torch.Tensor, + kv_segment_offsets: torch.Tensor, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + sm_scale: Optional[float] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + module = get_fmha_module( + q.dtype, + k.dtype, + v.dtype, + torch.int32, + q.shape[2], + v.shape[2], + PosEncodingMode.NONE.value, + False, # use_sliding_window + False, # use_logits_soft_cap + ) + nnz_qo, num_qo_heads, head_dim_qk = q.shape + nnz_kv, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + sm_scale = 1.0 / math.sqrt(head_dim_qk) + + qo_lens = qo_segment_offsets[1:] - qo_segment_offsets[:-1] + kv_lens = kv_segment_offsets[1:] - kv_segment_offsets[:-1] + batch_size = qo_lens.shape[0] + max_qo_len = qo_lens.max() + max_kv_len = kv_lens.max() + qo_total_len = nnz_qo + + if out is None: + out = torch.empty( + qo_total_len + max(max_qo_len, 128), + num_qo_heads, + head_dim_vo, + device=q.device, + dtype=q.dtype, + )[max(max_qo_len, 128) :] + + if lse is None: + lse = torch.empty( + qo_total_len, num_qo_heads, device=q.device, dtype=torch.float32 + ) + + module.run( + q, + k, + v, + qo_lens, + kv_lens, + qo_segment_offsets, + kv_segment_offsets, + out, + lse, + mask_mode_code, + sm_scale, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + batch_size, + nnz_qo, + nnz_kv, + max_qo_len, + max_kv_len, + ) + + return out, lse + + +@functools.cache +def get_cutlass_mha_module(): + def backend_module(*args): + modules_dict = _batch_prefill_modules + + if args not in modules_dict: + uri = get_batch_prefill_uri("cutlass", *args) + module = get_fmha_module(*args) + + @register_custom_op( + f"flashinfer::{uri}_ragged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "o", + "maybe_lse", + ), + ) + def ragged_run( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + o: torch.Tensor, + maybe_lse: Optional[torch.Tensor], + mask_mode: int, + layout: int, + window_left: int, + maybe_custom_mask: Optional[torch.Tensor], + maybe_mask_indptr: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + maybe_prefix_len_ptr: Optional[torch.Tensor], + maybe_token_pos_in_items_ptr: Optional[torch.Tensor], + maybe_max_item_len_ptr: Optional[torch.Tensor], + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + token_pos_in_items_len: int, + ) -> None: + return fmha_varlen( + q, + k, + v, + qo_indptr, + kv_indptr, + o, + maybe_lse, + mask_mode == MaskMode.CAUSAL.value, + sm_scale, + ) + + @register_custom_op( + f"flashinfer::{uri}_paged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "paged_k_cache", + "paged_v_cache", + "o", + "maybe_lse", + ), + ) + def paged_run( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + o: torch.Tensor, + maybe_lse: Optional[torch.Tensor], + mask_mode: int, + layout: int, + window_left: int, + maybe_custom_mask: Optional[torch.Tensor], + maybe_mask_indptr: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + maybe_prefix_len_ptr: Optional[torch.Tensor], + maybe_token_pos_in_items_ptr: Optional[torch.Tensor], + maybe_max_item_len_ptr: Optional[torch.Tensor], + logits_soft_cap: float, + sm_scale: float, + scale_q: Optional[torch.Tensor], + scale_k: Optional[torch.Tensor], + scale_v: Optional[torch.Tensor], + rope_scale: float, + rope_theta: float, + token_pos_in_items_len: int, + ) -> None: + pass + + @register_fake_op(f"flashinfer::{uri}_ragged_run") + def _fake_ragged_run( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + o: torch.Tensor, + maybe_lse: Optional[torch.Tensor], + mask_mode: int, + layout: int, + window_left: int, + maybe_custom_mask: Optional[torch.Tensor], + maybe_mask_indptr: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + maybe_prefix_len_ptr: Optional[torch.Tensor], + maybe_token_pos_in_items_ptr: Optional[torch.Tensor], + maybe_max_item_len_ptr: Optional[torch.Tensor], + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + token_pos_in_items_len: int, + ) -> None: + pass + + @register_fake_op(f"flashinfer::{uri}_paged_run") + def _fake_paged_run( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + o: torch.Tensor, + maybe_lse: Optional[torch.Tensor], + mask_mode: int, + layout: int, + window_left: int, + maybe_custom_mask: Optional[torch.Tensor], + maybe_mask_indptr: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + maybe_prefix_len_ptr: Optional[torch.Tensor], + maybe_token_pos_in_items_ptr: Optional[torch.Tensor], + maybe_max_item_len_ptr: Optional[torch.Tensor], + logits_soft_cap: float, + sm_scale: float, + scale_q: Optional[torch.Tensor], + scale_k: Optional[torch.Tensor], + scale_v: Optional[torch.Tensor], + rope_scale: float, + rope_theta: float, + token_pos_in_items_len: int, + ) -> None: + pass + + def plan(*args): + return None + + # Register the module. + # + # Note that plan is not part of model logic. It should not be included in + # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. + modules_dict[args] = SimpleNamespace( + plan=plan, + ragged_run=ragged_run, + paged_run=paged_run, + ) + + return modules_dict[args] + + return backend_module diff --git a/flashinfer/triton/__init__.py b/flashinfer/triton/__init__.py index 9543034366..6247c071fd 100644 --- a/flashinfer/triton/__init__.py +++ b/flashinfer/triton/__init__.py @@ -1,6 +1,2 @@ from . import cascade # noqa: F401 from . import sm_constraint_gemm # noqa: F401 -from .format_conversion import pack_ragged_tensor as pack_ragged_tensor -from .format_conversion import ( - pad_ragged_tensor_to_multiple_of as pad_ragged_tensor_to_multiple_of, -) diff --git a/flashinfer/triton/format_conversion.py b/flashinfer/triton/format_conversion.py deleted file mode 100644 index a6b8f8e23a..0000000000 --- a/flashinfer/triton/format_conversion.py +++ /dev/null @@ -1,284 +0,0 @@ -""" -Copyright (c) 2025 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -from typing import Optional - -import torch -import triton -import triton.language as tl - - -@triton.jit -def _compute_padded_indptr( - indptr_ptr, padded_indptr_ptr, n_rows, multiple_of, BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_rows - - # Load row lengths - row_start = tl.load(indptr_ptr + offsets, mask=mask, other=0) - row_end = tl.load(indptr_ptr + offsets + 1, mask=mask, other=0) - row_lengths = row_end - row_start - - # Compute padded lengths (round up to multiple_of) - padded_lengths = ((row_lengths + multiple_of - 1) // multiple_of) * multiple_of - - # Compute cumulative sum for padded indptr - if pid == 0: - # First element is always 0 - tl.store(padded_indptr_ptr + 0, 0) - - # Store the padded lengths at the correct positions - tl.store(padded_indptr_ptr + offsets + 1, padded_lengths, mask=mask) - - -@triton.jit -def _pad_ragged_tensor( - ragged_tensor_ptr, - padded_tensor_ptr, - indptr_ptr, - padded_indptr_ptr, - n_rows, - dim, - BLOCK_SIZE: tl.constexpr, - fill_zeros: tl.constexpr, -): - pid = tl.program_id(0) - - # Process one row per program - if pid >= n_rows: - return - - # Get original and padded row information - row_start = tl.load(indptr_ptr + pid) - row_end = tl.load(indptr_ptr + pid + 1) - row_length = row_end - row_start - - padded_row_start = tl.load(padded_indptr_ptr + pid) - padded_row_end = tl.load(padded_indptr_ptr + pid + 1) - padded_row_length = padded_row_end - padded_row_start - - # Copy the original data - for i in range(0, row_length): - col_idx = i - src_offset = (row_start + i) * dim - dst_offset = (padded_row_start + i) * dim - - # Copy the entire feature vector for this position - for j in range(0, dim, BLOCK_SIZE): - j_offsets = j + tl.arange(0, BLOCK_SIZE) - j_mask = j_offsets < dim - values = tl.load(ragged_tensor_ptr + src_offset + j_offsets, mask=j_mask) - tl.store(padded_tensor_ptr + dst_offset + j_offsets, values, mask=j_mask) - - # Zero-pad the remaining positions - if fill_zeros: - for i in range(row_length, padded_row_length): - col_idx = i - dst_offset = (padded_row_start + i) * dim - - # Zero out the entire feature vector for this position - for j in range(0, dim, BLOCK_SIZE): - j_offsets = j + tl.arange(0, BLOCK_SIZE) - j_mask = j_offsets < dim - tl.store(padded_tensor_ptr + dst_offset + j_offsets, 0.0, mask=j_mask) - - -@triton.jit -def _pack_ragged_tensor( - padded_tensor_ptr, - packed_tensor_ptr, - padded_indptr_ptr, - original_indptr_ptr, - n_rows, - dim, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - # Process one row per program - if pid >= n_rows: - return - - # Get original and padded row information - original_row_start = tl.load(original_indptr_ptr + pid) - original_row_end = tl.load(original_indptr_ptr + pid + 1) - original_row_length = original_row_end - original_row_start - - padded_row_start = tl.load(padded_indptr_ptr + pid) - - # Copy only the original data (not the padding) - for i in range(0, original_row_length): - src_offset = (padded_row_start + i) * dim - dst_offset = (original_row_start + i) * dim - - # Copy the entire feature vector for this position - for j in range(0, dim, BLOCK_SIZE): - j_offsets = j + tl.arange(0, BLOCK_SIZE) - j_mask = j_offsets < dim - values = tl.load(padded_tensor_ptr + src_offset + j_offsets, mask=j_mask) - tl.store(packed_tensor_ptr + dst_offset + j_offsets, values, mask=j_mask) - - -def max_power_of_2_leq(x: int) -> int: - r"""Return the maximum power of 2 less than or equal to x.""" - return 1 << (x - 1).bit_length() - - -def pad_ragged_tensor_to_multiple_of( - ragged_tensor: torch.Tensor, - indptr: torch.Tensor, - multiple_of: int, - fill_zeros: bool = False, - output_ragged_tensor: Optional[torch.Tensor] = None, - output_indptr: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - r"""Pad each row of ragged tensor to a multiple of ``multiple_of``. - - Suppose the ragged tensor has shape (150, 1024), and the indptr is [0, 100, 150] (which means there are 2 rows, - the first row has 100 columns, the second row has 50 columns), and the multiple_of is 16. - We will pad the first row to 112 columns, and the second row to 64 columns. - The padded ragged tensor will have shape (176, 1024), and the returned indptr will be [0, 112, 176]. - - Parameters - ---------- - ragged_tensor: torch.Tensor - The ragged tensor to pad, expected shape: (nnz, D) - indptr: torch.Tensor - The indptr of the ragged tensor, expected shape: (n_rows + 1,) - multiple_of: int - The multiple of to pad to, e.g. 256 - fill_zeros: bool - If True, the padded positions will be filled with zeros, otherwise they will be random values, - default is False. - output_ragged_tensor: Optional[torch.Tensor] - If provided, the padded ragged tensor will be stored in this tensor, - otherwise a new tensor will be allocated. - output_indptr: Optional[torch.Tensor] - If provided, the padded indptr will be stored in this tensor, - otherwise a new tensor will be allocated. - - Returns - ------- - padded_ragged_tensor: torch.Tensor - The padded ragged tensor, expected shape: (n_rows, padded_nnz, D) - padded_indptr: torch.Tensor - The padded indptr, expected shape: (n_rows + 1,) - """ - # Get dimensions - n_rows = indptr.shape[0] - 1 - nnz = ragged_tensor.shape[0] - dim = ragged_tensor.shape[1] - - # First compute padded indptr - if output_indptr is None: - padded_indptr = torch.zeros_like(indptr) - else: - padded_indptr = output_indptr - - grid_size = triton.cdiv(n_rows, 128) - _compute_padded_indptr[(grid_size,)]( - indptr, padded_indptr, n_rows, multiple_of, BLOCK_SIZE=128 - ) - - # Perform exclusive scan to get final padded_indptr - padded_indptr[1:] = torch.cumsum(padded_indptr[1:], dim=0) - - # Allocate padded tensor - if output_ragged_tensor is None: - total_padded_length = padded_indptr[-1].item() - padded_ragged_tensor = torch.empty( - (total_padded_length, dim), - dtype=ragged_tensor.dtype, - device=ragged_tensor.device, - ) - else: - padded_ragged_tensor = output_ragged_tensor - - # Pad the tensor - _pad_ragged_tensor[(n_rows,)]( - ragged_tensor, - padded_ragged_tensor, - indptr, - padded_indptr, - n_rows, - dim, - BLOCK_SIZE=min(max_power_of_2_leq(dim), 16384), - num_stages=2, - fill_zeros=fill_zeros, - ) - - return padded_ragged_tensor, padded_indptr - - -def pack_ragged_tensor( - padded_tensor: torch.Tensor, - padded_indptr: torch.Tensor, - original_indptr: torch.Tensor, - output_tensor: Optional[torch.Tensor] = None, -) -> torch.Tensor: - r"""Convert a padded ragged tensor back to packed format. - - This function reverses the operation of pad_ragged_tensor_to_multiple_of by - removing the padding and returning the original packed tensor. - - Parameters - ---------- - padded_tensor: torch.Tensor - The padded ragged tensor, expected shape: (padded_nnz, D) - padded_indptr: torch.Tensor - The padded indptr, expected shape: (n_rows + 1,) - original_indptr: torch.Tensor - The original indptr before padding, expected shape: (n_rows + 1,) - output_tensor: Optional[torch.Tensor] - If provided, the packed tensor will be stored in this tensor, - otherwise a new tensor will be allocated. - - Returns - ------- - packed_tensor: torch.Tensor - The packed tensor with padding removed, expected shape: (original_nnz, D) - """ - # Get dimensions - n_rows = padded_indptr.shape[0] - 1 - dim = padded_tensor.shape[1] - original_nnz = original_indptr[-1].item() - - # Allocate output tensor if not provided - if output_tensor is None: - packed_tensor = torch.empty( - (original_nnz, dim), - dtype=padded_tensor.dtype, - device=padded_tensor.device, - ) - else: - packed_tensor = output_tensor - - # Pack the tensor by removing padding - _pack_ragged_tensor[(n_rows,)]( - padded_tensor, - packed_tensor, - padded_indptr, - original_indptr, - n_rows, - dim, - BLOCK_SIZE=min(max_power_of_2_leq(dim), 16384), - num_stages=2, - ) - - return packed_tensor diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 0f5696a675..c388aca52b 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -361,8 +361,8 @@ def is_sm90a_supported(device: torch.device) -> bool: def is_sm100a_supported(device: torch.device) -> bool: - major, minor = get_compute_capability(device) - return major == 10 and minor == 0 and torch.version.cuda >= "12.9" + major, _ = get_compute_capability(device) + return major == 10 and torch.version.cuda >= "12.8" def determine_mla_backend(device: torch.device) -> str: diff --git a/include/flashinfer/attention/blackwell/collective/fmha_common.hpp b/include/flashinfer/attention/blackwell/collective/fmha_common.hpp new file mode 100644 index 0000000000..e799725128 --- /dev/null +++ b/include/flashinfer/attention/blackwell/collective/fmha_common.hpp @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape, + int head_idx, int offset, int seq_len) { + // (N, D, H) + auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)), + make_coord(offset, _0{})); + auto g_sequence = + make_tensor(g_offset.data(), + make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride())); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); + return g_tensor; +} + +template +CUTLASS_DEVICE auto get_local_tile_t_tensor(const MTensor& m_tensor, const Shape& tile_shape, + int head_idx, int offset, int seq_len) { + // (D, N, H) + auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(get<0>(tile_shape), 1), + make_coord(_0{}, offset)); + auto g_sequence = + make_tensor(g_offset.data(), + make_layout(cute::make_shape(get<0>(tile_shape), seq_len), g_offset.stride())); + auto g_tensor = local_tile(g_offset, tile_shape, make_coord(_0{}, _)); + return g_tensor; +} + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_, _, k_block), tB(_, _, k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = decltype(atom.accumulate_)::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, prepend(make_layout(stages), _)); +} + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +template +CUTE_HOST_DEVICE constexpr auto to_tiled_mma_sm100_ts( + TiledMMA, + cute::C, cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, + TMs...>) { + return TiledMMA< + MMA_Atom>, + TAs...>, + TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr auto to_tiled_mma_sm100_ts( + TiledMMA< + MMA_Atom, + TAs...>, + TMs...>) { + return TiledMMA, + TAs...>, + TMs...>{}; +} + +template +CUTLASS_DEVICE void warpgroup_reg_set() { + if constexpr (RegCount < 128) { + cutlass::arch::warpgroup_reg_dealloc(); + } else { + cutlass::arch::warpgroup_reg_alloc(); + } +} + +} // namespace cutlass::fmha::collective diff --git a/include/flashinfer/attention/blackwell/collective/fmha_fusion.hpp b/include/flashinfer/attention/blackwell/collective/fmha_fusion.hpp new file mode 100644 index 0000000000..a8ddbd05ba --- /dev/null +++ b/include/flashinfer/attention/blackwell/collective/fmha_fusion.hpp @@ -0,0 +1,208 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct NoMask { + template + CUTLASS_DEVICE int get_trip_count(BlkCoord const& blk_coord, TileShape const& tile_shape, + ProblemSize const& problem_size) { + return ceil_div(get<1>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE int get_masked_trip_count(BlkCoord const& blk_coord, TileShape const& tile_shape, + ProblemSize const& problem_size) { + return 0; + } + + template + CUTLASS_DEVICE int get_unmasked_trip_count(BlkCoord const& blk_coord, TileShape const& tile_shape, + ProblemSize const& problem_size) { + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE void apply_mask(AccQK& acc_qk, IndexQK const& index_qk, + ProblemSize const& problem_size) { + return; + } +}; + +struct ResidualMask : NoMask { + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count(BlkCoord const& blk_coord, TileShape const& tile_shape, + ProblemSize const& problem_size) { + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE int get_unmasked_trip_count(BlkCoord const& blk_coord, TileShape const& tile_shape, + ProblemSize const& problem_size) { + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE void apply_mask(AccQK& acc_qk, IndexQK const& index_qk, + ProblemSize const& problem_size) { + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<1>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct CausalMask : NoMask { + using Base = NoMask; + + template + CUTLASS_DEVICE int get_trip_count(BlkCoord const& blk_coord, TileShape const& tile_shape, + ProblemSize const& problem_size) { + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int offset_q = int(get<1>(problem_size)) - int(get<0>(problem_size)); + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + int max_blocks_q = + ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + + template + CUTLASS_DEVICE int get_masked_trip_count(BlkCoord const& blk_coord, TileShape const& tile_shape, + ProblemSize const& problem_size) { + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE int get_unmasked_trip_count(BlkCoord const& blk_coord, TileShape const& tile_shape, + ProblemSize const& problem_size) { + return get_trip_count(blk_coord, tile_shape, problem_size) - + get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE void apply_mask(AccQK& acc_qk, IndexQK const& index_qk, + ProblemSize const& problem_size) { + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + int offset_q = int(get<1>(problem_size)) - int(get<0>(problem_size)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct VariableLength { + int max_length; + int* segment_offsets = nullptr; + int* lengths = nullptr; + + CUTE_HOST_DEVICE operator int() const { return max_length; } +}; + +template +struct is_variable_length : std::false_type {}; +template <> +struct is_variable_length : std::true_type {}; +template +constexpr bool is_variable_length_v = is_variable_length::value; + +template +CUTE_HOST_DEVICE constexpr auto apply_variable_length(Shape const& shape, Idx const& idx) { + return transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v>) { + return s.lengths[idx]; + } else { + return s; + } + }); +} + +template +CUTE_HOST_DEVICE constexpr auto apply_variable_length(Shape const& shape, Coord const& coord, + Idx const& idx) { + auto new_shape = apply_variable_length(shape, idx); + auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { + if constexpr (is_variable_length_v>) { + return cute::make_tuple(c, s.segment_offsets[idx]); + } else { + return c; + } + }); + return cute::make_tuple(new_shape, new_coord); +} + +} // namespace cutlass::fmha::collective + +namespace cute { + +template <> +struct is_integral : true_type {}; + +CUTE_HOST_DEVICE +void print(cutlass::fmha::collective::VariableLength a) { + printf("Varlen<%d, %p, %p>", a.max_length, a.segment_offsets, a.lengths); +} + +} // namespace cute diff --git a/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000000..102dffcc2f --- /dev/null +++ b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,181 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "fmha_common.hpp" + +namespace cutlass::fmha::collective { + +template +struct Sm100FmhaFwdEpilogueTmaWarpspecialized { + using Pipeline = cutlass::PipelineAsync<2>; + // using ShapeT = cute::Shape>; + // using StrideO = cute::Shape>; + // using LayoutO = cute::Layout; + using ShapeT = cute::Shape, int32_t>>; + using StrideO = cute::Shape, int32_t>>; + using LayoutO = cute::Layout; + + using ShapeLSE = cute::Shape>; + using StrideLSE = cute::Shape<_1, cute::Shape>; + using LayoutLSE = cute::Layout; + + // using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{}))); + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, + tuple_element_t<1, TileShape>>()); + // using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, + // _0>{})); + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); + using SmemLayoutO_ = SmemLayoutO; + + struct TensorStorage { + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + }; + struct Arguments { + Element* ptr_O; + LayoutO layout_O; + + ElementAcc* ptr_LSE; + LayoutLSE layout_LSE; + }; + + using TMA_O = decltype(make_tma_copy( + SM90_TMA_STORE{}, make_tensor((Element*)nullptr, repeat_like(StrideO{}, 0), StrideO{}), + SmemLayoutO{}(_, _, _0{}))); + + struct Params { + TMA_O tma_store_o; + LayoutO layout_O; + ElementAcc* ptr_LSE; + LayoutLSE layout_LSE; + }; + + template + static Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, + void* workspace = nullptr) { + static_assert(is_variable_length_v>); + auto ptr_O = args.ptr_O; + LayoutO layout_O = args.layout_O; + + auto tma_store_o = + make_tma_copy(SM90_TMA_STORE{}, make_tensor(ptr_O, layout_O), SmemLayoutO{}(_, _, _0{})); + + return {tma_store_o, layout_O, args.ptr_LSE, args.layout_LSE}; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE auto store(BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, Pipeline& pipeline, + typename Pipeline::PipelineState& pipeline_consumer_state) { + int qo_tile_idx = get<0>(blk_coord); + int qo_head_idx = get<2, 0>(blk_coord); + int batch_idx = get<2, 1>(blk_coord); + int qo_len = get<0>(problem_shape); + int qo_segment_offset = get<0>(params_problem_shape).segment_offsets[batch_idx]; + uint32_t lane_predicate = cute::elect_one_sync(); + + using X = Underscore; + + int o0_index = 2 * get<0>(blk_coord); + int o1_index = 2 * get<0>(blk_coord) + 1; + + // Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), params.layout_LSE); + // Tensor gLSE = get_lse_local_tile_tensor(mLSE, Shape>{}, qo_head_idx, qo_indptr, + // qo_len)(_, qo_tile_idx); + + int max_length_q = get<0>(params_problem_shape).max_length; + int offs_0 = max_length_q - qo_len; + int offs_2_1 = qo_segment_offset + qo_len; + BlkCoord blk_coord_updated = blk_coord; + get<2, 1>(blk_coord_updated) = 0; + + Tensor mO = params.tma_store_o.get_tma_tensor(params.layout_O.shape()); + + Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO); + + Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord_updated)); + + // auto gO = get_local_tile_tensor(mO, select<0, 1>(TileShape{}), qo_head_idx, + // qo_segment_offset, + // qo_len); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto block_tma = params.tma_store_o.get_slice(0); + Tensor tOsO = block_tma.partition_S(sO); + Tensor tOgO = block_tma.partition_D(gO); + + auto pipeline_release_state = pipeline_consumer_state; + + // O1 O2 + // one pipeline: O + // wait from corr, issue tma store on smem + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_, _, _, _0{}), tOgO(_, _, _, o0_index)); + } + tma_store_arrive(); + + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_, _, _, _1{}), tOgO(_, _, _, o1_index)); + } + tma_store_arrive(); + + tma_store_wait<1>(); + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + tma_store_wait<0>(); + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000000..b95edbcf0f --- /dev/null +++ b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1061 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/simd_sm100.hpp" +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "fmha_common.hpp" +#include "fmha_fusion.hpp" +#include "sm100_fmha_load_tma_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template > +struct Sm100FmhaFwdMainloopTmaWarpspecialized { + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using TileShape = decltype(select<0, 1>(TileShapeQK_{})); + using TileShapeQK = decltype(shape_div(TileShapeQK_{}, ThreadShape{})); + using TileShapePV = decltype(shape_div(TileShapePV_{}, ThreadShape{})); + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountKV = + get<2>(TileShapeQK{}) == 128 ? 2 : 1; // sizeof(Element_) == 1 ? 2 : 2; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, StrideQ, Alignment, Element, + StrideK, Alignment, ElementQK, TileShapeQK, ClusterShape, + cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, Element, StrideV, Alignment, ElementPV, TileShapePV, + ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = + decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = + decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = + decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + struct TensorStorage { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, // 128 + V0 = S0, // 0 // stats storage from softmax to correction + V1 = S1, // 128 + P0 = S0 + kSizeP, // 32 + P1 = S1 + kSizeP, // 160 + O0 = S1 + kSizeS, // 256 + O1 = O0 + kSizeO, // 384 + kEnd = O1 + kSizeO // 512 + }; + + // indices for V0 / V1 + enum : int { kIdxOldRowMax = 0, kIdxNewRowMax = 1, kIdxFinalRowSum = 0, kIdxFinalRowMax = 1 }; + + // from load to mma warp, protects q in smem + using PipelineQ = + cutlass::PipelineTmaUmmaAsync; + + // from load to mma warp, protects k/v in smem + using PipelineK = + cutlass::PipelineTmaUmmaAsync; + + using PipelineV = + cutlass::PipelineTmaUmmaAsync; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static const int TransactionBytesLoadQ = + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + + static const int TransactionBytesLoadK = + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutK{})) * cute::sizeof_bits_v); + + static const int TransactionBytesLoadV = + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + // static_assert( + // cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutK{})) * cute::sizeof_bits_v) == + // cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutV{})) * + // cute::sizeof_bits_v), + // "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadTmaWarpspecialized; + using LayoutQ = typename Load::LayoutQ; + using LayoutK = typename Load::LayoutK; + using LayoutV = typename Load::LayoutV; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float)std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o}; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void load(BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, PipelineQ& pipeline_q, + typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineK& pipeline_k, + typename PipelineK::PipelineState& pipeline_k_producer_state, + PipelineV& pipeline_v, + typename PipelineV::PipelineState& pipeline_v_producer_state) { + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, storage, pipeline_q, + pipeline_q_producer_state, pipeline_k, pipeline_k_producer_state, pipeline_v, + pipeline_v_producer_state); + } + + template + CUTLASS_DEVICE auto mma( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, PipelineQ& pipeline_q, + typename PipelineQ::PipelineState& pipeline_q_consumer_state, PipelineK& pipeline_k, + typename PipelineK::PipelineState& pipeline_k_consumer_state, PipelineV& pipeline_v, + typename PipelineV::PipelineState& pipeline_v_consumer_state, PipelineS& pipeline_s0, + typename PipelineS::PipelineState& pipeline_s0_producer_state, PipelineS& pipeline_s1, + typename PipelineS::PipelineState& pipeline_s1_producer_state, PipelineO& pipeline_corr, + typename PipelineO::PipelineState& pipeline_corr_producer_state) { + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_k_release_state = pipeline_k_consumer_state; + auto pipeline_v_release_state = pipeline_v_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0, 1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0, 1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = + make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_, _, _, q_index); + + // wait for K1 + k_index = pipeline_k_consumer_state.index(); + pipeline_k.consumer_wait(pipeline_k_consumer_state); + ++pipeline_k_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_, _, _, k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_k.consumer_release(pipeline_k_release_state); + ++pipeline_k_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_, _, _, q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_k_consumer_state.index(); + pipeline_k.consumer_wait(pipeline_k_consumer_state); + ++pipeline_k_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_, _, _, k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_k.consumer_release(pipeline_k_release_state); + ++pipeline_k_release_state; + + // wait for V1 + v_index = pipeline_v_consumer_state.index(); + pipeline_v.consumer_wait(pipeline_v_consumer_state); + ++pipeline_v_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_, _, _, v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_v.consumer_release(pipeline_v_release_state); + ++pipeline_v_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + // wait for Ki + k_index = (pipeline_k_consumer_state.index()); + pipeline_k.consumer_wait(pipeline_k_consumer_state); + ++pipeline_k_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_, _, _, k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_k.consumer_release(pipeline_k_release_state); + ++pipeline_k_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_v_consumer_state.index(); + pipeline_v.consumer_wait(pipeline_v_consumer_state); + ++pipeline_v_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_, _, _, v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_v.consumer_release(pipeline_v_release_state); + ++pipeline_v_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_k_consumer_state.index()); + pipeline_k.consumer_wait(pipeline_k_consumer_state); + ++pipeline_k_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_, _, _, k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_k.consumer_release(pipeline_k_release_state); + ++pipeline_k_release_state; + + // wait for Vi + v_index = (pipeline_v_consumer_state.index()); + pipeline_v.consumer_wait(pipeline_v_consumer_state); + ++pipeline_v_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_, _, _, v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_v.consumer_release(pipeline_v_release_state); + ++pipeline_v_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_v_consumer_state.index(); + pipeline_v.consumer_wait(pipeline_v_consumer_state); + ++pipeline_v_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_, _, _, v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_v.consumer_release(pipeline_v_release_state); + ++pipeline_v_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 + // S11 B2, ... Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * + // K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto softmax_step(float& row_max, float& row_sum, Stage stage, bool final_call, + BlkCoord const& blk_coord, CountingTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, + typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, + typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = + partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0, 1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i + 1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i + 2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i + 3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2(tTMEM_LOADrS(i + 0), tTMEM_LOADrS(i + 1)); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i + 0) = ::exp2f(tTMEM_LOADrS(i + 0)); + tTMEM_LOADrS(i + 1) = ::exp2f(tTMEM_LOADrS(i + 1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), + tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i + 1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i + 2), tTMEM_LOADrS(i + 2 + 1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i + 4), tTMEM_LOADrS(i + 4 + 1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i + 6), tTMEM_LOADrS(i + 6 + 1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto softmax(Stage stage, BlkCoord const& blk_coord, Params const& params, + ProblemShape const& problem_shape, PipelineS& pipeline_s, + typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, + typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0, 1>(TileShapeQK{})); + auto logical_offset = make_coord(get<0>(blk_coord) * get<0>(TileShape{}) + + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{})); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, pipeline_s, pipeline_s_consumer_state, pipeline_c, + pipeline_c_producer_state, order_s); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, pipeline_c, pipeline_c_producer_state, order_s); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto correction_epilogue(float scale, Stage stage, TensorO const& sO_01) { + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_, _, stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = + std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0, 1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0, 1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = + logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = + logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = + logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<1>(TileShapePV{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j + 1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j + 1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto correction_rescale(float scale, uint32_t tmem_O) { + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0, 1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0, 1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = + make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<1>(TileShapePV{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i + 1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j + 1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j + 1) = out.y; + } + + copy_out(i); + } + } + + template + CUTLASS_DEVICE auto correction( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, + TensorStorageEpi& shared_storage_epi, PipelineC& pipeline_s0_c, + typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, PipelineC& pipeline_s1_c, + typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, PipelineO& pipeline_o, + typename PipelineO::PipelineState& pipeline_o_consumer_state, PipelineE& pipeline_epi, + typename PipelineE::PipelineState& pipeline_epi_producer_state) { + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = + partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0, 1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0, 1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * + (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * + (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), + typename TensorStorageEpi::SmemLayoutO{}); + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + // correction_epilogue(params.scale_output, _0{}, sO); + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + // correction_epilogue(params.scale_output, _1{}, sO); + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/include/flashinfer/attention/blackwell/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp b/include/flashinfer/attention/blackwell/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp new file mode 100644 index 0000000000..4eb97b540b --- /dev/null +++ b/include/flashinfer/attention/blackwell/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp @@ -0,0 +1,82 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::collective { + +template +struct Sm100FmhaGenEpilogueWarpspecialized { + using Pipeline = cutlass::PipelineAsync<2>; + + using SmemLayoutO = Layout>; + using SmemLayoutO_ = SmemLayoutO; + using Element = Element_; + using StrideOOrig = StrideO_; + using StrideO = decltype(replace<0>(StrideOOrig{}, 0)); + + struct TensorStorage { + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + }; + + struct Arguments { + Element* ptr_o; + StrideO dO; + }; + + using Params = Arguments; + + const Params& params; + + CUTLASS_DEVICE Sm100FmhaGenEpilogueWarpspecialized(const Params& params) : params(params) {} + + template + static Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, + void* workspace = nullptr) { + return args; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { /* no-op */ } + + template + CUTLASS_DEVICE auto store(BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, Pipeline& pipeline, + typename Pipeline::PipelineState& pipeline_consumer_state) { + /* no-op */ + } +}; + +} // namespace cutlass::fmha::collective diff --git a/include/flashinfer/attention/blackwell/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/include/flashinfer/attention/blackwell/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp new file mode 100644 index 0000000000..8a96c1cd0f --- /dev/null +++ b/include/flashinfer/attention/blackwell/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -0,0 +1,1064 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_cpasync_warpspecialized.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template > +struct Sm100FmhaGenMainloopWarpspecialized { + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using ElementAcc = ElementPV_; + using ElementOut = ElementOut_; + using TileShape = TileShape_; + using StrideQOrig = StrideQ_; + using StrideQ = decltype(replace<0>(StrideQ_{}, 0)); + using StrideNewK = StrideNewK_; + using StrideNewV = StrideNewV_; + using StrideCacheK = StrideK_; + using StrideCacheV = StrideV_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using StrideOOrig = StrideO_; + using StrideO = decltype(replace<0>(StrideO_{}, 0)); + using Mask = Mask_; + + static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; + static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0, 2, 1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, StrideQ, Alignment, Element, + StrideK, Alignment, ElementQK, TileShapeQK, ClusterShape, + cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, Element, decltype(select<1, 0, 2>(StrideV{})), Alignment, + ElementPV, TileShapePV, ClusterShape, + cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = + decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = + decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = + decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { kIdxOldRowMax = 0, kIdxNewRowMax = 1, kIdxFinalRowSum = 0, kIdxFinalRowMax = 1 }; + + // from load to mma warp, protects q in smem + using PipelineQ = + cutlass::PipelineUmmaConsumerAsync; + + // from load to mma warp, protects k/v in smem + using PipelineKV = + cutlass::PipelineUmmaConsumerAsync; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static_assert( + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutK{})) * cute::sizeof_bits_v) == + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutV{})) * cute::sizeof_bits_v), + "K and V smem layouts must be of equal size"); + + using Load = + Sm100FmhaLoadCpAsyncWarpspecialized; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float)std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o}; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void load(BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, PipelineQ& pipeline_q, + typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, + typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, storage, pipeline_q, + pipeline_q_producer_state, pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto mma( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, PipelineQ& pipeline_q, + typename PipelineQ::PipelineState& pipeline_q_consumer_state, PipelineKV& pipeline_kv, + typename PipelineKV::PipelineState& pipeline_kv_consumer_state, PipelineS& pipeline_s0, + typename PipelineS::PipelineState& pipeline_s0_producer_state, PipelineS& pipeline_s1, + typename PipelineS::PipelineState& pipeline_s1_producer_state, PipelineO& pipeline_corr, + typename PipelineO::PipelineState& pipeline_corr_producer_state) { + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0, 1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0, 1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = + make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_, _, _, q_index); + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_, _, _, k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_, _, _, q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_, _, _, k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_, _, _, v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_, _, _, k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_, _, _, v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_, _, _, k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_, _, _, v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_, _, _, v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 + // S11 B2, ... Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * + // K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto softmax_step(float& row_max, float& row_sum, Stage stage, bool final_call, + BlkCoord const& blk_coord, CountingTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, + typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, + typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = + partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0, 1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i + 1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i + 2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i + 3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2(tTMEM_LOADrS(i + 0), tTMEM_LOADrS(i + 1)); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i + 0) = ::exp2f(tTMEM_LOADrS(i + 0)); + tTMEM_LOADrS(i + 1) = ::exp2f(tTMEM_LOADrS(i + 1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), + tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i + 1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i + 2), tTMEM_LOADrS(i + 2 + 1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i + 4), tTMEM_LOADrS(i + 4 + 1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i + 6), tTMEM_LOADrS(i + 6 + 1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto softmax(Stage stage, BlkCoord const& blk_coord, Params const& params, + ProblemShape const& problem_shape, PipelineS& pipeline_s, + typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, + typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0, 1>(TileShapeQK{})); + auto logical_offset = make_coord(get<0>(blk_coord) * get<0>(TileShape{}) + + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{})); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, pipeline_s, pipeline_s_consumer_state, pipeline_c, + pipeline_c_producer_state, order_s); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, pipeline_c, pipeline_c_producer_state, order_s); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto correction_epilogue(float scale_softmax_log2, float scale_out, + Vector const& v0, Vector const& v1, GTensor& gO, + CTensor const& cO, Shape const& g_shape, + Epilogue const& epilogue) { + using ElementOut = typename GTensor::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = + std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor tOtO = partition_fragment_C(mma, select<0, 1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOgO = mma.get_slice(0).partition_C(gO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO0 = tOtO_i; + tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO_i; + tOtO1.data() = tOtO1.data().get() + uint32_t(TmemAllocation::O1); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO0 = thr_tmem_load.partition_S(tOtO0); + Tensor tTMEM_LOADtO1 = thr_tmem_load.partition_S(tOtO1); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_LOADgO = thr_tmem_load.partition_D(tOgO_i); + + float row_max = std::max(v0(kIdxFinalRowMax), v1(kIdxFinalRowMax)); + float adj0 = ::exp2f(scale_softmax_log2 * (v0(kIdxFinalRowMax) - row_max)); + float adj1 = ::exp2f(scale_softmax_log2 * (v1(kIdxFinalRowMax) - row_max)); + float row_sum = adj0 * v0(kIdxFinalRowSum) + adj1 * v1(kIdxFinalRowSum); + float scale0 = scale_out * adj0 / row_sum; + float scale1 = scale_out * adj1 / row_sum; + + float2 scale0_f32x2 = make_float2(scale0, scale0); + float2 scale1_f32x2 = make_float2(scale1, scale1); + + // loop: + // TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 128 / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0; + tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1; + tTMEM_LOADtO1_i.data() = tTMEM_LOADtO1_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMEM_LOADgO_i = tTMEM_LOADgO; + tTMEM_LOADgO_i.data() = tTMEM_LOADgO_i.data().get() + i * kCorrectionTileSize * stride<1>(gO); + + Tensor tTMrO0 = make_tensor(shape(tTMEM_LOADcO)); + Tensor tTMrO1 = make_tensor(shape(tTMEM_LOADcO)); + + copy(tiled_tmem_load, tTMEM_LOADtO0_i, tTMrO0); + copy(tiled_tmem_load, tTMEM_LOADtO1_i, tTMrO1); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO0); j += 2) { + float2 in0 = make_float2(tTMrO0(j), tTMrO0(j + 1)); + float2 in1 = make_float2(tTMrO1(j), tTMrO1(j + 1)); + float2 out; + cute::mul(out, scale0_f32x2, in0); + cute::fma(out, scale1_f32x2, in1, out); + tTMrO0(j) = out.x; + tTMrO0(j + 1) = out.y; + } + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO0); + + Tensor tCs = recast(tTMrO0); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMgO_i = recast(tTMEM_LOADgO_i); + Tensor tSMrO_i = recast(tSMrO); + + // could use masking do this right for smaller D + if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMgO_i); + } + } + } + + CUTLASS_DEVICE auto correction_rescale(float scale, uint32_t tmem_O) { + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0, 1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0, 1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = + make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i + 1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j + 1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j + 1) = out.y; + } + + copy_out(i); + } + } + + template + CUTLASS_DEVICE auto correction( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, + TensorStorageEpi& shared_storage_epi, PipelineC& pipeline_s0_c, + typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, PipelineC& pipeline_s1_c, + typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, PipelineO& pipeline_o, + typename PipelineO::PipelineState& pipeline_o_consumer_state, PipelineE& pipeline_epi, + typename PipelineE::PipelineState& pipeline_epi_producer_state, Epilogue const& epilogue) { + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = + partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0, 1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0, 1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * + (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * + (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS0 = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS0); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + Tensor tTMEM_LOADVrS1 = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS1); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + auto pipeline_o_release_state = pipeline_o_consumer_state; + pipeline_o.consumer_wait(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + pipeline_o.consumer_wait(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + + Tensor cO = make_identity_tensor(select<0, 1>(TileShapePV{})); + auto g_shape = select<0, 2>(problem_shape); + auto mO = make_tensor(make_gmem_ptr(epilogue.params.ptr_o), + append<3>(select<0, 1>(TileShapePV{}), get<3>(problem_shape)), + epilogue.params.dO); + auto gO = mO(_, _, get<2>(blk_coord)); + + correction_epilogue(params.scale_softmax_log2, params.scale_output, tTMEM_LOADVrS0, + tTMEM_LOADVrS1, gO, cO, g_shape, epilogue); + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_release_state); + ++pipeline_o_release_state; + + pipeline_o.consumer_release(pipeline_o_release_state); + ++pipeline_o_release_state; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/include/flashinfer/attention/blackwell/collective/sm100_fmha_load_cpasync_warpspecialized.hpp b/include/flashinfer/attention/blackwell/collective/sm100_fmha_load_cpasync_warpspecialized.hpp new file mode 100644 index 0000000000..0d35a60b83 --- /dev/null +++ b/include/flashinfer/attention/blackwell/collective/sm100_fmha_load_cpasync_warpspecialized.hpp @@ -0,0 +1,362 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +struct Sm100FmhaLoadCpAsyncWarpspecialized { + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + const int* cache_batch_idx; + + const Element* ptr_q; + StrideQ dQ; + + const Element* ptr_new_k; + StrideNewK dNewK; + const Element* ptr_new_v; + StrideNewV dNewV; + + Element* ptr_cache_k; + StrideCacheK dCacheK; + Element* ptr_cache_v; + StrideCacheV dCacheV; + }; + + using Params = Arguments; + + template + static Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + return args; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) {} + + template + CUTLASS_DEVICE auto constexpr transpose(Tensor const& t) { + CUTE_STATIC_ASSERT_V(rank(t) == _2{}); + return t.compose( + make_layout(make_shape(size<1>(t), size<0>(t)), make_stride(size<0>(t), _1{}))); + } + + template + CUTLASS_DEVICE void copy_with_limit(TiledCopy const& tiled_copy, + CountTensor const& c, CountLimit const& l, + SrcTensor const& src, DstTensor&& dst) { + // copy(tiled_copy, src, dst); +#if 1 + auto c_f = make_tensor(c.data(), flatten(c.layout())); + auto src_f = make_tensor(src.data(), flatten(src.layout())); + auto dst_f = make_tensor(dst.data(), flatten(dst.layout())); + auto c_v = group_modes<1, rank_v>(c_f); + auto src_v = group_modes<1, rank_v>(src_f); + auto dst_v = group_modes<1, rank_v>(dst_f); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(src_v); i++) { + if (elem_less(c_v(_0{}, i), l)) { + copy(CAtom{}, src_v(_, i), dst_v(_, i)); + } else { + clear(dst_v(_, i)); + } + } +#endif + } + + template + CUTLASS_DEVICE void load(BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, PipelineQ& pipeline_q, + typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, + typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + mask_tile_count *= 2; + + int warp_idx = (threadIdx.x / 32) % 2; + int thread_idx = warp_idx * 32 + (threadIdx.x % 32); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + auto blk_coord_cache = blk_coord; + if (params.cache_batch_idx != nullptr) { + get<2, 1>(blk_coord_cache) = params.cache_batch_idx[get<2, 1>(blk_coord_cache)]; + } + + // Q1, K1, K2, V1, K3, V2, ... Kn, Vn-1, Vn + // two pipes: Q and KV + auto cQ = make_identity_tensor(select<0, 2>(TileShape{})); + auto mQ = make_tensor(make_gmem_ptr(params.ptr_q), + append<3>(select<0, 2>(TileShapeQK{}), get<3>(problem_shape)), params.dQ); + auto gQ = mQ(_, _, get<2>(blk_coord)); + auto sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + auto tSgQ = thr_mma_qk.partition_A(gQ); + auto tScQ = thr_mma_qk.partition_A(cQ); + + auto atom_q_tv = Layout, Shape<_16, _16>>, + Stride, Stride<_1, _1024>>>{}; + auto atom_kv_tv = Layout, Shape<_16, _4>>, + Stride, Stride<_1, _1024>>>{}; + + auto tiled_copy_q = make_cotiled_copy( + Copy_Atom, Element>{}, atom_q_tv, + make_layout(shape(tSgQ), + replace<0>(stride(tSgQ), replace<0>(stride<0>(tSgQ), get<2>(TileShape{}))))); + + auto thr_copy_q = tiled_copy_q.get_slice(thread_idx); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQ = thr_copy_q.partition_S(tSgQ); + auto tQcQ = thr_copy_q.partition_S(tScQ); + + auto limitQ = append<2>(get<0>(problem_shape), _128{}); + + // Q1 + int q0_index = get<0>(blk_coord); + + auto load_q = [&](int q_index, auto& state) { + pipeline_q.producer_acquire(state); + + // q is always loaded masked + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tQgQ(_, _, _, _)); + auto dst = recast(tQsQ(_, _, _, _, state.index())); + auto c = tQcQ(_, _, _, _); + int vlen = sizeof(Vec) / sizeof(Element); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen * i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitQ); + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Always>(dst_ptr, src_ptr, + guard); + } + + pipeline_q.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + }; + + load_q(q0_index, pipeline_q_producer_state); + ++pipeline_q_producer_state; + + auto cK_t = make_identity_tensor(select<1, 2>(TileShapeQK{})); + auto cK = make_tensor(cK_t.data(), + make_layout(get<0>(cK_t.layout()), get<1>(cK_t.layout()), + make_layout(_2{}, get<1>(TileShapeQK{}) * stride<0>(cK_t)))); + auto mK = make_tensor(make_gmem_ptr(params.ptr_cache_k), select<1, 2, 3>(problem_shape), + params.dCacheK); + auto gK = local_tile(mK(_, _, get<2>(blk_coord_cache)), TileShapeQK{}, make_coord(_, _, _0{}), + Step{}); + auto sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + auto tSgK = thr_mma_qk.partition_B(gK); + auto tScK = thr_mma_qk.partition_B(cK); + + auto tSlK = thr_mma_qk.partition_B(make_tensor( + (Element*)nullptr, make_ordered_layout(select<1, 2>(TileShapeQK{}), Step<_1, _0>{}))); + auto tiled_copy_k = make_cotiled_copy( + Copy_Atom, Element>{}, atom_kv_tv, tSlK.layout()); + + auto thr_copy_k = tiled_copy_k.get_slice(thread_idx); + + auto tKsK = thr_copy_k.partition_D(sK); + auto tKgK = thr_copy_k.partition_S(tSgK); + auto tKcK = thr_copy_k.partition_S(tScK); + + int seqlen_cache_kv = get<1>(problem_shape) - ((params.ptr_new_k != nullptr) ? 1 : 0); + auto limitK = append<2>(seqlen_cache_kv, _128{}); + + auto cV_t = make_identity_tensor(select<1, 2>(TileShapePV{})); + auto cV = make_tensor(cV_t.data(), + make_layout(get<0>(cV_t.layout()), get<1>(cV_t.layout()), + make_layout(_2{}, get<2>(TileShapePV{}) * stride<1>(cV_t)))); + auto mV = make_tensor(make_gmem_ptr(params.ptr_cache_v), select<2, 1, 3>(problem_shape), + select<1, 0, 2>(params.dCacheV)); + auto gV = local_tile(mV(_, _, get<2>(blk_coord_cache)), TileShapePV{}, make_coord(_, _0{}, _), + Step{}); + auto sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + typename CollectiveMmaPV::TiledMma mma_pv; + ThrMMA thr_mma_pv = mma_pv.get_slice(0); + auto tOgV = thr_mma_pv.partition_B(gV); + auto tOcV = thr_mma_pv.partition_B(cV); + auto tOlV = thr_mma_pv.partition_B( + make_tensor((Element*)nullptr, make_layout(select<1, 2>(TileShapePV{})))); + + auto tiled_copy_v = make_cotiled_copy( + Copy_Atom, Element>{}, atom_kv_tv, tOlV.layout()); + + auto thr_copy_v = tiled_copy_v.get_slice(thread_idx); + + auto tVsV = thr_copy_v.partition_D(sV); + auto tVgV = thr_copy_v.partition_S(tOgV); + auto tVcV = thr_copy_v.partition_S(tOcV); + + auto limitV = select<1, 0>(limitK); + + int full_tiles_cache = seqlen_cache_kv / get<1>(TileShapeQK{}); + + bool has_new = params.ptr_new_k != nullptr; + Tensor mNewK = + make_tensor(make_gmem_ptr(params.ptr_new_k), select<1, 2, 3>(problem_shape), params.dNewK); + Tensor mNewV = + make_tensor(make_gmem_ptr(params.ptr_new_v), select<1, 2, 3>(problem_shape), params.dNewV); + Tensor gNewK = mNewK(_, _, get<2>(blk_coord)); + Tensor gNewV = mNewV(_, _, get<2>(blk_coord)); + + auto load_k = [&](int k_index, auto& state) { + pipeline_kv.producer_acquire(state); + + if (k_index < full_tiles_cache) { + copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index())); + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } else { + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tKgK(_, _, _, _, k_index)); + auto dst = recast(tKsK(_, _, _, _, state.index())); + auto src2 = recast(gNewK); + auto c = tKcK(_, _, _, _, k_index); + int vlen = sizeof(Vec) / sizeof(Element); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen * i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitK); + if (get<0>(cc) == seqlen_cache_kv && has_new) { + src_ptr = &src2(_0{}, get<1>(cc) / vlen); + guard = true; + } + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>(dst_ptr, src_ptr, + guard); + } + + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } + }; + + auto load_v = [&](int v_index, auto& state) { + pipeline_kv.producer_acquire(state); + + if (v_index < full_tiles_cache) { + copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index())); + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } else { + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tVgV(_, _, _, _, v_index)); + auto dst = recast(tVsV(_, _, _, _, state.index())); + auto src2 = recast(gNewV); + int vlen = sizeof(Vec) / sizeof(Element); + auto c = tVcV(_, _, _, _, v_index); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen * i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitV); + if (get<1>(cc) == seqlen_cache_kv && has_new) { + src_ptr = &src2(_0{}, get<0>(cc) / vlen); + guard = true; + } + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>(dst_ptr, src_ptr, + guard); + } + + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } + }; + + // K1 + int k_index = 0; + int v_index = 0; + + load_k(k_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + k_index += 1; + + mask_tile_count -= 1; + + for (; mask_tile_count > 0; mask_tile_count -= 1) { + load_k(k_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + k_index += 1; + + load_v(v_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + v_index += 1; + } + + // V1 + + load_v(v_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + v_index += 1; + + if (has_new) { + for (int i = thread_idx; i < get<2>(TileShape{}); i += 64) { + gK(seqlen_cache_kv, i, 0) = gNewK(0, i); + gV(i, seqlen_cache_kv, 0) = gNewV(0, i); + } + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/include/flashinfer/attention/blackwell/collective/sm100_fmha_load_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/collective/sm100_fmha_load_tma_warpspecialized.hpp new file mode 100644 index 0000000000..4524b61b4c --- /dev/null +++ b/include/flashinfer/attention/blackwell/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -0,0 +1,265 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "../../../cutlass_utils.cuh" +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "fmha_common.hpp" +#include "fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +struct Sm100FmhaLoadTmaWarpspecialized { + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + using GmemTiledCopyKV = cute::SM90_TMA_LOAD; + static constexpr uint32_t NumStagesQ = PipelineQ::Stages; + + // (N, D, (H_R, H_G)) + using ShapeT = cute::Shape>; + // (N, D, (H_R, H_G)) + using StrideQ = cute::Shape>; + using StrideK = cute::Shape>; + using StrideV = cute::Shape<_1, int32_t, cute::Shape<_0, int32_t>>; + using LayoutQ = cute::Layout; + using LayoutK = cute::Layout; + using LayoutV = cute::Layout; + struct Arguments { + const Element* ptr_Q; + LayoutQ layout_Q; + const Element* ptr_K; + LayoutK layout_K; + const Element* ptr_V; + LayoutV layout_V; + }; + + // using ShapeLseT = cute::Shape; + // using StrideLseT = cute::Shape<_1, int64_t>; + // using LayoutLseT = cute::Layout; + + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(Shape<_1, _1, _1>{}), + make_tile(typename CollectiveMmaQK::TiledMma::AtomThrID{}))); + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_Q; + LayoutQ layout_Q; + TMA_K tma_load_K; + LayoutK layout_K; + TMA_V tma_load_V; + LayoutV layout_V; + }; + + template + static Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + static_assert(is_variable_length_v>); + static_assert(is_variable_length_v>); + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + LayoutQ layout_Q = args.layout_Q; + LayoutK layout_K = args.layout_K; + LayoutV layout_V = args.layout_V; + + auto mQ = make_tensor(make_gmem_ptr(ptr_Q), layout_Q); + auto mK = make_tensor(make_gmem_ptr(ptr_K), layout_K); + auto mV = make_tensor(make_gmem_ptr(ptr_V), layout_V); + + auto cluster_layout_vmnk = + tiled_divide(make_layout(Shape<_1, _1, _1>{}), + make_tile(typename CollectiveMmaQK::TiledMma::AtomThrID{})); + TMA_Q tma_load_Q = make_tma_atom_A_sm100( + GmemTiledCopyQ{}, mQ, SmemLayoutQ{}(_, _, _, _0{}), TileShapeQK{}, + typename CollectiveMmaQK::TiledMma{}, cluster_layout_vmnk); + TMA_K tma_load_K = make_tma_atom_B_sm100( + GmemTiledCopyKV{}, mK, SmemLayoutK{}(_, _, _, _0{}), TileShapeQK{}, + typename CollectiveMmaQK::TiledMma{}, cluster_layout_vmnk); + TMA_V tma_load_V = make_tma_atom_B_sm100( + GmemTiledCopyKV{}, mV, SmemLayoutV{}(_, _, _, _0{}), TileShapePV{}, + typename CollectiveMmaPV::TiledMma{}, cluster_layout_vmnk); + + return Params{tma_load_Q, layout_Q, tma_load_K, layout_K, tma_load_V, layout_V}; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void load(BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, PipelineQ& pipeline_q, + typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineK& pipeline_k, + typename PipelineK::PipelineState& pipeline_k_producer_state, + PipelineV& pipeline_v, + typename PipelineV::PipelineState& pipeline_v_producer_state) { + int qo_tile_idx = get<0>(blk_coord); + int qo_head_idx = get<2, 0>(blk_coord); + int batch_idx = get<2, 1>(blk_coord); + int qo_len = get<0>(problem_shape); + int kv_len = get<1>(problem_shape); + int qo_segment_offset = get<0>(params_problem_shape).segment_offsets[batch_idx]; + int kv_segment_offset = get<1>(params_problem_shape).segment_offsets[batch_idx]; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.layout_Q.shape()); + Tensor mK = params.tma_load_K.get_tma_tensor(params.layout_K.shape()); + Tensor mV = params.tma_load_V.get_tma_tensor(params.layout_V.shape()); + + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + auto gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShapeQK{}), qo_head_idx, qo_segment_offset, + qo_len); // (Q, D, _) + auto gK = get_local_tile_tensor(mK, select<1, 2>(TileShapeQK{}), qo_head_idx, kv_segment_offset, + kv_len); // (K, D, _) + auto gV = + get_local_tile_t_tensor(mV, select<1, 2>(TileShapePV{}), qo_head_idx, kv_segment_offset, + kv_len); // (K, D, _) + + int warp_idx = cutlass::canonical_warp_idx_sync(); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ); + Tensor tSgK_kdl = mma_qk.partition_B(gK); + Tensor tOgV_dkl = mma_pv.partition_B(gV); + auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, _0{}, Layout<_1>{}, group_modes<0, 3>(sQ), + group_modes<0, 3>(tSgQ_qdl)); // (TMA, q), (TMA, PIPE) + auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{}, group_modes<0, 3>(sK), + group_modes<0, 3>(tSgK_kdl)); // (TMA, k), (TMA, PIPE) + auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{}, group_modes<0, 3>(sV), + group_modes<0, 3>(tOgV_dkl)); // (TMA, k), (TMA, PIPE) + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord); + int q1_index = 2 * get<0>(blk_coord) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_Q.with(*tma_barrier, 0), tQgQ(_, q0_index), + tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_k.producer_acquire(pipeline_k_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_k.producer_get_barrier(pipeline_k_producer_state); + copy(params.tma_load_K.with(*tma_barrier, 0), tKgK(_, k_index), + tKsK(_, pipeline_k_producer_state.index())); + } + ++pipeline_k_producer_state; + k_index += 1; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_Q.with(*tma_barrier, 0), tQgQ(_, q1_index), + tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // V1 + int v_index = 0; + pipeline_v.producer_acquire(pipeline_v_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_v.producer_get_barrier(pipeline_v_producer_state); + copy(params.tma_load_V.with(*tma_barrier, 0), tVgV(_, v_index), + tVsV(_, pipeline_v_producer_state.index())); + } + ++pipeline_v_producer_state; + v_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + // Ki + pipeline_k.producer_acquire(pipeline_k_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_k.producer_get_barrier(pipeline_k_producer_state); + copy(params.tma_load_K.with(*tma_barrier, 0), tKgK(_, k_index), + tKsK(_, pipeline_k_producer_state.index())); + } + ++pipeline_k_producer_state; + k_index += 1; + + // Vi + pipeline_v.producer_acquire(pipeline_v_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_v.producer_get_barrier(pipeline_v_producer_state); + copy(params.tma_load_V.with(*tma_barrier, 0), tVgV(_, v_index), + tVsV(_, pipeline_v_producer_state.index())); + } + ++pipeline_v_producer_state; + v_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/include/flashinfer/attention/blackwell/common/pow_2.hpp b/include/flashinfer/attention/blackwell/common/pow_2.hpp new file mode 100644 index 0000000000..64542402c8 --- /dev/null +++ b/include/flashinfer/attention/blackwell/common/pow_2.hpp @@ -0,0 +1,89 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include +#include + +namespace cutlass::fmha { + +struct Pow2 { + int n; + int log2_n; + + explicit CUTE_DEVICE Pow2(int n) : n(n) { +#ifdef __CUDA_ARCH__ + log2_n = __ffs(n) - 1; +#endif + } + + template + CUTE_HOST_DEVICE T operator*(T const& b) const { + return n * b; + } + + template + CUTE_HOST_DEVICE auto operator*(Int const&) const { + if constexpr (N & (N - 1) == 0) { + return Pow2{n * N}; + } + return n * N; + } +}; + +template +CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) { + return a >> b.log2_n; +} + +template +CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) { + return a & (b.n - 1); +} + +template +CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) { + return a < b.n; +} + +CUTE_HOST_DEVICE void print(Pow2 const& a) { printf("2^%d", a.log2_n); } + +} // end namespace cutlass::fmha + +namespace cute { + +template <> +struct is_integral : true_type {}; + +} // end namespace cute diff --git a/include/flashinfer/attention/blackwell/device/fmha.hpp b/include/flashinfer/attention/blackwell/device/fmha.hpp new file mode 100644 index 0000000000..14fa867b99 --- /dev/null +++ b/include/flashinfer/attention/blackwell/device/fmha.hpp @@ -0,0 +1,251 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FMHA { + public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + + private: + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + + public: + /// Access the Params structure + Params const& params() const { return params_; } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Params const& params) { return Kernel::get_grid_shape(params); } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute(device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, device_kernel, Kernel::MaxThreadsPerBlock, smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own + /// params. Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr (Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*)device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = + ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } else { + launch_result = Status::kSuccess; + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel + // handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status operator()(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params + /// struct. + Status run(cudaStream_t stream = nullptr) { return run(params_, stream); } + + /// Overload that allows a user to re-launch the same kernel without updating internal params + /// struct. + Status operator()(cudaStream_t stream = nullptr) { return run(params_, stream); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/attention/blackwell/device/sm100_mla.hpp b/include/flashinfer/attention/blackwell/device/sm100_mla.hpp new file mode 100644 index 0000000000..90210eb8b4 --- /dev/null +++ b/include/flashinfer/attention/blackwell/device/sm100_mla.hpp @@ -0,0 +1,335 @@ +/*************************************************************************************************** + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +#include "kernel/sm100_fmha_mla_reduction.hpp" +#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +using namespace cute; +using namespace cutlass::fmha::kernel; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class MLA { + public: + using Kernel = Kernel_; + + using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< + typename Kernel::ElementOut, typename Kernel::ElementAcc, typename Kernel::ElementAcc, + Kernel::TileShapeH::value, Kernel::TileShapeL::value, 256 /*Max split*/ + >; + + /// Argument structure: User API + using KernelArguments = typename Kernel::Arguments; + using ReductionArguments = typename ReductionKernel::Arguments; + + using Arguments = KernelArguments; + + /// Argument structure: Kernel API + using KernelParams = typename Kernel::Params; + using ReductionParams = typename ReductionKernel::Params; + struct Params { + KernelParams fmha_params; + ReductionParams reduction_params; + }; + + private: + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + + static ReductionArguments to_reduction_args(Arguments const& args) { + auto [H, K, D, B] = args.problem_shape; + return ReductionArguments{nullptr, + args.epilogue.ptr_o, + nullptr, + args.epilogue.ptr_lse, + args.mainloop.softmax_scale, + B, + args.split_kv, + K, + args.mainloop.ptr_seq, + args.ptr_split_kv, + Kernel::TileShapeS::value}; + } + + public: + /// Access the Params structure + Params const& params() const { return params_; } + + static void set_split_kv(KernelArguments& args) { + if (args.split_kv >= 1) return; + auto [H, K, D, B] = args.problem_shape; + int sm_count = args.hw_info.sm_count; + int max_splits = ceil_div(K, 128); + int sms_per_batch = max(1, sm_count / B); + int split_heur = min(max_splits, sms_per_batch); + int waves = ceil_div(B * split_heur, sm_count); + int k_waves = ceil_div(max_splits, split_heur); + int split_wave_aware = ceil_div(max_splits, k_waves); + args.split_kv = split_wave_aware; + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + if (!Kernel::can_implement(args)) { + return Status::kInvalid; + } + if (!ReductionKernel::can_implement(to_reduction_args(args))) { + return Status::kInvalid; + } + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); + return workspace_bytes; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute(device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, device_kernel, Kernel::MaxThreadsPerBlock, smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); + if (status != Status::kSuccess) { + return status; + } + KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = + ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params{kernel_params, reduction_params}; + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + // no dynamic smem is needed for reduction kernel + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + auto fmha_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = + ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params{fmha_params, reduction_params}; + + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own + /// params. Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = Kernel::get_grid_shape(params.fmha_params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr (Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*)device_kernel; + void* kernel_params[] = {¶ms.fmha_params}; + launch_result = + ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } else { + launch_result = Status::kSuccess; + device_kernel<<>>(params.fmha_params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess != result or Status::kSuccess != launch_result) { + // return Status::kSuccess; + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + if (params.reduction_params.split_kv > 1) { + // launch reduction kernel + dim3 const block = ReductionKernel::get_block_shape(); + dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params); + device_kernel<<>>(params.reduction_params); + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result) { + return Status::kSuccess; + } else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } else { + return Status::kSuccess; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel + // handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status operator()(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params + /// struct. + Status run(cudaStream_t stream = nullptr) { return run(params_, stream); } + + /// Overload that allows a user to re-launch the same kernel without updating internal params + /// struct. + Status operator()(cudaStream_t stream = nullptr) { return run(params_, stream); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh b/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh new file mode 100644 index 0000000000..460b6dfc93 --- /dev/null +++ b/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp" +#include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "device/fmha.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp" +#include "pytorch_extension_utils.h" + +namespace flashinfer { + +using namespace cute; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::device; + +template +struct FwdRunner { + using Element = DTypeIn; + using ElementAccumulatorQK = float; + using ElementAccumulatorPV = float; + using ElementOut = DTypeOut; + + // Q K D ((H_R, H_KV), B) + using ProblemShapeVarlen = + cute::tuple, int>>; + + using StrideQ = cute::tuple>; // Q D (H_G H_R) + using StrideK = cute::tuple>; // K D (H_G H_R) + using StrideV = cute::tuple<_1, int, cute::tuple<_0, int>>; // D V (H_G H_R) + // NOTE(Zihao): use markus's trick for tma store + using StrideO = + cute::tuple, int>>; // Q D (H_G H_R) CUMULATIVE_Q + using StrideLSE = cute::tuple<_1, cute::tuple>; // Q (H_G H_R) + + using Mainloop = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeQK, TileShapePV, StrideQ, + StrideK, StrideV, ActiveMask>; + using Epilogue = cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized< + ElementOut, ElementAccumulatorPV, typename Mainloop::TileShapePV>; + using Operation = + cutlass::fmha::device::FMHA::value, + cutlass::fmha::kernel::NaiveTileScheduler, + cutlass::fmha::kernel::PersistentTileScheduler>::type>>; + using LayoutQ = typename Mainloop::LayoutQ; + using LayoutK = typename Mainloop::LayoutK; + using LayoutV = typename Mainloop::LayoutV; + using LayoutO = typename Epilogue::LayoutO; + using LayoutLSE = typename Epilogue::LayoutLSE; + + static void run(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_lens, at::Tensor kv_lens, + at::Tensor qo_segment_offsets, at::Tensor kv_segment_offsets, at::Tensor o, + std::optional maybe_lse, int mask_mode_code, double sm_scale, + int num_qo_heads, int num_kv_heads, int head_dim_qk, int head_dim_vo, + int batch_size, int total_qo_len, int total_kv_len, int max_qo_len, + int max_kv_len) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + + int h_r = num_qo_heads / num_kv_heads; + assert(num_qo_heads % num_kv_heads == 0); + ProblemShapeVarlen problem_shape = cute::make_tuple( + VariableLength{max_qo_len, static_cast(qo_segment_offsets.data_ptr()), + static_cast(qo_lens.data_ptr())}, + VariableLength{max_kv_len, static_cast(kv_segment_offsets.data_ptr()), + static_cast(kv_lens.data_ptr())}, + head_dim_qk, cute::make_tuple(cute::make_tuple(h_r, num_kv_heads), batch_size)); + + stride_Q = + make_stride(num_qo_heads * head_dim_qk, _1{}, make_stride(head_dim_qk, h_r * head_dim_qk)); + stride_O = make_stride( + num_qo_heads * head_dim_vo, _1{}, + make_stride(make_stride(head_dim_vo, h_r * head_dim_vo), num_qo_heads * head_dim_vo)); + stride_K = make_stride(num_kv_heads * head_dim_qk, _1{}, make_stride(_0{}, head_dim_qk)); + stride_V = make_stride(_1{}, num_kv_heads * head_dim_vo, make_stride(_0{}, head_dim_vo)); + stride_LSE = make_stride(_1{}, make_stride(total_qo_len, total_qo_len * h_r)); + + auto shape_Q = make_shape(total_qo_len, head_dim_qk, make_shape(h_r, num_kv_heads)); + auto shape_O = make_shape(max_qo_len, head_dim_vo, + make_shape(make_shape(h_r, num_kv_heads), max_qo_len + total_qo_len)); + auto shape_K = make_shape(total_kv_len, head_dim_qk, make_shape(h_r, num_kv_heads)); + auto shape_V = make_shape(head_dim_vo, total_kv_len, make_shape(h_r, num_kv_heads)); + auto shape_LSE = make_shape(total_qo_len, make_shape(h_r, num_kv_heads)); + + LayoutQ layout_Q = make_layout(shape_Q, stride_Q); + LayoutK layout_K = make_layout(shape_K, stride_K); + LayoutV layout_V = make_layout(shape_V, stride_V); + LayoutO layout_O = make_layout(shape_O, stride_O); + LayoutLSE layout_LSE = make_layout(shape_LSE, stride_LSE); + + typename Operation::Arguments arguments{ + problem_shape, + {static_cast(q.data_ptr()), layout_Q, static_cast(k.data_ptr()), + layout_K, static_cast(v.data_ptr()), layout_V}, + {static_cast(o.data_ptr()) - max_qo_len * get<0>(stride_O), layout_O, + static_cast(maybe_lse.value().data_ptr()), layout_LSE}, + hw_info}; + + Operation op; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = cutlass::Status::kSuccess; + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + } + + status = op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + } + + // Run + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + } + } +}; + +template +void run_fmha_fwd(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_lens, at::Tensor kv_lens, + at::Tensor qo_segment_offsets, at::Tensor kv_segment_offsets, at::Tensor o, + std::optional maybe_lse, int mask_mode_code, double sm_scale, + int num_qo_heads, int num_kv_heads, int head_dim_qk, int head_dim_vo, + int batch_size, int total_qo_len, int total_kv_len, int max_qo_len, + int max_kv_len) { + FwdRunner::run( + q, k, v, qo_lens, kv_lens, qo_segment_offsets, kv_segment_offsets, o, maybe_lse, + mask_mode_code, sm_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, batch_size, + total_qo_len, total_kv_len, max_qo_len, max_kv_len); +} + +}; // namespace flashinfer diff --git a/include/flashinfer/attention/blackwell/kernel/fmha_options.hpp b/include/flashinfer/attention/blackwell/kernel/fmha_options.hpp new file mode 100644 index 0000000000..5b32c079cf --- /dev/null +++ b/include/flashinfer/attention/blackwell/kernel/fmha_options.hpp @@ -0,0 +1,78 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +template +struct find_option; + +template +struct find_option { + using option_value = Default; +}; + +template +struct find_option + : std::conditional_t > {}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK +}; + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +} // namespace cutlass::fmha::kernel diff --git a/include/flashinfer/attention/blackwell/kernel/fmha_tile_scheduler.hpp b/include/flashinfer/attention/blackwell/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 0000000000..c2bad93ecd --- /dev/null +++ b/include/flashinfer/attention/blackwell/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,204 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments(ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, + TileShape const& tile_shape) { + using namespace cute; + dim3 grid( + round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), + size<3, 0>(problem_size), size<3, 1>(problem_size)); + return Params{grid}; + } + + static dim3 get_grid_shape(Params const& params) { return params.grid; } + + CUTLASS_DEVICE + bool is_valid() { return valid_; } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct NaiveTileScheduler { + struct Params { + int num_qo_tiles; + int batch_size; + int num_qo_heads; + }; + + int qo_tile_idx; + int batch_idx; + int qo_head_idx; + bool is_valid_tile; + + CUTLASS_DEVICE + NaiveTileScheduler(Params const& params) + : qo_tile_idx(blockIdx.x), + batch_idx(blockIdx.y), + qo_head_idx(blockIdx.z), + is_valid_tile(true) {} + + template + static Params to_underlying_arguments(ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, + TileShape const& tile_shape) { + return Params{ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<3, 0>(problem_size), + size<3, 1>(problem_size)}; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(params.num_qo_tiles, params.batch_size, params.num_qo_heads); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { return is_valid_tile; } + + CUTLASS_DEVICE + auto get_block_coord() { + return make_coord(qo_tile_idx, _0{}, make_coord(batch_idx, qo_head_idx)); + } + + CUTLASS_DEVICE + NaiveTileScheduler& operator++() { + is_valid_tile = false; + return *this; + } +}; + +struct PersistentTileScheduler { + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_h; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments(ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, + TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM " + "count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " + << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), + size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3, 0>(problem_size) * size<3, 1>(problem_size); + + return Params{num_blocks, + {num_m_blocks}, + {size<3, 0>(problem_size)}, + {size<3, 1>(problem_size)}, + hw_info}; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { return block_idx < params.num_blocks; } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidb, bidh)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000..e2c782b507 --- /dev/null +++ b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,506 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_fusion.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "fmha_options.hpp" +#include "fmha_tile_scheduler.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaCtxKernelWarpspecializedSchedule { + enum class WarpRole { Softmax0, Softmax1, Correction, MMA, Load, Epilogue, Empty }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; +}; + +template +struct Sm100FmhaFwdKernelTmaWarpspecialized { + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = ProblemShapeIn; + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineK::SharedStorage load_k; + alignas(16) typename CollectiveMainloop::PipelineV::SharedStorage load_v; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + cutlass::arch::ClusterBarrier barrier_O; + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct Arguments { + ProblemShape problem_shape; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, + TileShape{})}; + } + + CUTLASS_DEVICE auto apply_batch(const Params& params, ProblemShape const& problem_shape, + int batch_idx) { + return apply_variable_length(params.problem_shape, batch_idx); + } + + CUTLASS_DEVICE void operator()(const Params& params, char* smem) { + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, pipeline_load_q_params, ClusterShape{}, cute::true_type{}, + /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineK::Params pipeline_load_k_params; + typename CollectiveMainloop::PipelineV::Params pipeline_load_v_params; + if (role == WarpRole::Load) { + pipeline_load_k_params.role = CollectiveMainloop::PipelineK::ThreadCategory::Producer; + pipeline_load_v_params.role = CollectiveMainloop::PipelineV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_k_params.role = CollectiveMainloop::PipelineK::ThreadCategory::Consumer; + pipeline_load_v_params.role = CollectiveMainloop::PipelineV::ThreadCategory::Consumer; + } + pipeline_load_k_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_v_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_k_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK; + pipeline_load_v_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadV; + typename CollectiveMainloop::PipelineK pipeline_load_k( + shared_storage.pipelines.load_k, pipeline_load_k_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + typename CollectiveMainloop::PipelineV pipeline_load_v( + shared_storage.pipelines.load_v, pipeline_load_v_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, pipeline_mma_s0_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, pipeline_mma_s1_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr(shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr(shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, pipeline_mma_corr_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineE pipeline_corr_epi(shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01(shared_storage.pipelines.order_s01, + params_order_s01); + + TmemAllocator tmem_allocator; + + if (role == WarpRole::Load && lane_predicate) { + shared_storage.barrier_O.init(/*num_threads=*/1); + } + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_k.init_masks(ClusterShape{}); + pipeline_load_v.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineK::PipelineState pipeline_load_k_consumer_state; + typename CollectiveMainloop::PipelineK::PipelineState pipeline_load_k_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineV::PipelineState pipeline_load_v_consumer_state; + typename CollectiveMainloop::PipelineV::PipelineState pipeline_load_v_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = + cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01); + } + } else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.correction( + blk_coord, params.mainloop, logical_problem_shape, shared_storage.epilogue, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, pipeline_s1_corr, + pipeline_s1_corr_consumer_state, pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state); + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.mma( + blk_coord, params.mainloop, logical_problem_shape, shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, pipeline_load_k, + pipeline_load_k_consumer_state, pipeline_load_v, pipeline_load_v_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, pipeline_mma_s1, + pipeline_mma_s1_producer_state, pipeline_mma_corr, pipeline_mma_corr_producer_state); + } + } else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + int work_idx = 0; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + shared_storage.barrier_O.wait((work_idx + 1) % 2); + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.load(blk_coord, logical_problem_shape, params.mainloop, params.problem_shape, + shared_storage.mainloop, pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_k, pipeline_load_k_producer_state, pipeline_load_v, + pipeline_load_v_producer_state); + + work_idx++; + } + } else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + if (work_idx != 0) { + if (lane_predicate) { + shared_storage.barrier_O.arrive(0, lane_predicate); + } + } + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + epilogue.store(blk_coord, logical_problem_shape, params.epilogue, params.problem_shape, + shared_storage.epilogue, pipeline_corr_epi, + pipeline_corr_epi_consumer_state); + work_idx++; + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp new file mode 100644 index 0000000000..1e9a77502d --- /dev/null +++ b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp @@ -0,0 +1,530 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "collective/fmha_fusion.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaGenKernelWarpspecializedSchedule { + enum class WarpRole { Softmax0, Softmax1, Correction, MMA, Load, Epilogue, Empty }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + if (warp_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (warp_idx == 1) return WarpRole::MMA; // 12 + if (warp_idx == 2 || warp_idx == 3) return WarpRole::Load; // 13 + if (warp_idx == 4) return WarpRole::Softmax1; // 4 - 7 + if (warp_idx == 8) return WarpRole::Correction; // 8 - 11 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 1; + static const int NumWarpsCorrection = 1; + static const int NumWarpsEpilogue = 0; + static const int NumWarpsLoad = 2; + + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 104; + static const int NumRegsOther = 248; + static const int NumRegsEmpty = 24; + + static const int NumWarps = 12; +}; + +template +struct Sm100FmhaGenKernelWarpspecialized { + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = decltype(replace<0>(ProblemShapeIn{}, 0)); + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using StrideQOrig = typename CollectiveMainloop::StrideQOrig; + using StrideOOrig = typename CollectiveMainloop::StrideOOrig; + using StrideQ = typename CollectiveMainloop::StrideQ; + using StrideO = typename CollectiveMainloop::StrideO; + using StrideCacheK = typename CollectiveMainloop::StrideCacheK; + using StrideCacheV = typename CollectiveMainloop::StrideCacheV; + using StrideNewK = typename CollectiveMainloop::StrideNewK; + using StrideNewV = typename CollectiveMainloop::StrideNewV; + using Element = typename CollectiveMainloop::Element; + using ElementAcc = typename CollectiveMainloop::ElementAcc; + using ElementOut = typename CollectiveMainloop::ElementOut; + + struct Arguments { + // _1, max_seqlen_k, head_dim, ((h_g, h_kv), b) + ProblemShapeIn problem_shape; + const int* seqlen_kv; + const int* cache_batch_idx; + + const Element* ptr_q; // 1 x D x (H x B) + StrideQOrig dQ; + const Element* ptr_new_k; // 1 x D x (H x B) + StrideNewK dNewK; + const Element* ptr_new_v; // 1 x D x (H x B) + StrideNewV dNewV; + + Element* ptr_cache_k; // seqlen_max x D x (H x B) + StrideCacheK dCacheK; + Element* ptr_cache_v; // seqlen_max x D x (H x B) + StrideCacheV dCacheV; + ElementOut* ptr_o; // 1 x D x (H x B) + StrideOOrig dO; + + cutlass::KernelHardwareInfo hw_info; + + ElementAcc scale_softmax = 0.0f; + }; + + struct Params { + ProblemShape problem_shape; + const int* seqlen_kv; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { return true; } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + ProblemShape problem_shape = + replace<0>(args.problem_shape, static_cast(get<0>(args.problem_shape))); + CUTE_STATIC_ASSERT_V(get<0>(args.problem_shape) == _1{}); + StrideQ dQ = replace<0>(args.dQ, 0); + StrideO dO = replace<0>(args.dO, 0); + get<0>(problem_shape) = get<3, 0, 0>(args.problem_shape); + get<3, 0, 0>(problem_shape) = 1; + get<0>(dQ) = get<2, 0, 0>(dQ); + get<0>(dO) = get<2, 0, 0>(dO); + + typename CollectiveMainloop::Arguments mainloop_args{{ + args.cache_batch_idx, + args.ptr_q, + dQ, + args.ptr_new_k, + args.dNewK, + args.ptr_new_v, + args.dNewV, + args.ptr_cache_k, + args.dCacheK, + args.ptr_cache_v, + args.dCacheV, + }, + args.scale_softmax}; + + typename CollectiveEpilogue::Arguments epilogue_args{ + args.ptr_o, + dO, + }; + + return Params{ + problem_shape, args.seqlen_kv, + CollectiveMainloop::to_underlying_arguments(problem_shape, mainloop_args, workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shape, epilogue_args, workspace), + TileScheduler::to_underlying_arguments(problem_shape, args.hw_info, ClusterShape{}, + TileShape{})}; + } + + CUTLASS_DEVICE auto apply_batch(const Params& params, ProblemShape const& problem_shape, + int batch_idx) { + ProblemShape result = problem_shape; + get<1>(result) = params.seqlen_kv[batch_idx]; + if (params.mainloop.load.ptr_new_k != nullptr) { + get<1>(result) += 1; + } + return result; + } + + CUTLASS_DEVICE void operator()(const Params& params, char* smem) { + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, pipeline_load_q_params, ClusterShape{}, cute::true_type{}, + /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, pipeline_load_kv_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, pipeline_mma_s0_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, pipeline_mma_s1_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr(shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr(shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, pipeline_mma_corr_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineE pipeline_corr_epi(shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01(shared_storage.pipelines.order_s01, + params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = + cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = + cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue(params.epilogue); + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01); + } + } else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.correction( + blk_coord, params.mainloop, logical_problem_shape, shared_storage.epilogue, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, pipeline_s1_corr, + pipeline_s1_corr_consumer_state, pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state, epilogue); + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.mma(blk_coord, params.mainloop, logical_problem_shape, shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, pipeline_load_kv, + pipeline_load_kv_consumer_state, pipeline_mma_s0, + pipeline_mma_s0_producer_state, pipeline_mma_s1, + pipeline_mma_s1_producer_state, pipeline_mma_corr, + pipeline_mma_corr_producer_state); + } + } else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.load(blk_coord, logical_problem_shape, params.mainloop, params.problem_shape, + shared_storage.mainloop, pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state); + } + } else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + epilogue.store(blk_coord, logical_problem_shape, params.epilogue, params.problem_shape, + shared_storage.epilogue, pipeline_corr_epi, + pipeline_corr_epi_consumer_state); + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_mla_reduction.hpp b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_mla_reduction.hpp new file mode 100644 index 0000000000..40d843fa63 --- /dev/null +++ b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_mla_reduction.hpp @@ -0,0 +1,195 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +using namespace cute; +template +struct Sm100FmhaMlaReductionKernel { + static const int SharedStorageSize = 0; + static const int MaxThreadsPerBlock = 128; + static const int MinBlocksPerMultiprocessor = 1; + + using ArchTag = cutlass::arch::Sm100; + + static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0); + struct Arguments { + ElementAcc* ptr_oaccum = nullptr; + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_lseaccum = nullptr; + ElementAcc* ptr_lse = nullptr; + ElementScale scale = 1.f; + int num_batches = 0; + int split_kv = -1; + int dim_k = -1; + int* ptr_seq = nullptr; + int* ptr_split_kv = nullptr; + int tile_shape_s = 128; + }; + using Params = Arguments; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse, + args.scale, args.num_batches, args.split_kv, args.dim_k, + args.ptr_seq, args.ptr_split_kv, args.tile_shape_s}; + } + + static size_t get_workspace_size(Arguments const& /*args*/) { return 0; } + + static Status initialize_workspace(Arguments const& /*args*/, void* /*ws*/, + cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return dim3(kNumHeads, 1, params.num_batches); + } + + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } + + static bool can_implement(Arguments const& args) { + if (args.num_batches <= 0) return false; + if (args.split_kv <= 0) return false; + return true; + } + + CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { + if (params.split_kv <= 1) return; + auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z); + + __shared__ ElementAcc sLseScale[kMaxSplits]; + const size_t offset_lseaccum = + get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord); + const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord); + + Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum), + make_shape(params.split_kv), Stride>{}); + + Tensor gLSE = + make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse), Shape<_1>{}, Stride<_1>{}); + + auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)]; + auto local_split_kv = + params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)]; + auto k_tile_total = ceil_div(dim_k, params.tile_shape_s); + auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv); + local_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + ElementAcc local_lse[kNLsePerThread]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + local_lse[i] = split < local_split_kv ? gLSEaccum(split) + : -std::numeric_limits::infinity(); + } + + ElementAcc lse_max = -std::numeric_limits::infinity(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + lse_max = max(lse_max, local_lse[i]); + } + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset)); + } + lse_max = lse_max == -std::numeric_limits::infinity() + ? 0.0f + : lse_max; // In case all local LSEs are -inf + lse_max = __shfl_sync(0xffffffff, lse_max, 0); + + ElementAcc sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset); + } + + sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); + + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) + ? std::numeric_limits::infinity() + : logf(sum_lse) + params.scale * lse_max; + if (threadIdx.x == 0 and params.ptr_lse != nullptr) { + gLSE(0) = global_lse; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + if (split < local_split_kv) { + sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + } + __syncthreads(); + + constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock; + const size_t offset_oaccum = + kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord)); + Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum), + Shape>{}, Stride<_1>{}); + ElementAcc local_val[Elements] = {0}; + for (int split = 0; split < local_split_kv; ++split) { + ElementAcc lse_scale = sLseScale[split]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Elements; ++i) { + local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i); + } + gOaccum.data() = gOaccum.data() + kHeadDimLatent; + } + auto ptr_o_local = + params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent; + Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape>{}, Stride<_1>{}); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Elements; ++i) { + gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast(local_val[i]); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_mla_tma_warpspecialized.hpp new file mode 100644 index 0000000000..c761258494 --- /dev/null +++ b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -0,0 +1,1941 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "common/pow_2.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "gather_tensor.hpp" // from examples/common + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct Sm100FmhaMlaKernelTmaWarpspecialized { + + using Element = Element_; + using ElementAcc = ElementAcc_; + using ElementOut = ElementOut_; + using ElementLSE = ElementLSE_; + + // only 2Sm mode is supported + static const bool kIs2Sm = true; + static const int MaxThreadsPerBlock = 256; + static const int MinBlocksPerMultiprocessor = 1; + static const int TotalSNum = 2; + static const int TotalPNum = 2; + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = cute::conditional_t, Shape<_1, _1, _1>>; + + using TileShapeH = tuple_element_t<0, TileShape>; + using TileShapeS = tuple_element_t<1, TileShape>; + using TileShapeD = tuple_element_t<2, TileShape>; + + using TileShapeL = tuple_element_t<0, TileShapeD>; + using TileShapeR = tuple_element_t<1, TileShapeD>; + static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim"); + + using ProblemShape = Shape; + using TensorStride = Stride; + using TmemAllocator = + cute::conditional_t; + + static_assert(TileShapeH{} == 128); + static const int kWarpsInN = kIs2Sm ? 2 : 1; + + static const int kNumComputeWarps = 4; + static const int kNumLoadWarps = kIsCpAsync ? 2 : 1; + + enum class WarpRole { + kMma = 0x1, + kLoad = 0x2, + kCompute = 0x3, + kLoadPageTable = 0x4, + kEmpty = 0x0 + }; + + static const long long unsigned int kWarpAssignment = + kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; + + static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + static const int Alignment = 128 / sizeof_bits_v; + static const int AlignmentOut = 128 / sizeof_bits_v; + + using TileShapeQK = Shape; + static const int StagesQK = 24 / sizeof(Element); // free parameter + static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQK = IterationsQKLatent + IterationsQKRope; + + using Schedule = cute::conditional_t; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, TensorStride, Alignment, + Element, TensorStride, Alignment, ElementAcc, TileShapeQK, ClusterShape, + cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK; + + // chosen for unified smem staging between K and V + using TileShapePV = Shape; + using TransposeTensorStride = decltype(select<1, 0, 2>(TensorStride{})); + static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes + static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; + static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, TensorStride, Alignment, + Element, TransposeTensorStride, Alignment, ElementAcc, TileShapePV, ClusterShape, + cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; + using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK; + static_assert(std::is_same_v); + + using TiledMmaPV = typename CollectiveMmaPV::TiledMma; + + using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK; + static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == + typename CollectiveMmaPV::AtomThrShapeMNK{}, + "schedule must match"); + + static const int StagesPageTable = kIsCpAsync ? StagesPV : 1; + + // pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd + // use expect_tx for Q load + using PipelineLoadQK = + cute::conditional_t, + PipelineTmaUmmaAsync>; + using PipelineLoadPV = PipelineLoadQK; + // pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages + using PipelineS = PipelineUmmaAsync; + // pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages + using PipelineP = PipelineUmmaConsumerAsync; + // pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage + using PipelineO = PipelineUmmaAsync<1, AtomThrShapeMNK>; + + using PipelinePT = PipelineAsync; + + struct PipelineStorage { + alignas(16) typename PipelineLoadQK::SharedStorage load_qk; + alignas(16) typename PipelineS::SharedStorage mma_s; + alignas(16) typename PipelineP::SharedStorage p_mma; + alignas(16) typename PipelineO::SharedStorage mma_o; + alignas(16) typename PipelinePT::SharedStorage load_page_table; + }; + + template + static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutQ = + decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB; + using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, + make_shape(Int{}, _2{}))); + + static const int kBytesLoadQ = + size(AtomThrShapeMNK{}) * + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static const int kBytesLoadKC = + size(AtomThrShapeMNK{}) * + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutKC{})) * cute::sizeof_bits_v); + static const int kBytesLoadVC = + size(AtomThrShapeMNK{}) * + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutVC{})) * cute::sizeof_bits_v); + // pre-condition for overlapped smem staging + static_assert(kBytesLoadKC == kBytesLoadVC); + static_assert(StagesQK == StagesPV); + + static const int kTransactionsBytesLoadQK = kBytesLoadKC; + static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ; + static const int kTransactionsBytesLoadPV = kBytesLoadVC; + + static const int kNamedBarrierExchange = + (int)cutlass::arch::ReservedNamedBarriers::TransformBarrier; + // This Named Barrier is introduced to solve Q tile loading overwritten issue when enable + // persistent tile scheduler for FP8 MLA. + static const int kNamedBarrierEpilogue = + (int)cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; + // + static const int kNamedBarrierTmemDealloc = + (int)cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; + + enum class TmemAllocation : uint32_t { + kSizeS = TileShapeS::value / kWarpsInN, + // Overall + kSizeO = TileShapeL::value / kWarpsInN, + // Between accumulators we loop over + kSizeAccO = decltype(get<1>(TileShapePV{}))::value / kWarpsInN, + kNumS = TotalSNum, + kNumP = TotalPNum, + kNumO = 1, + kS0 = 0, + kS1 = kS0 + kSizeS, + kO0 = kS1 + kSizeS, + kTotal = kO0 + kSizeO + }; + + static_assert(static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem"); + + struct TensorStorage { + // to communicate max and row_sum + cute::array smem_exchange; + cute::array smem_page_table; + alignas(2048) cute::array> smem_q; + union { + alignas(2048) cute::array> smem_kc; + alignas(2048) cute::array> smem_vc; + }; + alignas(2048) cute::array> smem_p; + }; + + struct SharedStorage { + PipelineStorage pipelines; + TensorStorage tensors; + uint32_t tmem_base_ptr; + }; + + static const int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, + "using too much smem"); + + struct MainloopArguments { + ElementAcc softmax_scale; + + // all tensors strides are (num_heads or seqlen, head_dim, batch) + // head_dim stride is always 1 + Element* ptr_q_latent; + TensorStride stride_q_latent; + Element* ptr_q_rope; + TensorStride stride_q_rope; + + Element* ptr_c_latent; + TensorStride stride_c_latent; + Element* ptr_k_rope; + TensorStride stride_k_rope; + + // for paged attention, we interpret what was previously [batch, seqlen] + // as [page_count, page_size], and index according to page_table + int* ptr_seq = nullptr; + int* ptr_page_table = nullptr; + // page table is [batch, seqlen or similar] + Stride<_1, int> stride_page_table = {}; + int page_count = 0; + int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS + }; + + struct EpilogueArguments { + ElementOut* ptr_o = nullptr; + TensorStride stride_o; + ElementLSE* ptr_lse = nullptr; + Stride<_1, int> stride_lse; + ElementAcc output_scale = 1.0f; + }; + + struct Arguments { + // (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count) + // for paged attention, seqlen is max seqlen + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B; + + struct MainloopParams { + TmaLoadQLatent tma_load_q_latent; + TmaLoadQRope tma_load_q_rope; + TmaLoadCLatent tma_load_c_latent; + TmaLoadKRope tma_load_k_rope; + TmaLoadCLatentTranspose tma_load_c_latent_transpose; + }; + + struct EpilogueParams { + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_o_acc = nullptr; + TensorStride stride_o; + TensorStride stride_o_acc; + ElementLSE* ptr_lse = nullptr; + ElementLSE* ptr_lse_acc = nullptr; + Stride<_1, int> stride_lse; + Stride<_1, int> stride_lse_acc; + ElementAcc output_scale = 1.0f; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueParams epilogue; + MainloopParams mainloop_params; + typename TileScheduler::Params tile_scheduler; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + // workspace = nullptr; // let's get an error if one of these needs workspace + + auto [H, K, D, B] = args.problem_shape; + auto [L, R] = D; + + int paged_B = B; + int paged_K = K; + if (args.mainloop.ptr_page_table != nullptr) { + paged_B = args.mainloop.page_count; + paged_K = args.mainloop.page_size; + } + + auto params_qk_latent = + CollectiveMmaQK::to_underlying_arguments(make_shape(H, K, L, B), + typename CollectiveMmaQK::Arguments{ + args.mainloop.ptr_q_latent, + args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, + args.mainloop.stride_c_latent, + }, + nullptr); + + auto params_qk_latent_paged = + CollectiveMmaQK::to_underlying_arguments(make_shape(H, paged_K, L, paged_B), + typename CollectiveMmaQK::Arguments{ + args.mainloop.ptr_q_latent, + args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, + args.mainloop.stride_c_latent, + }, + nullptr); + + auto params_qk_rope = + CollectiveMmaQK::to_underlying_arguments(make_shape(H, K, R, B), + typename CollectiveMmaQK::Arguments{ + args.mainloop.ptr_q_rope, + args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, + args.mainloop.stride_k_rope, + }, + nullptr); + + auto params_qk_rope_paged = + CollectiveMmaQK::to_underlying_arguments(make_shape(H, paged_K, R, paged_B), + typename CollectiveMmaQK::Arguments{ + args.mainloop.ptr_q_rope, + args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, + args.mainloop.stride_k_rope, + }, + nullptr); + + auto stride_c_latent_transpose = select<1, 0, 2>(args.mainloop.stride_c_latent); + auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments( + make_shape(H, L, paged_K, paged_B), + typename CollectiveMmaPV::Arguments{ + args.mainloop.ptr_q_latent, + args.mainloop.stride_q_latent, // dummy, never used + args.mainloop.ptr_c_latent, + stride_c_latent_transpose, + }, + nullptr); + + MainloopParams mainloop_params{params_qk_latent.tma_load_a, params_qk_rope.tma_load_a, + params_qk_latent_paged.tma_load_b, + params_qk_rope_paged.tma_load_b, params_pv_latent.tma_load_b}; + + EpilogueParams epilogue_params; + + epilogue_params.ptr_o = args.epilogue.ptr_o; + epilogue_params.stride_o = args.epilogue.stride_o; + epilogue_params.ptr_lse = args.epilogue.ptr_lse; + epilogue_params.stride_lse = args.epilogue.stride_lse; + epilogue_params.output_scale = args.epilogue.output_scale; + + if (args.split_kv > 1) { + ElementAcc* ptr_o_acc = reinterpret_cast(workspace); + ElementLSE* ptr_lse_acc = + reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); + epilogue_params.ptr_o_acc = ptr_o_acc; + epilogue_params.ptr_lse_acc = ptr_lse_acc; + + epilogue_params.stride_o_acc = make_tuple(static_cast(0 + L) * args.split_kv, _1{}, + static_cast(0 + H * L) * args.split_kv); + epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv); + } + + return {args.problem_shape, + args.mainloop, + epilogue_params, + mainloop_params, + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, + args.split_kv), + args.split_kv, + args.ptr_split_kv}; + } + + static size_t get_workspace_size(Arguments const& args) { + ProblemShape problem_shape = args.problem_shape; + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto split_kv = args.split_kv; + return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B; + } + static Status initialize_workspace(Arguments const& /*args*/, void* /*ws*/, + cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static bool can_implement(Arguments const& args) { + if (kIsCpAsync) { + if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { + return false; + } + if (args.mainloop.page_size > TileShapeS{}) { + return false; + } + } else { + if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + return false; + } + } + if (get<0>(args.problem_shape) != 128) { + return false; + } + if (get<1>(args.problem_shape) <= 0) { + return false; + } + if (args.split_kv <= 0) { + return false; + } + return true; + } + + CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { + TileScheduler tile_scheduler(params.tile_scheduler); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{}); + bool is_mma_leader_cta = cta_coord_v == 0; + + if (role == WarpRole::kLoad && lane_predicate && !kIsCpAsync) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); + prefetch_tma_descriptor( + params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor()); + } + SharedStorage& shared_storage = *reinterpret_cast(smem_raw); + + typename PipelineLoadQK::Params pipeline_load_qk_params; + if (role == WarpRole::kLoad) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Producer; + } + if (role == WarpRole::kMma) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Consumer; + } + if constexpr (kIsCpAsync) { + // we can make our life easier by unconditionally loading blocks + // since we know it'll always be legal + pipeline_load_qk_params.producer_arv_count = + kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + } else { + pipeline_load_qk_params.is_leader = + lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; + pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK; + } + pipeline_load_qk_params.initializing_warp = 0; + PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, + /*mask calc*/ cute::false_type{}); + + typename PipelineS::Params pipeline_mma_s_params; + if (role == WarpRole::kMma) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s_params.consumer_arv_count = + kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_s_params.initializing_warp = 1; + PipelineS pipeline_mma_s(shared_storage.pipelines.mma_s, pipeline_mma_s_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename PipelineP::Params pipeline_p_mma_params; + if (role == WarpRole::kMma) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Consumer; + } + if (role == WarpRole::kCompute) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer; + } + pipeline_p_mma_params.producer_arv_count = + kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_p_mma_params.consumer_arv_count = 1; + pipeline_p_mma_params.initializing_warp = 2; + PipelineP pipeline_p_mma(shared_storage.pipelines.p_mma, pipeline_p_mma_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename PipelineO::Params pipeline_mma_o_params; + if (role == WarpRole::kMma) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_o_params.consumer_arv_count = + kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_o_params.initializing_warp = 3; + PipelineO pipeline_mma_o(shared_storage.pipelines.mma_o, pipeline_mma_o_params, ClusterShape{}, + /*barrier init*/ cute::true_type{}, /*mask calc*/ cute::false_type{}); + + typename PipelinePT::Params pipeline_pt_params; + if (role == WarpRole::kLoad) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Consumer; + } + if (role == WarpRole::kLoadPageTable) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer; + } + pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp; + pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp; + pipeline_pt_params.initializing_warp = 4; + PipelinePT pipeline_page_table(shared_storage.pipelines.load_page_table, pipeline_pt_params); + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm? + pipeline_mma_s.init_masks(ClusterShape{}); + pipeline_p_mma.init_masks(ClusterShape{}); + pipeline_mma_o.init_masks(ClusterShape{}); + + typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state; + typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = + cutlass::make_producer_start_state(); + + typename PipelineS::PipelineState pipeline_mma_s_consumer_state; + typename PipelineS::PipelineState pipeline_mma_s_producer_state = + cutlass::make_producer_start_state(); + + typename PipelineP::PipelineState pipeline_p_mma_consumer_state; + typename PipelineP::PipelineState pipeline_p_mma_producer_state = + cutlass::make_producer_start_state(); + + typename PipelineO::PipelineState pipeline_mma_o_consumer_state; + typename PipelineO::PipelineState pipeline_mma_o_producer_state = + cutlass::make_producer_start_state(); + + typename PipelinePT::PipelineState pipeline_pt_consumer_state; + typename PipelinePT::PipelineState pipeline_pt_producer_state = + cutlass::make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + if (role == WarpRole::kLoadPageTable) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) continue; + load_page_table(blk_coord, problem_shape, params.mainloop, shared_storage.tensors, + pipeline_page_table, pipeline_pt_producer_state, local_split_kv); + } + } else if (role == WarpRole::kLoad) { + if constexpr (kIsCpAsync) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) continue; + load_cpasync(blk_coord, problem_shape, params.mainloop, params.mainloop_params, + shared_storage.tensors, pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv, + /* must be shared pipe */ + pipeline_page_table, pipeline_pt_consumer_state); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue) + .arrive_and_wait(); + } + } else { + if (params.mainloop.ptr_page_table != nullptr) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) continue; + load_tma( + blk_coord, problem_shape, params.mainloop, params.mainloop_params, + shared_storage.tensors, pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, local_split_kv); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue) + .arrive_and_wait(); + } + } else { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) continue; + load_tma(blk_coord, problem_shape, params.mainloop, params.mainloop_params, + shared_storage.tensors, pipeline_load_qk, + pipeline_load_qk_producer_state, pipeline_load_qk, + pipeline_load_qk_producer_state, local_split_kv); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue) + .arrive_and_wait(); + } + } + } + } else if (role == WarpRole::kMma) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + + if (is_mma_leader_cta) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) continue; + mma(blk_coord, problem_shape, shared_storage.tensors, pipeline_load_qk, + pipeline_load_qk_consumer_state, pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_mma_s, pipeline_mma_s_producer_state, pipeline_p_mma, + pipeline_p_mma_consumer_state, pipeline_mma_o, pipeline_mma_o_producer_state, + local_split_kv); + } + } + + // cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, + // kNamedBarrierTmemDealloc).arrive_and_wait(); + + // uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + // tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } else if (role == WarpRole::kCompute) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) continue; + compute(blk_coord, problem_shape, + params.mainloop, // for softmax_scale + params.epilogue, + shared_storage.tensors, // for smem_comm + pipeline_mma_s, pipeline_mma_s_consumer_state, pipeline_p_mma, + pipeline_p_mma_producer_state, pipeline_mma_o, pipeline_mma_o_consumer_state, + local_split_kv); + } + + // cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, + // kNamedBarrierTmemDealloc).arrive(); + } + + cute::cluster_sync(); + cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, + kNamedBarrierTmemDealloc) + .arrive(); + if (role == WarpRole::kMma) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + template + CUTLASS_DEVICE void load_page_table( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, TensorStorage& shared_tensors, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) { + auto [H, K, D, B] = problem_shape; + int batch_coord = get<2>(blk_coord); + + auto mPT_l = + make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(mainloop_args.page_count, B), mainloop_args.stride_page_table); + auto mPT = mPT_l(_, batch_coord); + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + auto page_size = Pow2{mainloop_args.page_size}; + auto pages_per_tile = Pow2{TileShapeS{} / page_size}; + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp; + +#if 1 + for (; k_tile_count > 0; ++k_index, --k_tile_count) { + pipeline_page_table.producer_acquire(pipeline_pt_producer_state); + + // assume a single warp + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) { + int idx = i + thread_idx; + bool guard = idx < pages_per_tile; + int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx; + int pt_idx = pages_per_tile * k_index + idx; + + cutlass::arch::cp_async_zfill( + &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard); + } + + pipeline_page_table.producer_commit(pipeline_pt_producer_state, + cutlass::arch::cpasync_barrier_arrive); + ++pipeline_pt_producer_state; + } +#endif + } + + struct Gather { + int& page_table_stage; + Pow2 pages_per_tile; + const int* __restrict__ smem_page_table; + + CUTLASS_DEVICE int operator()(int idx) const { + return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile]; + } + + CUTLASS_DEVICE friend void print(Gather const&) { printf(""); } + }; + + template + CUTLASS_DEVICE void load_cpasync( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, PipelineLoadQK& pipeline_load, + typename PipelineLoadQK::PipelineState& pipeline_load_producer_state, int const& split_kv, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_consumer_state) { + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using X = Underscore; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // partition all tensors + auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), + mainloop_args.stride_q_latent); + auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), + mainloop_args.stride_q_rope); + + int paged_B = mainloop_args.page_count; + auto paged_K = Pow2{mainloop_args.page_size}; + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), + mainloop_args.stride_page_table); + + int batch_coord = get<2>(blk_coord); + auto mPT = mPT_l(_, batch_coord); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto make_copy_for = [](auto sT) { + auto rT_a = sT.layout()(_, _, _, _0{}); + auto rT = make_ordered_layout(shape(rT_a), stride(rT_a)); + auto threads = Int{}; + auto values = Int{}; + return make_cotiled_copy( + Copy_Atom, Element>{}, + make_ordered_layout(make_shape(threads, values), make_stride(_1{}, _0{})), rT); + }; + + // like cute::copy, but makes sure we do all page table lookups first + auto copy_split = [](auto atom, auto src, auto dst) { + auto src_v = group_modes<1, rank_v>(src); + auto dst_v = group_modes<1, rank_v>(dst); + + auto src_v_ptrs = make_tensor(size<1>(src_v)); + for (int i = 0; i < size<1>(src_v); i++) { + src_v_ptrs(i) = &src_v(_0{}, i); + } + + for (int i = 0; i < size<1>(src_v); i++) { + auto src_v_i = make_tensor(make_gmem_ptr(src_v_ptrs(i)), make_shape(shape<0>(src_v)), + make_stride(make_stride(_1{}, _0{}))); + atom.call(src_v_i, dst_v(_, i)); + } + }; + + auto tiled_copy_q = make_copy_for(sQ); + auto tiled_copy_kc = make_copy_for(sKC); + auto tiled_copy_vc = make_copy_for(sVC); + + auto thr_copy_q = + tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_kc = + tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_vc = + tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQL = thr_copy_q.partition_S(tSgQL); + auto tQgQR = thr_copy_q.partition_S(tSgQR); + + auto tKCsKC = thr_copy_kc.partition_D(sKC); + auto tVCsVC = thr_copy_vc.partition_D(sVC); + + auto pipeline_pt_release_state = pipeline_pt_consumer_state; + + int page_table_stage = -1; + Pow2 pages_per_tile{TileShapeS{} / paged_K}; + const int* __restrict__ smem_page_table = shared_tensors.smem_page_table.begin(); + Gather gather{page_table_stage, pages_per_tile, smem_page_table}; + + auto mCL = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout(make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), + example::CustomStride( + gather, get<2>(mainloop_args.stride_c_latent))), + get<1>(mainloop_args.stride_c_latent))), + make_coord(_0{}, _0{}), make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mKR = make_tensor( + make_gmem_ptr(mainloop_args.ptr_k_rope), + ComposedLayout{ + make_layout(make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), + example::CustomStride( + gather, get<2>(mainloop_args.stride_k_rope))), + get<1>(mainloop_args.stride_k_rope))), + make_coord(_0{}, _0{}), make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mCLT = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(_1{}, make_shape(paged_K, paged_B)), + make_stride(get<1>(mainloop_args.stride_c_latent), + make_stride(get<0>(mainloop_args.stride_c_latent), + example::CustomStride( + gather, get<2>(mainloop_args.stride_c_latent))))), + make_coord(_0{}, _0{}), make_identity_layout(make_shape(D_latent, paged_K * paged_B))}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_, _, _), Step{}); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + auto tKCgCL = thr_copy_kc.partition_S(tSgCL); + auto tKCgKR = thr_copy_kc.partition_S(tSgKR); + auto tVCgCLT = thr_copy_vc.partition_S(tOgCLT); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + auto& pipeline_acquire_state = pipeline_load_producer_state; + auto pipeline_commit_state = pipeline_acquire_state; + int pipeline_offset = 0; + + for (int i = 0; i < StagesPV; i++) { + cutlass::arch::cp_async_fence(); + } + + auto load_stage = [&](auto fn) { + pipeline_load.producer_acquire(pipeline_acquire_state); + fn(pipeline_acquire_state.index()); + cutlass::arch::cp_async_fence(); + + ++pipeline_acquire_state; + ++pipeline_offset; + + if (pipeline_offset == StagesPV - 1) { + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + }; + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i)); + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), + tQsQ(_, _, _, _, IterationsQKLatent + i)); + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), + tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + k_index += 1; + k_tile_count -= 1; + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), + tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + while (pipeline_offset > 0) { + cutlass::arch::cp_async_fence(); + + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + + cutlass::arch::cp_async_wait<0>(); + } + + template + CUTLASS_DEVICE void load_tma( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_producer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state, + int const& split_kv) { + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + using X = Underscore; + + // partition all tensors + auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B)); + auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B)); + + int paged_B = B; + int paged_K = K; + if constexpr (kIsPaged) { + paged_B = mainloop_args.page_count; + paged_K = mainloop_args.page_size; + } + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), + mainloop_args.stride_page_table); + + auto mCL = + mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B)); + auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B)); + + auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor( + make_shape(D_latent, paged_K, paged_B)); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_, _, _), Step{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto [tQLgQL_mkl, tQsQ] = + tma_partition(mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}), + group_modes<0, 3>(sQ), group_modes<0, 3>(tSgQL)); + + auto [tQRgQR_mkl, tQsQ_ignore] = + tma_partition(mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}), + group_modes<0, 3>(sQ), group_modes<0, 3>(tSgQR)); + + auto [tCLgCL_nkl, tKCsKC] = + tma_partition(mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}), + group_modes<0, 3>(sKC), group_modes<0, 3>(tSgCL)); + + auto [tKRgKR_nkl, tKCsKC_ignore] = + tma_partition(mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}), + group_modes<0, 3>(sKC), group_modes<0, 3>(tSgKR)); + + auto [tCLTgCLT_nkl, tVCsVC] = + tma_partition(mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}), + group_modes<0, 3>(sVC), group_modes<0, 3>(tOgCLT)); + + uint16_t mcast_mask = 0; + + int batch_coord = get<2>(blk_coord); + Tensor tQLgQL = tQLgQL_mkl(_, _, _, batch_coord); + Tensor tQRgQR = tQRgQR_mkl(_, _, _, batch_coord); + + auto mPT = mPT_l(_, batch_coord); + + Tensor tCLgCL = tCLgCL_nkl(_, _, _, _); + Tensor tKRgKR = tKRgKR_nkl(_, _, _, _); + + // careful: stage and k are swapped here! + Tensor tCLTgCLT = tCLTgCLT_nkl(_, _, _, _); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, + kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), + tQLgQL(_, _0{}, i), tQsQ(_, i)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy(mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } else { + cute::copy(mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, + kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), + tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy(mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } else { + cute::copy(mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // perform K load + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy(mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } else { + cute::copy(mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy(mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } else { + cute::copy(mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + // prefetch next K load to keep busy while we transpose-load from cache + const int kPrefetchDistance = 1; + for (int i = 0; i < IterationsQKLatent; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch(mainloop_params.tma_load_c_latent, + tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance))); + } + } else { + cute::prefetch(mainloop_params.tma_load_c_latent, + tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord)); + } + } + } + + for (int i = 0; i < IterationsQKRope; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch(mainloop_params.tma_load_k_rope, + tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance))); + } + } else { + cute::prefetch(mainloop_params.tma_load_k_rope, + tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord)); + } + } + } + + // perform V load (k_idx - 1) + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices! + // note we are off-by-one on k_index + if constexpr (kIsPaged) { + cute::copy(mainloop_params.tma_load_c_latent_transpose.with( + *tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index())); + } else { + cute::copy(mainloop_params.tma_load_c_latent_transpose.with( + *tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index())); + } + } + ++pipeline_load_pv_producer_state; + } + } + + k_index += 1; + k_tile_count -= 1; + } + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices + // note we are off-by-one on k_index + + if constexpr (kIsPaged) { + cute::copy(mainloop_params.tma_load_c_latent_transpose.with( + *tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index())); + } else { + cute::copy(mainloop_params.tma_load_c_latent_transpose.with( + *tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index())); + } + } + ++pipeline_load_pv_producer_state; + } + } + } + + template + CUTLASS_DEVICE void mma(BlkCoord const& blk_coord, ProblemShape const& problem_shape, + TensorStorage& shared_tensors, PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_consumer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_consumer_state, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_producer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_consumer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_producer_state, + int const& split_kv) { + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // mma init + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sP = make_tensor(make_smem_ptr((Element*)shared_tensors.smem_p.begin()), SmemLayoutP{}); + + Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ); + Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC); + Tensor tOrP = TiledMmaPV::make_fragment_A(sP); + Tensor tOrVC = TiledMmaPV::make_fragment_B(sVC); + + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShapePV{})); + + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero; + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + + // Mma S0 S1 O0 S2 O1 ... Sn On-1 On + // S0 ownership -- ----- -- -- + // S1 ownership -- ----- ---- + // O ownership -- -- ---- -- + + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 + : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, tSrQ(_, _, k_block, i), tSrKC(_, _, k_block, read_stage), tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 + : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, tSrQ(_, _, k_block, i), tSrKC(_, _, k_block, read_stage), tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_, _, k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_, _, k_block, read_stage), tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + + --k_tile_count; + } + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_, _, k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_, _, k_block, read_stage), tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + } + + template + CUTLASS_DEVICE void softmax(IsLastTile const& is_last_tile, ElementAcc& row_max, + ElementAcc& row_sum, ElementAcc& correction_factor, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, TensorStorage& shared_tensors, + int k_index, uint32_t tmem_s, int smem_p_index) { + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaQK tiled_mma_qk; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShapeQK{})); + tStS.data() = tmem_s; + + CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{}); + Tensor tAcc = tStS(make_coord(_, _), _0{}, _0{}); + + Tensor cS = make_identity_tensor(take<0, 2>(CtaShapeQK{})); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_cS = thread_t2r.partition_D(cS); + Tensor tTR_rAcc = make_tensor(shape(tTR_cS)); + + Tensor tTR_rS_frag = make_tensor(shape(tTR_rAcc)); + const int AlignmentS = 4; + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + Tensor tTR_rAcc_vec = recast>(tTR_rAcc); + Tensor tTR_rS_vec = recast>(tTR_rS_frag); + + // load s + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + if (is_last_tile) { + for (int i = 0; i < size(tTR_rAcc); i++) { + if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) { + tTR_rAcc(i) = -std::numeric_limits::infinity(); + } + } + } + + // max + ElementAcc row_max_new = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 1) { + row_max_new = ::fmax(row_max_new, tTR_rAcc(i)); + } + + // for 2x2 dp, reduce here + if constexpr (kWarpsInN > 1) { + shared_tensors.smem_exchange[threadIdx.x] = row_max_new; + cutlass::arch::NamedBarrier(kNumComputeWarps * NumThreadsPerWarp, kNamedBarrierExchange) + .sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); + } + +#ifndef B2B + // find correction factor + ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast(M_LOG2E); + correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new)); + row_max = row_max_new; + + // softmax + ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); + } +#endif + + // quantize + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_vec); i++) { + tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i)); + } + + Tensor sP = make_tensor(make_smem_ptr((Element*)shared_tensors.smem_p.begin()), SmemLayoutP{})( + _, _, _, make_coord(_, smem_p_index)); + + Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS); + + // have a mapping for each thread to coord + // find identical mapping to coords for the MMA + auto l = make_ordered_layout( + make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), + make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); + auto sP_ = as_position_independent_swizzle_tensor(sP); + copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _)); + + // sum + row_sum *= correction_factor; + + static_assert(cute::is_same_v); + auto tTR_rAcc_float2 = recast(tTR_rAcc); + auto sums = make_tensor(_4{}); + static_assert(size(tTR_rAcc_float2) % size(sums) == 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(sums); i++) { + sums(i) = tTR_rAcc_float2(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = size(sums); i < size(tTR_rAcc_float2); i += size(sums)) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j++) { + cute::add(sums(j), sums(j), tTR_rAcc_float2(i + j)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < size(sums); i *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j += 2 * i) { + cute::add(sums(j), sums(j), sums(j + i)); + } + } + row_sum += sums(0).x + sums(0).y; + } + + CUTLASS_DEVICE void rescale(ElementAcc correction_factor, uint32_t tmem_o) { + // for b2b gemm, do nothing +#ifndef B2B + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + auto store_op = TMEM::tmem_load_to_store(load_op); + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShapePV{})); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_, _), _0{}, _0{}); + + auto cta_tiler_pv = take<0, 2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*)nullptr), cta_tiler_pv, make_stride(0, 0)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto tiled_r2t = make_tmem_copy(store_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + auto thread_r2t = tiled_r2t.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + // load o + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + // multiply by correction factor + float2 correction_factor_vec = make_float2(correction_factor, correction_factor); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 2) { + float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1)); + float2 out; + cute::mul(out, in, correction_factor_vec); + tTR_rAcc(i + 0) = out.x; + tTR_rAcc(i + 1) = out.y; + } + + // store o + copy(tiled_r2t, tTR_rAcc, tTR_tAcc); +#endif + } + + template + CUTLASS_DEVICE void epilogue(ElementAcc& row_max, ElementAcc& row_sum, BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, TensorStorage& shared_tensors, + uint32_t tmem_o, int const& split_kv) { + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = + TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_, _), _0{}, _0{}); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + if (epilogue_args.ptr_o_acc != nullptr) { + using ElementOutAcc = ElementAcc; + constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), + make_shape(H, D_latent, B), epilogue_args.stride_o_acc); + auto cta_tiler_pv = take<0, 2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0, 3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination< + ElementOutAcc, 1, ElementAcc, ElementAcc, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling> + epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), + make_shape(H, B), epilogue_args.stride_lse_acc); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0, 3>(cta_coord), + Step<_1, Underscore, _1>{}); + // for 2x2 dp, this must be conditional and the index is wrong + if (!kIs2Sm || (threadIdx.x < 64)) { + gLSE(threadIdx.x) = lse; + } +#endif + } else { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), + epilogue_args.stride_o); + auto cta_tiler_pv = take<0, 2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0, 3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination< + ElementOut, 1, ElementAcc, ElementAcc, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling> + epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), + epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0, 3>(cta_coord), + Step<_1, Underscore, _1>{}); + + // for 2x2 dp, this must be conditional and the index is wrong + if (!kIs2Sm || (threadIdx.x < 64)) { + gLSE(threadIdx.x) = lse; + } + } +#endif + } + } + + template + CUTLASS_DEVICE void compute( + CtaCoord const& cta_coord, ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_consumer_state, PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_producer_state, PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_consumer_state, int const& split_kv) { + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + // if we return early, we have to make sure we release the load warp + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue) + .arrive(); + + return; + } + int k_index_final = k_tile_total - 1; + + ElementAcc row_max = -std::numeric_limits::infinity(); + ElementAcc row_sum = 0; + ElementAcc correction_factor = 1; + + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } else { + fn(cute::false_type{}); + } + }; + + // softmax s0 -> p0 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax(is_last_tile, row_max, row_sum, correction_factor, problem_shape, mainloop_args, + shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 + : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index()); + }); + + k_index += 1; + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + // softmax s1 -> p1 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax(is_last_tile, row_max, row_sum, correction_factor, problem_shape, mainloop_args, + shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 + : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index()); + }); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + // rescale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + rescale(correction_factor, + uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO)); + } + + cutlass::arch::fence_view_async_tmem_store(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + + --k_tile_count; + k_index += 1; + } + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + +#ifdef B2B + row_sum = 1; +#else + if constexpr (kWarpsInN > 1) { + // reduce row_sum if needed (for 2x2 dp) + shared_tensors.smem_exchange[threadIdx.x] = row_sum; + cutlass::arch::NamedBarrier(kNumComputeWarps * NumThreadsPerWarp, kNamedBarrierExchange) + .sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_sum += shared_tensors.smem_exchange[peer_index]; + } +#endif + + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue) + .arrive(); + + // epilogue + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + epilogue(row_max, row_sum, replace<1>(cta_coord, j), problem_shape, mainloop_args, + epilogue_args, shared_tensors, + uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv); + } + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/include/flashinfer/attention/blackwell/kernel/sm100_mla_tile_scheduler.hpp b/include/flashinfer/attention/blackwell/kernel/sm100_mla_tile_scheduler.hpp new file mode 100644 index 0000000000..f096c6d537 --- /dev/null +++ b/include/flashinfer/attention/blackwell/kernel/sm100_mla_tile_scheduler.hpp @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaIndividualTileScheduler { + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments(ProblemShape const& problem_shape, + KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, + split_kv /*Maximum Split KV*/); + return Params{grid}; + } + + static dim3 get_grid_shape(Params const& params) { return params.grid; } + + CUTLASS_DEVICE + bool is_valid() { return valid_; } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z); + } + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaPersistentTileScheduler { + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_split_kv; + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments(ProblemShape const& problem_shape, + KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM " + "count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " + << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = size<0>(cluster_shape); + int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */; + num_blocks *= split_kv; /* Maximum Split KV*/ + + return Params{num_blocks, {num_m_blocks}, {get<3>(problem_shape)}, {split_kv}, hw_info}; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { return block_idx < params.num_blocks; } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, n_split_kv; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_split_kv(block_decode, n_split_kv, block_decode); + return make_coord(m_block, _0{}, bidb, n_split_kv); + } + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/tests/test_blackwell_fmha.py b/tests/test_blackwell_fmha.py new file mode 100644 index 0000000000..74286aea68 --- /dev/null +++ b/tests/test_blackwell_fmha.py @@ -0,0 +1,165 @@ +import math + +import pytest +import torch + +import flashinfer +import flashinfer.triton +from flashinfer.utils import is_sm100a_supported + + +def attention_ref( + batch_size, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sm_scale: float, +) -> torch.Tensor: + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_qo_heads = q.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + ) + * sm_scale + ) + + if causal: + mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + else: + mask = torch.ones(qo_len, kv_len, device=q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_qo_heads, head_dim_vo) + .to(q) + ) + + return o_ref, lse_ref * math.log2(math.e) + + +@pytest.mark.parametrize("batch_size", [1, 2, 3, 17]) +@pytest.mark.parametrize("qo_len", [1, 17, 177, 377, 977]) +@pytest.mark.parametrize("kv_len", [1, 17, 544, 977, 1999]) +@pytest.mark.parametrize("num_qo_heads", [32]) +@pytest.mark.parametrize("num_kv_heads", [4, 32]) +@pytest.mark.parametrize("head_dim_qk", [192, 128]) +@pytest.mark.parametrize("head_dim_vo", [128]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("dtype", [torch.half]) +def test_blackwell_cutlass_fmha( + batch_size, + qo_len, + kv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + causal, + dtype, +): + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + if not is_sm100a_supported(torch.device("cuda")): + pytest.skip("SM100A is not supported on this device") + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, num_qo_heads, head_dim_qk, dtype=dtype, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + + k = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim_qk, dtype=dtype, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim_vo, dtype=dtype, device="cuda" + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + kv_layout="NHD", + backend="cutlass", + ) + sm_scale = 1.0 / (head_dim_qk**0.5) + wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + ) + o, lse = wrapper.run(q, k, v, return_lse=True) + + # gqa_group_ratio = num_qo_heads // num_kv_heads + # k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) + # v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) + # o_ref, lse_ref = attention_ref( + # batch_size, q, k_repeated, v_repeated, causal, sm_scale + # ) + + # lse_ref = lse_ref.flatten(0, 1) + # if dtype == torch.half: + # torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + # else: + # torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) + + # torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + + # test with pre-allocated output + # o_buffer = torch.empty_like(o) + # lse_buffer = torch.empty_like(lse) + # flashinfer.prefill.fmha( + # q, k, v, qo_lens, kv_lens, out=o_buffer, lse=lse_buffer, causal=causal + # ) + # torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) + # torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_blackwell_cutlass_fmha( + 1, + 1, + 1, + 32, + 4, + 192, + 128, + False, + torch.bfloat16, + # 3, + # 999, + # 999, + # 16, + # 8, + # 128, + # 128, + # False, + # torch.bfloat16, + ) diff --git a/tests/test_format_conversion.py b/tests/test_format_conversion.py deleted file mode 100644 index dd9dff4525..0000000000 --- a/tests/test_format_conversion.py +++ /dev/null @@ -1,163 +0,0 @@ -import numpy as np -import pytest -import torch - -from flashinfer.triton import pack_ragged_tensor, pad_ragged_tensor_to_multiple_of - - -def pad_ragged_tensor_to_multiple_of_pytorch_fill_zeros( - ragged_tensor, indptr, multiple_of -): - """PyTorch baseline implementation of pad_ragged_tensor_to_multiple_of.""" - n_rows = indptr.shape[0] - 1 - dim = ragged_tensor.shape[1] - - # Compute padded lengths for each row - row_lengths = indptr[1:] - indptr[:-1] - padded_lengths = ((row_lengths + multiple_of - 1) // multiple_of) * multiple_of - - # Compute padded indptr - padded_indptr = torch.zeros_like(indptr) - padded_indptr[1:] = torch.cumsum(padded_lengths, dim=0) - - # Allocate padded tensor - total_padded_length = padded_indptr[-1].item() - padded_ragged_tensor = torch.zeros( - (total_padded_length, dim), - dtype=ragged_tensor.dtype, - device=ragged_tensor.device, - ) - - # Copy data from original tensor to padded tensor - for i in range(n_rows): - row_start = indptr[i].item() - row_end = indptr[i + 1].item() - row_length = row_end - row_start - - padded_row_start = padded_indptr[i].item() - - # Copy the original data - padded_ragged_tensor[padded_row_start : padded_row_start + row_length] = ( - ragged_tensor[row_start:row_end] - ) - - return padded_ragged_tensor, padded_indptr - - -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("n_rows", [1, 2, 5, 10]) -@pytest.mark.parametrize("dim", [64, 128, 1024, 4096, 16384]) -@pytest.mark.parametrize("multiple_of", [8, 16, 32, 64, 128]) -def test_pad_ragged_tensor_to_multiple_of(dtype, n_rows, dim, multiple_of): - device = torch.device("cuda:0") - torch.manual_seed(42) - - # Create random row lengths - row_lengths = torch.randint(1, 100, (n_rows,), device=device) - - # Create indptr - indptr = torch.zeros(n_rows + 1, dtype=torch.int32, device=device) - indptr[1:] = torch.cumsum(row_lengths, dim=0) - - # Create ragged tensor - nnz = indptr[-1].item() - ragged_tensor = torch.randn(nnz, dim, dtype=dtype, device=device) - - # Run both implementations - padded_ragged_tensor, padded_indptr = pad_ragged_tensor_to_multiple_of( - ragged_tensor, indptr, multiple_of, fill_zeros=True - ) - - padded_ragged_tensor_ref, padded_indptr_ref = ( - pad_ragged_tensor_to_multiple_of_pytorch_fill_zeros( - ragged_tensor, indptr, multiple_of - ) - ) - - # Check shapes - assert padded_ragged_tensor.shape == padded_ragged_tensor_ref.shape - assert padded_indptr.shape == padded_indptr_ref.shape - - # Check indptr values - assert torch.allclose(padded_indptr, padded_indptr_ref) - - # Check tensor values - assert torch.allclose( - padded_ragged_tensor, padded_ragged_tensor_ref, rtol=1e-3, atol=1e-3 - ) - - -def pack_ragged_tensor_pytorch( - padded_tensor: torch.Tensor, - padded_indptr: torch.Tensor, - original_indptr: torch.Tensor, -) -> torch.Tensor: - """PyTorch reference implementation of pack_ragged_tensor.""" - n_rows = padded_indptr.shape[0] - 1 - dim = padded_tensor.shape[1] - original_nnz = original_indptr[-1].item() - - packed_tensor = torch.empty( - (original_nnz, dim), - dtype=padded_tensor.dtype, - device=padded_tensor.device, - ) - - for i in range(n_rows): - padded_row_start = padded_indptr[i].item() - original_row_start = original_indptr[i].item() - original_row_length = original_indptr[i + 1].item() - original_row_start - - # Copy the original data (without padding) - packed_tensor[original_row_start : original_row_start + original_row_length] = ( - padded_tensor[padded_row_start : padded_row_start + original_row_length] - ) - - return packed_tensor - - -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("n_rows", [1, 2, 5, 10]) -@pytest.mark.parametrize("dim", [64, 128, 1024, 4096, 16384]) -@pytest.mark.parametrize("multiple_of", [8, 16, 32, 64, 128]) -def test_pack_ragged_tensor(dtype, n_rows, dim, multiple_of): - device = torch.device("cuda:0") - torch.manual_seed(42) - - # Create random row lengths - row_lengths = torch.randint(1, 100, (n_rows,), device=device) - - # Create indptr - original_indptr = torch.zeros(n_rows + 1, dtype=torch.int32, device=device) - original_indptr[1:] = torch.cumsum(row_lengths, dim=0) - - # Create ragged tensor - nnz = original_indptr[-1].item() - original_tensor = torch.randn(nnz, dim, dtype=dtype, device=device) - - # First pad the tensor - padded_tensor, padded_indptr = pad_ragged_tensor_to_multiple_of( - original_tensor, original_indptr, multiple_of, fill_zeros=True - ) - - # Now unpad (pack) the tensor - packed_tensor = pack_ragged_tensor(padded_tensor, padded_indptr, original_indptr) - - # Reference implementation - packed_tensor_ref = pack_ragged_tensor_pytorch( - padded_tensor, padded_indptr, original_indptr - ) - - # Check shapes - assert packed_tensor.shape == original_tensor.shape - assert packed_tensor.shape == packed_tensor_ref.shape - - # Check tensor values - assert torch.allclose(packed_tensor, original_tensor, rtol=1e-3, atol=1e-3) - assert torch.allclose(packed_tensor, packed_tensor_ref, rtol=1e-3, atol=1e-3) - - -if __name__ == "__main__": - test_pad_ragged_tensor_to_multiple_of( - dtype=torch.float16, n_rows=100, dim=1024, multiple_of=128 - )