From 186930b67e6bae329cbc951bbde95107327ba9e5 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 21 Feb 2025 14:54:28 +0000 Subject: [PATCH 1/6] Remove Torch CPP backend and update execution backend options MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove TorchCPPKernelAdapter and related code from JIT modules - Update execution backend options in jit/__init__.py, kernel.py, and adapter/__init__.py - Remove "torch_cpp" from supported execution backend literals - Simplify backend validation and remove unused torch_cpp-related code 。 --- tilelang/jit/__init__.py | 10 +-- tilelang/jit/adapter/__init__.py | 1 - tilelang/jit/adapter/torchcpp.py | 128 ------------------------------- tilelang/jit/core.py | 123 ----------------------------- tilelang/jit/env.py | 19 +++-- tilelang/jit/kernel.py | 19 +---- 6 files changed, 20 insertions(+), 280 deletions(-) delete mode 100644 tilelang/jit/adapter/torchcpp.py delete mode 100644 tilelang/jit/core.py diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index a6d1c725f..8d74d910c 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -24,7 +24,7 @@ def jit( func: Callable = None, *, # Enforce keyword-only arguments from here on out_idx: Union[List[int], int] = None, - execution_backend: Literal["dlpack", "torch_cpp", "ctypes"] = "dlpack", + execution_backend: Literal["dlpack", "ctypes"] = "dlpack", target: Union[str, Target] = "auto", verbose: bool = False, ) -> BaseKernelAdapter: @@ -42,9 +42,9 @@ def jit( out_idx : Union[List[int], int], optional The index (or list of indices) of the function outputs. This can be used to specify which outputs from the compiled function will be returned. - execution_backend : Literal["dlpack", "torch_cpp", "ctypes"], optional + execution_backend : Literal["dlpack", "ctypes"], optional The wrapper type to use for the kernel adapter. Currently, only "dlpack" - and "torch_cpp" are supported. + and "ctypes" are supported. target : Union[str, Target], optional The compilation target for TVM. If set to "auto", an appropriate target will be inferred automatically. Otherwise, must be one of the supported @@ -69,7 +69,7 @@ def jit( target = Target(target) - assert execution_backend in ["dlpack", "torch_cpp", "ctypes"], "Invalid execution backend." + assert execution_backend in ["dlpack", "ctypes", "cython"], "Invalid execution backend." def _compile_and_create_adapter(tilelang_func: PrimFunc) -> BaseKernelAdapter: """ @@ -110,7 +110,7 @@ def real_decorator(tilelang_func: PrimFunc) -> BaseKernelAdapter: def compile( func: PrimFunc = None, out_idx: Union[List[int], int] = None, - execution_backend: Literal["dlpack", "torch_cpp", "ctypes", "cython"] = "cython", + execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", target: Union[str, Target] = "auto", target_host: Union[str, Target] = None, verbose: bool = False, diff --git a/tilelang/jit/adapter/__init__.py b/tilelang/jit/adapter/__init__.py index 5d7b2acd4..c3a3f276f 100644 --- a/tilelang/jit/adapter/__init__.py +++ b/tilelang/jit/adapter/__init__.py @@ -3,6 +3,5 @@ from .base import BaseKernelAdapter # noqa: F401 from .dlpack import TorchDLPackKernelAdapter # noqa: F401 -from .torchcpp import TorchCPPKernelAdapter # noqa: F401 from .ctypes import CtypesKernelAdapter # noqa: F401 from .cython import CythonKernelAdapter # noqa: F401 diff --git a/tilelang/jit/adapter/torchcpp.py b/tilelang/jit/adapter/torchcpp.py deleted file mode 100644 index 0b5360fce..000000000 --- a/tilelang/jit/adapter/torchcpp.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""The profiler and convert to torch utils""" - -import torch -from typing import List, Union -from .base import BaseKernelAdapter -from pathlib import Path -from tvm.relay import TensorType -from tilelang.jit.core import load_cuda_ops -from tilelang.jit.env import (TILELANG_JIT_WORKSPACE_DIR) - - -def torch_cpp_cuda_compile(code, target, verbose): - # TODO(lei): This is not fully implemented yet - # TODO(lei): extract name and magic number from module - name: str = "matmul" - magic_number = 0x9f - full_kernel_dir = TILELANG_JIT_WORKSPACE_DIR / Path(f"{name}_{magic_number}") - full_kernel_dir.mkdir(parents=True, exist_ok=True) - - sources: List[Union[str, Path]] = [] - - tmp_cuda_kernel_file = (full_kernel_dir / "kernel.cu") - - code = ( - code + r""" - void kenrel_interface(void* A, void *B, void *C, int64_t cuda_stream) { - cudaStream_t stream = reinterpret_cast(cuda_stream); - main_kernel<<>>((half_t *)A, (half_t *)B, (half_t *)C); - } - """) - with open(tmp_cuda_kernel_file, "w") as f: - f.write(code) - - print(tmp_cuda_kernel_file) - - sources.append(tmp_cuda_kernel_file) - - tmp_host_file = (full_kernel_dir / "host.cpp") - - host_code = r""" - #include - #include - #include - - void kenrel_interface(void* A, void *B, void *C, int64_t cuda_stream); - - int dispather(at::Tensor& A, at::Tensor& B, at::Tensor& C, int64_t cuda_stream) { - kenrel_interface( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - cuda_stream - ); - return 0; - } - - int dispather(at::Tensor& A, at::Tensor& B, at::Tensor& C, int64_t cuda_stream); - - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("matmul", &dispather, "matmul"); - printf("Registering matmul\n"); - } - """ - with open(tmp_host_file, "w") as f: - f.write(host_code) - - sources.append(tmp_host_file) - module = load_cuda_ops(name=name, sources=sources, verbose=verbose) - return module.matmul - - -class TorchCPPKernelAdapter(BaseKernelAdapter): - - target = "cuda" - prim_func = None - - def __init__(self, - mod, - params: List[TensorType], - result_idx: List[int], - target, - prim_func, - verbose: bool = False): - self.target = target - self.prim_func = prim_func - self.verbose = verbose - super().__init__(mod, params, result_idx) - - def _convert_torch_func(self) -> callable: - - target = self.target - verbose = self.verbose - code = self.get_kernel_source() - torch_module = torch_cpp_cuda_compile(code, target, verbose) - - # raise NotImplementedError("Please implement this function") - - def func(*ins: List[torch.Tensor]): - if len(ins) + len(self.result_idx) != len(self.params): - raise ValueError( - f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs" - ) - ins_idx = 0 - args = [] - - # use the device of the first input tensor if available - device = ins[0].device if len(ins) > 0 else torch.cuda.current_device() - - for i in range(len(self.params)): - if i in self.result_idx: - dtype = torch.__getattribute__(str(self.params[i].dtype)) - shape = list(map(int, self.params[i].shape)) - tensor = torch.empty(*shape, dtype=dtype, device=device) - else: - tensor = ins[ins_idx] - ins_idx += 1 - args.append(tensor) - - torch_module(*args, 0) - - if len(self.result_idx) == 1: - return args[self.result_idx[0]] - else: - return [args[i] for i in self.result_idx] - - return func diff --git a/tilelang/jit/core.py b/tilelang/jit/core.py deleted file mode 100644 index 5b51d0d3f..000000000 --- a/tilelang/jit/core.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# This file is modified from the original version, -# which is part of the flashinfer project -# (https://github.com/flashinfer-ai/flashinfer). - -import logging -import os -from pathlib import Path -from typing import List, Union - -import torch.utils.cpp_extension as torch_cpp_ext -from filelock import FileLock -from .env import CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH, TILELANG_JIT_DIR -from contextlib import suppress - - -class TileLangJITLogger(logging.Logger): - - def __init__(self, name): - super().__init__(name) - self.setLevel(logging.INFO) - # Add a StreamHandler for console output - stream_handler = logging.StreamHandler() - stream_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) - self.addHandler(stream_handler) - - def info(self, msg): - super().info("tilelang.jit: " + msg) - - -logger = TileLangJITLogger("tilelang.jit") - - -def check_cuda_arch(): - # cuda arch check for fp8 at the moment. - for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): # noqa: B007 - pass - - -def remove_unwanted_pytorch_nvcc_flags(): - REMOVE_NVCC_FLAGS = [ - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ] - for flag in REMOVE_NVCC_FLAGS: - try: - torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) - except ValueError: - suppress(ValueError) - - -remove_unwanted_pytorch_nvcc_flags() - -sm90a_nvcc_flags = ["-gencode", "arch=compute_90a,code=sm_90a"] - - -def load_cuda_ops( - name: str, - sources: List[Union[str, Path]], - extra_cflags: List[str] = None, - extra_cuda_cflags: List[str] = None, - extra_ldflags=None, - extra_include_paths=None, - verbose=False, -): - if extra_cflags is None: - extra_cflags = [] - - if extra_cuda_cflags is None: - extra_cuda_cflags = [] - - cflags = ["-O3", "-Wno-switch-bool"] - cuda_cflags = [ - "-O3", - "-std=c++17", - "-use_fast_math", - ] - cflags += extra_cflags - cuda_cflags += extra_cuda_cflags - check_cuda_arch() - build_directory = TILELANG_JIT_DIR / name - os.makedirs(build_directory, exist_ok=True) - if extra_include_paths is None: - extra_include_paths = [ - CUTLASS_INCLUDE_DIR, - TILELANG_TEMPLATE_PATH, - ] - - lock = FileLock(TILELANG_JIT_DIR / f"{name}.lock", thread_local=False) - with lock: - module = torch_cpp_ext.load( - name, - list(map(lambda _: str(_), sources)), - extra_cflags=cflags, - extra_cuda_cflags=cuda_cflags, - extra_ldflags=extra_ldflags, - extra_include_paths=list(map(lambda _: str(_), extra_include_paths)), - build_directory=build_directory, - verbose=verbose, - with_cuda=True, - keep_intermediates=False, - ) - logger.info(f"Finished loading JIT ops: {name}") - return module diff --git a/tilelang/jit/env.py b/tilelang/jit/env.py index 7b3bb9444..5e4dc93d9 100644 --- a/tilelang/jit/env.py +++ b/tilelang/jit/env.py @@ -27,7 +27,6 @@ import re import warnings -from torch.utils.cpp_extension import _get_cuda_arch_flags from tilelang.env import ( CUTLASS_INCLUDE_DIR, # noqa: F401 TILELANG_TEMPLATE_PATH, # noqa: F401 @@ -51,19 +50,23 @@ def _initialize_torch_cuda_arch_flags(): def _get_workspace_dir_name() -> pathlib.Path: try: - with warnings.catch_warnings(): - # Ignore the warning for TORCH_CUDA_ARCH_LIST not set - warnings.filterwarnings("ignore", r".*TORCH_CUDA_ARCH_LIST.*", module="torch") - flags = _get_cuda_arch_flags() - arch = "_".join(sorted(set(re.findall(r"compute_(\d+)", "".join(flags))))) + from tilelang.contrib import nvcc + from tilelang.utils.target import determine_target + + target = determine_target(return_object=True) + # create tmp source file for torch cpp extension + compute_version = "".join(nvcc.get_target_compute_version(target).split(".")) + # set TORCH_CUDA_ARCH_LIST + major = compute_version[0] + minor = compute_version[1] + arch = f"{major}_{minor}" except Exception: arch = "noarch" # e.g.: $HOME/.cache/tilelang/75_80_89_90/ return pathlib.Path.home() / ".cache" / "tilelang" / arch -# use pathlib -_initialize_torch_cuda_arch_flags() +# _initialize_torch_cuda_arch_flags() TILELANG_JIT_WORKSPACE_DIR = _get_workspace_dir_name() TILELANG_JIT_DIR = TILELANG_JIT_WORKSPACE_DIR / "cached_ops" TILELANG_GEN_SRC_DIR = TILELANG_JIT_WORKSPACE_DIR / "generated" diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 1ae6f892a..09e7b4453 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -7,7 +7,7 @@ from tilelang import tvm as tvm from tvm.tir import PrimFunc -from tilelang.jit.adapter import TorchCPPKernelAdapter, TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter +from tilelang.jit.adapter import TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter from tilelang.utils.target import determine_target, AVALIABLE_TARGETS from tilelang.profiler import Profiler, TensorSupplyType @@ -34,7 +34,7 @@ def __init__( self, func: PrimFunc = None, out_idx: Union[List[int], int] = None, - execution_backend: Literal["dlpack", "torch_cpp", "ctypes", "cython"] = "cython", + execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", target: Union[str, Target] = "auto", target_host: Union[str, Target] = None, verbose: bool = False, @@ -48,7 +48,7 @@ def __init__( The TileLang TIR function to compile and wrap. out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). - execution_backend : Literal["dlpack", "torch_cpp", "ctypes"], optional + execution_backend : Literal["dlpack", "ctypes"], optional Execution backend to use for kernel execution (default: "dlpack"). target : Union[str, Target], optional Compilation target, either as a string or a TVM Target object (default: "auto"). @@ -73,7 +73,7 @@ def __init__( target = Target(target) # Validate the execution backend. - assert execution_backend in ["dlpack", "torch_cpp", "ctypes", + assert execution_backend in ["dlpack", "ctypes", "cython"], f"Invalid execution backend. {execution_backend}" if execution_backend == "cython": from tilelang.contrib.cc import get_cplus_compiler @@ -137,17 +137,6 @@ def _compile_and_create_adapter(self, tilelang_func: PrimFunc) -> BaseKernelAdap if execution_backend == "dlpack": # Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack. adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx) - elif execution_backend == "torch_cpp": - # Torch CPP backend adapter (not fully implemented yet). - adapter = TorchCPPKernelAdapter( - rt_mod, - params=params, - result_idx=out_idx, - target=target, - prim_func=tilelang_func, - verbose=verbose, - ) - raise NotImplementedError("Torch CPP backend is not fully implemented.") elif execution_backend == "ctypes": adapter = CtypesKernelAdapter( rt_mod, From 06884e15c0f34e34639353604472e4d456772869 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 21 Feb 2025 14:54:49 +0000 Subject: [PATCH 2/6] lint fix --- tilelang/jit/env.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tilelang/jit/env.py b/tilelang/jit/env.py index 5e4dc93d9..f9738e61f 100644 --- a/tilelang/jit/env.py +++ b/tilelang/jit/env.py @@ -24,8 +24,6 @@ """ import pathlib -import re -import warnings from tilelang.env import ( CUTLASS_INCLUDE_DIR, # noqa: F401 From cfa7095c6e2290fad3309640d706d2bb63f48290 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 22 Feb 2025 18:05:40 +0000 Subject: [PATCH 3/6] Add block sparse attention implementations for TileLang and Triton - Implement block sparse attention kernels for TileLang and Triton - Add example scripts for block sparse attention with top-k and threshold-based masking - Include utility functions for generating sparse attention masks - Demonstrate causal attention with block-level sparsity - Add test cases to validate sparse attention implementations against PyTorch reference --- .../block_sparse_attn_tilelang.py | 218 +++++++++++ .../block_sparse_attn_triton.py | 359 ++++++++++++++++++ .../flash_attention/example_mha_fwd_bhsd.py | 227 +++++++++++ ...example_mha.py => example_mha_fwd_bshd.py} | 0 src/tl_templates/cuda/debug.h | 90 ++++- 5 files changed, 874 insertions(+), 20 deletions(-) create mode 100644 examples/blocksparse_attention/block_sparse_attn_tilelang.py create mode 100644 examples/blocksparse_attention/block_sparse_attn_triton.py create mode 100644 examples/flash_attention/example_mha_fwd_bhsd.py rename examples/flash_attention/{example_mha.py => example_mha_fwd_bshd.py} (100%) diff --git a/examples/blocksparse_attention/block_sparse_attn_tilelang.py b/examples/blocksparse_attention/block_sparse_attn_tilelang.py new file mode 100644 index 000000000..0237eec9c --- /dev/null +++ b/examples/blocksparse_attention/block_sparse_attn_tilelang.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import tilelang +import tilelang.language as T +import torch.nn.functional as F + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :,-2:,:] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :,-2:,:] = True + dense_mask.tril_() + return dense_mask + + +def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + num_stages = 0 + threads = 128 + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + + dtype = "float16" + accum_dtype = "float" + block_mask_dtype = "int8" + + def kernel_func(block_M, block_N, num_stages, threads): + + @T.macro + def MMA0( + K: T.Buffer(shape, dtype), + Q_shared: T.Buffer([block_M, dim], dtype), + K_shared: T.Buffer([block_N, dim], dtype), + acc_s: T.Buffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Buffer(shape, dtype), + V_shared: T.Buffer([block_M, dim], dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + acc_o: T.Buffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.Buffer([block_M, block_N], accum_dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + scores_max: T.Buffer([block_M], accum_dtype), + scores_max_prev: T.Buffer([block_M], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + scores_sum: T.Buffer([block_M], accum_dtype), + logsum: T.Buffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.Buffer([block_M, dim], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), + K: T.Buffer(shape, dtype), + V: T.Buffer(shape, dtype), + BlockSparseMask: T.Buffer(block_mask_shape, block_mask_dtype), + Output: T.Buffer(shape, dtype), + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_mask = T.alloc_local([downsample_len], block_mask_dtype) + + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for vj in T.serial(downsample_len): + block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[k] != 0: + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, + scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + + return main + + return kernel_func(block_M, block_N, num_stages, threads) + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD ** 0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16) + x_ds[:,:,:,0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + # Run Triton kernel + program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = tilelang.compile(program, out_idx=[4]) + print(kernel.get_kernel_source()) + tilelang_output = kernel(q, k, v, block_mask) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), + torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float('-inf')) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + + print("ref_output", ref_output) + print("tilelang_output", tilelang_output) + + + # Verify accuracy + assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \ + "TileLang output doesn't match reference" + print("Pass topk sparse attention test with qlen == klen") + +if __name__ == "__main__": + test_topk_sparse_attention() diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py new file mode 100644 index 000000000..e459800c5 --- /dev/null +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -0,0 +1,359 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import triton +import triton.language as tl +import torch.nn.functional as F + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :,-2:,:] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :,-2:,:] = True + dense_mask.tril_() + return dense_mask + + + + +@triton.jit +def _fwd_kernel_inner( + acc, l_i, m_i, + q, + k_block_col_idx, + block_mask_ptr, + k_ptrs, v_ptrs, + offs_m, offs_n, + stride_kt, stride_vt, stride_bmask_n, + sm_scale, + seqlen_k, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + + mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) + # print + + if k_block_col_idx == 3: + print("mask_val", mask_val) + if mask_val == True: + start_n = k_block_col_idx * BLOCK_N + # -- compute qk ---- + + k = tl.load(k_ptrs + start_n * stride_kt) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK : + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) + + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + start_n * stride_vt) + + p = p.to(v.type.element_ty) + + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + return acc, l_i, m_i + + + + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + block_mask_ptr, + Out, + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_bmz, stride_bmh, stride_bmm, stride_bmn, + stride_oz, stride_oh, stride_om, stride_od, + H, N_CTX, + PAST_LEN, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + Q_LEN = N_CTX - PAST_LEN + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_h = off_hz % H + off_z = off_hz // H + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + mask_ptrs = block_mask_ptr + start_m * stride_bmm + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) + + k_block_start = 0 + k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N) + + # loop over k, v and update accumulator + for col_idx in range(k_block_start, k_block_end): + acc, l_i, m_i = _fwd_kernel_inner( + acc, l_i, m_i, + q, + col_idx, + mask_ptrs, + k_ptrs, v_ptrs, + offs_m, offs_n, + stride_kn, stride_vn, stride_bmn, + sm_scale, + N_CTX, + PAST_LEN, + col_idx == k_block_end - 1, + BLOCK_M, + BLOCK_N, + ) + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(Out.dtype.element_ty) + + + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) + +def _forward( + ctx, + q, + k, + v, + block_sparse_mask, + sm_scale, + BLOCK_M=64, + BLOCK_N=64, + num_warps=None, + num_stages=1, + out=None + ): + + + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert k.shape[2] == v.shape[2] + o = out if out is not None else torch.empty_like(q).contiguous() + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) + + assert q.shape[-1] in [64, 128] + BLOCK_DMODEL = q.shape[-1] + + if is_hip(): + num_warps, num_stages = 8, 1 + else: + num_warps, num_stages = 4, 2 + + N_CTX = k.shape[2] + PAST_LEN = N_CTX - q.shape[2] + + + H = q.shape[1] + + _fwd_kernel[grid]( + q, k, v, sm_scale, + block_sparse_mask, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), + *o.stride(), + H, N_CTX, + PAST_LEN, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + + return o + + + + +class _sparse_attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, block_sparse_dense, sm_scale): + # shape constraints + return _forward(ctx, q, k, v, block_sparse_dense, sm_scale) + + @staticmethod + def backward(ctx, do): + # No gradient propagation. + raise NotImplementedError("It does not support gradient propagation yet") + return None, None, None, None, None + +block_sparse_triton_fn = _sparse_attention.apply + + + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + sm_scale = 1.0 / (D_HEAD ** 0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + print("downsample_len", downsample_len) + + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16) + x_ds[:,:,:,0] = 100 + print("x_ds.shape", x_ds.shape) + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=downsample_len) + # print("block_mask", block_mask) + print("block_mask.shape", block_mask.shape) + + # Run Triton kernel + triton_output = block_sparse_triton_fn( + q, k, v, + block_mask, + sm_scale + ) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), + torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float('-inf')) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + + # print("ref_output", ref_output) + # print("triton_output", triton_output) + + + # Verify accuracy + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ + "Triton output doesn't match reference" + print("Pass topk sparse attention test with qlen == klen") + + + +# def test_topk_sparse_attention_qlt_kl(): +# BATCH, N_HEADS = 2, 4 +# Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. +# TOPK = 1 +# BLOCK = 64 # block size used in downsampling +# torch.manual_seed(0) + +# # Create inputs. +# q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) +# k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) +# v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) +# # softmax scale +# sm_scale = 1.0 / (D_HEAD ** 0.5) + +# downsample_factor = BLOCK +# print("downsample_factor", downsample_factor) +# downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension +# print("downsample_len", downsample_len) +# x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, +# device='cuda', dtype=torch.bfloat16) +# # Force the first column to be high so that the first block is always selected. +# x_ds[:, :, :, 0] = 100 +# block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) +# print("block_mask", block_mask) +# print("block_mask.shape", block_mask.shape) +# # Run Triton kernel. +# triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + +# past_len = K_LEN - Q_LEN + +# attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + +# full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() +# full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] + +# effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) + + +# i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) +# j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) +# causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + +# final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) + +# attn = attn.masked_fill(~final_mask, float('-inf')) +# attn = F.softmax(attn, dim=-1) +# ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + +# # Verify accuracy. +# assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ +# "Triton output doesn't match reference when qlen < klen" + +# print("Pass topk sparse attention test with qlen < klen") + + +if __name__ == "__main__": + test_topk_sparse_attention() + # test_topk_sparse_attention_qlt_kl() \ No newline at end of file diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py new file mode 100644 index 000000000..f4b873ef8 --- /dev/null +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -0,0 +1,227 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn.functional as F +import tilelang +from tilelang import Profiler +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + block_M = [128] + block_N = [128] + num_stages = [2] + threads = [256] + _configs = list(itertools.product(block_M, block_N, num_stages, threads)) + + configs = [{ + 'block_M': c[0], + 'block_N': c[1], + 'num_stages': c[2], + 'threads': c[3] + } for c in _configs] + return configs + + +def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + dtype = "float16" + accum_dtype = "float" + + def kernel_func(block_M, block_N, num_stages, threads): + + @T.macro + def MMA0( + K: T.Buffer(shape, dtype), + Q_shared: T.Buffer([block_M, dim], dtype), + K_shared: T.Buffer([block_N, dim], dtype), + acc_s: T.Buffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Buffer(shape, dtype), + V_shared: T.Buffer([block_M, dim], dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + acc_o: T.Buffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.Buffer([block_M, block_N], accum_dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + scores_max: T.Buffer([block_M], accum_dtype), + scores_max_prev: T.Buffer([block_M], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + scores_sum: T.Buffer([block_M], accum_dtype), + logsum: T.Buffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.Buffer([block_M, dim], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), + K: T.Buffer(shape, dtype), + V: T.Buffer(shape, dtype), + Output: T.Buffer(shape, dtype), + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, + scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + + return main + + if tune: + + @autotune( + configs=get_configs(), + keys=["block_M", "block_N", "num_stages", "threads"], + warmup=10, + rep=10) + @jit( + out_idx=[3], + supply_type=tilelang.TensorSupplyType.Integer, + ref_prog=None, + profiler="auto") + def kernel(block_M=None, block_N=None, num_stages=None, threads=None): + return kernel_func(block_M, block_N, num_stages, threads) + + return kernel() + else: + + def kernel(block_M, block_N, num_stages, threads): + return kernel_func(block_M, block_N, num_stages, threads) + + return kernel + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float('-inf')) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + return output + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=8, help='batch size') + parser.add_argument('--heads', type=int, default=32, help='heads') + parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') + parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument('--is_causal', action='store_true', help='causal') + parser.add_argument('--tune', action='store_true', help='tune configs') + args = parser.parse_args() + batch, heads, seq_len, dim, is_causal = args.batch, args.heads, args.seq_len, args.dim, args.is_causal + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if (not args.tune): + program = flashattn( + batch, heads, seq_len, dim, is_causal, tune=args.tune)( + block_M=128, block_N=128, num_stages=1, threads=128) + ref_program = partial(ref_program, is_causal=is_causal) + mod, params = tilelang.lower(program) + mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = mod.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = mod.do_bench(mod.func, warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_latency, best_config, _ = flashattn( + batch, heads, seq_len, dim, is_causal, tune=args.tune) + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") diff --git a/examples/flash_attention/example_mha.py b/examples/flash_attention/example_mha_fwd_bshd.py similarity index 100% rename from examples/flash_attention/example_mha.py rename to examples/flash_attention/example_mha_fwd_bshd.py diff --git a/src/tl_templates/cuda/debug.h b/src/tl_templates/cuda/debug.h index 4818f14ba..0cb939627 100644 --- a/src/tl_templates/cuda/debug.h +++ b/src/tl_templates/cuda/debug.h @@ -4,10 +4,31 @@ #include // Template declaration for device-side debug printing (variable only) -template __device__ void debug_print_var(char *msg, T var); +template __device__ void debug_print_var(const char *msg, T var); + +// Specialization for signed char type +template <> +__device__ void debug_print_var(const char *msg, signed char var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed " + "char " + "value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} + +// Specialization for unsigned char type +template <> +__device__ void debug_print_var(const char *msg, + unsigned char var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unsigned char " + "value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} // Specialization for integer type -template <> __device__ void debug_print_var(char *msg, int var) { +template <> __device__ void debug_print_var(const char *msg, int var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int " "value=%d\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -15,7 +36,7 @@ template <> __device__ void debug_print_var(char *msg, int var) { } // Specialization for float type -template <> __device__ void debug_print_var(char *msg, float var) { +template <> __device__ void debug_print_var(const char *msg, float var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float " "value=%f\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -23,7 +44,7 @@ template <> __device__ void debug_print_var(char *msg, float var) { } // Specialization for half type -template <> __device__ void debug_print_var(char *msg, half var) { +template <> __device__ void debug_print_var(const char *msg, half var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half " "value=%f\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -31,7 +52,8 @@ template <> __device__ void debug_print_var(char *msg, half var) { } // Specialization for half_t type -template <> __device__ void debug_print_var(char *msg, half_t var) { +template <> +__device__ void debug_print_var(const char *msg, half_t var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half_t " "value=%f\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -40,7 +62,7 @@ template <> __device__ void debug_print_var(char *msg, half_t var) { // Specialization for bfloat16_t type template <> -__device__ void debug_print_var(char *msg, bfloat16_t var) { +__device__ void debug_print_var(const char *msg, bfloat16_t var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " "dtype=bfloat16_t value=%f\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -48,7 +70,8 @@ __device__ void debug_print_var(char *msg, bfloat16_t var) { } // Specialization for double type -template <> __device__ void debug_print_var(char *msg, double var) { +template <> +__device__ void debug_print_var(const char *msg, double var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double " "value=%lf\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -62,13 +85,36 @@ template <> __device__ void debug_print_var(char *msg, double var) { // Template declaration for device-side debug printing (buffer only) template -__device__ void debug_print_buffer_value(char *msg, char *buf_name, int index, - T var); +__device__ void debug_print_buffer_value(const char *msg, const char *buf_name, + int index, T var); + +// Specialization for signed char type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, signed char var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=signed char value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, var); +} + +// Specialization for unsiged char type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + char var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=char value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, var); +} // Specialization for integer type template <> -__device__ void debug_print_buffer_value(char *msg, char *buf_name, - int index, int var) { +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + int var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " "index=%d, dtype=int value=%d\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -77,8 +123,9 @@ __device__ void debug_print_buffer_value(char *msg, char *buf_name, // Specialization for float type template <> -__device__ void debug_print_buffer_value(char *msg, char *buf_name, - int index, float var) { +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + float var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " "index=%d, dtype=float value=%f\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -87,8 +134,9 @@ __device__ void debug_print_buffer_value(char *msg, char *buf_name, // Specialization for half type template <> -__device__ void debug_print_buffer_value(char *msg, char *buf_name, - int index, half var) { +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + half var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " "index=%d, dtype=half value=%f\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -97,7 +145,8 @@ __device__ void debug_print_buffer_value(char *msg, char *buf_name, // Specialization for half_t type template <> -__device__ void debug_print_buffer_value(char *msg, char *buf_name, +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, half_t var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " "index=%d, dtype=half_t value=%f\n", @@ -107,9 +156,9 @@ __device__ void debug_print_buffer_value(char *msg, char *buf_name, // Specialization for bfloat16_t type template <> -__device__ void debug_print_buffer_value(char *msg, char *buf_name, - int index, - bfloat16_t var) { +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, bfloat16_t var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " "index=%d, dtype=bfloat16_t value=%f\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, @@ -118,7 +167,8 @@ __device__ void debug_print_buffer_value(char *msg, char *buf_name, // Specialization for double type template <> -__device__ void debug_print_buffer_value(char *msg, char *buf_name, +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, double var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " "index=%d, dtype=double value=%lf\n", From 26d6d6f6f6c8f3a492d81ce01aab287cbff63578 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 23 Feb 2025 07:15:16 +0000 Subject: [PATCH 4/6] Bump version to 0.1.1 --- MANIFEST.in | 1 + VERSION | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index ba3120225..88b206825 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,7 @@ include CMakeLists.txt include requirements.txt include requirements-test.txt include requirements-dev.txt +include tilelang/jit/adapter/cython/cython_wrapper.pyx recursive-include src * recursive-include 3rdparty * recursive-exclude 3rdparty/clang* * diff --git a/VERSION b/VERSION index 6c6aa7cb0..6da28dde7 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0 \ No newline at end of file +0.1.1 \ No newline at end of file From 892f6409c7d1fc23ea6b1e031e090d076355c545 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 23 Feb 2025 09:09:38 +0000 Subject: [PATCH 5/6] Refactor block sparse attention examples for improved code quality - Apply consistent code formatting and style in TileLang and Triton block sparse attention implementations - Add ruff linter ignore comment for specific line in Triton implementation - Improve readability by adjusting indentation and line breaks - Standardize sparse mask generation and test function implementations - Minor optimizations in test case configurations --- .../block_sparse_attn_tilelang.py | 36 ++-- .../block_sparse_attn_triton.py | 184 ++++++++++-------- 2 files changed, 125 insertions(+), 95 deletions(-) diff --git a/examples/blocksparse_attention/block_sparse_attn_tilelang.py b/examples/blocksparse_attention/block_sparse_attn_tilelang.py index 0237eec9c..912ec7b96 100644 --- a/examples/blocksparse_attention/block_sparse_attn_tilelang.py +++ b/examples/blocksparse_attention/block_sparse_attn_tilelang.py @@ -7,24 +7,28 @@ import tilelang.language as T import torch.nn.functional as F + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: - dense_mask[:, :,-2:,:] = True + dense_mask[:, :, -2:, :] = True dense_mask.tril_() - return dense_mask + return dense_mask def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): - dense_mask = x > threshold + dense_mask = x > threshold if use_dense_for_last_block: - dense_mask[:, :,-2:,:] = True + dense_mask[:, :, -2:, :] = True dense_mask.tril_() - return dense_mask + return dense_mask def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): @@ -136,7 +140,7 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -165,6 +169,7 @@ def main( return kernel_func(block_M, block_N, num_stages, threads) + def test_topk_sparse_attention(): # Config BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 @@ -177,13 +182,15 @@ def test_topk_sparse_attention(): k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - sm_scale = 1.0 / (D_HEAD ** 0.5) + sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16) - x_ds[:,:,:,0] = 100 + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) # Run Triton kernel @@ -194,25 +201,24 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), - torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal - + # PyTorch reference implementation attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = attn.masked_fill(~full_mask, float('-inf')) attn = F.softmax(attn, dim=-1) ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) - + print("ref_output", ref_output) print("tilelang_output", tilelang_output) - # Verify accuracy assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \ "TileLang output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") + if __name__ == "__main__": test_topk_sparse_attention() diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index e459800c5..e79f00e9b 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa: E712 import math import torch @@ -7,6 +8,7 @@ import triton.language as tl import torch.nn.functional as F + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" @@ -15,33 +17,40 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: - dense_mask[:, :,-2:,:] = True + dense_mask[:, :, -2:, :] = True dense_mask.tril_() - return dense_mask + return dense_mask def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): - dense_mask = x > threshold + dense_mask = x > threshold if use_dense_for_last_block: - dense_mask[:, :,-2:,:] = True + dense_mask[:, :, -2:, :] = True dense_mask.tril_() - return dense_mask - - + return dense_mask @triton.jit def _fwd_kernel_inner( - acc, l_i, m_i, + acc, + l_i, + m_i, q, k_block_col_idx, block_mask_ptr, - k_ptrs, v_ptrs, - offs_m, offs_n, - stride_kt, stride_vt, stride_bmask_n, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kt, + stride_vt, + stride_bmask_n, sm_scale, seqlen_k, past_len, @@ -51,8 +60,8 @@ def _fwd_kernel_inner( ): mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) - # print - + # print + if k_block_col_idx == 3: print("mask_val", mask_val) if mask_val == True: @@ -67,9 +76,9 @@ def _fwd_kernel_inner( qk *= sm_scale # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N - if LAST_K_BLOCK : - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) - + if LAST_K_BLOCK: + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, + float('-inf')) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -78,7 +87,7 @@ def _fwd_kernel_inner( alpha = tl.exp(m_i - m_ij) l_i = l_i * alpha + l_ij acc = acc * alpha[:, None] - + # update acc v = tl.load(v_ptrs + start_n * stride_vt) @@ -90,21 +99,38 @@ def _fwd_kernel_inner( return acc, l_i, m_i - - @triton.jit def _fwd_kernel( - Q, K, V, sm_scale, + Q, + K, + V, + sm_scale, block_mask_ptr, Out, - stride_qz, stride_qh, stride_qm, stride_qd, - stride_kz, stride_kh, stride_kn, stride_kd, - stride_vz, stride_vh, stride_vn, stride_vd, - stride_bmz, stride_bmh, stride_bmm, stride_bmn, - stride_oz, stride_oh, stride_om, stride_od, - H, N_CTX, + stride_qz, + stride_qh, + stride_qm, + stride_qd, + stride_kz, + stride_kh, + stride_kn, + stride_kd, + stride_vz, + stride_vh, + stride_vn, + stride_vd, + stride_bmz, + stride_bmh, + stride_bmm, + stride_bmn, + stride_oz, + stride_oh, + stride_om, + stride_od, + H, + N_CTX, PAST_LEN, - BLOCK_M: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ): @@ -144,13 +170,19 @@ def _fwd_kernel( # loop over k, v and update accumulator for col_idx in range(k_block_start, k_block_end): acc, l_i, m_i = _fwd_kernel_inner( - acc, l_i, m_i, + acc, + l_i, + m_i, q, col_idx, mask_ptrs, - k_ptrs, v_ptrs, - offs_m, offs_n, - stride_kn, stride_vn, stride_bmn, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kn, + stride_vn, + stride_bmn, sm_scale, N_CTX, PAST_LEN, @@ -162,27 +194,25 @@ def _fwd_kernel( m_i += tl.math.log(l_i) l_recip = 1 / l_i[:, None] acc = acc * l_recip - acc = acc.to(Out.dtype.element_ty) + acc = acc.to(Out.dtype.element_ty) - - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ + None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward( - ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None - ): +def _forward(ctx, + q, + k, + v, + block_sparse_mask, + sm_scale, + BLOCK_M=64, + BLOCK_N=64, + num_warps=None, + num_stages=1, + out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] @@ -200,19 +230,22 @@ def _forward( N_CTX = k.shape[2] PAST_LEN = N_CTX - q.shape[2] - H = q.shape[1] _fwd_kernel[grid]( - q, k, v, sm_scale, + q, + k, + v, + sm_scale, block_sparse_mask, o, - *q.stride(), - *k.stride(), - *v.stride(), - *block_sparse_mask.stride(), + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), *o.stride(), - H, N_CTX, + H, + N_CTX, PAST_LEN, BLOCK_M, BLOCK_N, @@ -224,8 +257,6 @@ def _forward( return o - - class _sparse_attention(torch.autograd.Function): @staticmethod @@ -239,8 +270,8 @@ def backward(ctx, do): raise NotImplementedError("It does not support gradient propagation yet") return None, None, None, None, None -block_sparse_triton_fn = _sparse_attention.apply +block_sparse_triton_fn = _sparse_attention.apply def test_topk_sparse_attention(): @@ -254,55 +285,50 @@ def test_topk_sparse_attention(): q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - sm_scale = 1.0 / (D_HEAD ** 0.5) + sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) print("downsample_len", downsample_len) - - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16) - x_ds[:,:,:,0] = 100 + + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 print("x_ds.shape", x_ds.shape) - block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=downsample_len) + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) # print("block_mask", block_mask) print("block_mask.shape", block_mask.shape) # Run Triton kernel - triton_output = block_sparse_triton_fn( - q, k, v, - block_mask, - sm_scale - ) + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), - torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal - + # PyTorch reference implementation attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = attn.masked_fill(~full_mask, float('-inf')) attn = F.softmax(attn, dim=-1) ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) - + # print("ref_output", ref_output) # print("triton_output", triton_output) - # Verify accuracy assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ "Triton output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") - # def test_topk_sparse_attention_qlt_kl(): # BATCH, N_HEADS = 2, 4 # Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. -# TOPK = 1 +# TOPK = 1 # BLOCK = 64 # block size used in downsampling # torch.manual_seed(0) @@ -317,7 +343,7 @@ def test_topk_sparse_attention(): # print("downsample_factor", downsample_factor) # downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension # print("downsample_len", downsample_len) -# x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, +# x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, # device='cuda', dtype=torch.bfloat16) # # Force the first column to be high so that the first block is always selected. # x_ds[:, :, :, 0] = 100 @@ -336,7 +362,6 @@ def test_topk_sparse_attention(): # effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) - # i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) # j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) # causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) @@ -350,10 +375,9 @@ def test_topk_sparse_attention(): # # Verify accuracy. # assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ # "Triton output doesn't match reference when qlen < klen" - -# print("Pass topk sparse attention test with qlen < klen") +# print("Pass topk sparse attention test with qlen < klen") if __name__ == "__main__": test_topk_sparse_attention() - # test_topk_sparse_attention_qlt_kl() \ No newline at end of file + # test_topk_sparse_attention_qlt_kl() From b10259799d706f8bbe98379b43ee9703ad2f0ebd Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 23 Feb 2025 09:11:19 +0000 Subject: [PATCH 6/6] lint --- .../blocksparse_attention/block_sparse_attn_triton.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index 7df1cfafe..907d42d9c 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -337,14 +337,14 @@ def test_topk_sparse_attention_qlt_kl(): k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) # softmax scale - sm_scale = 1.0 / (D_HEAD ** 0.5) + sm_scale = 1.0 / (D_HEAD**0.5) downsample_factor = BLOCK print("downsample_factor", downsample_factor) downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension print("downsample_len", downsample_len) - x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, - device='cuda', dtype=torch.bfloat16) + x_ds = torch.randn( + BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -363,7 +363,7 @@ def test_topk_sparse_attention_qlt_kl(): effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) - j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) + j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) @@ -378,6 +378,7 @@ def test_topk_sparse_attention_qlt_kl(): print("Pass topk sparse attention test with qlen < klen") + if __name__ == "__main__": test_topk_sparse_attention() test_topk_sparse_attention_qlt_kl()