diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 5d59b246e1..f82207b920 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,7 +1,27 @@ -# generate a list of kernels, but not actually emit files at config stage +# validate user-specified fmha_fwd API list +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv") +set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") +if(FMHA_FWD_ENABLE_APIS STREQUAL "all") + set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) +endif() + +foreach(api ${FMHA_FWD_ENABLE_APIS}) + if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) + message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.") + endif() +endforeach() + +# "fwd" is a must-have api for the fmha_fwd example, add it if not specified +if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND FMHA_FWD_ENABLE_APIS "fwd") +endif() + +string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +# generate a list of kernels, but not actually emit files at config sta execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) execute_process( @@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) add_custom_command( @@ -60,6 +80,20 @@ else() endif() list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) +# conditionally enable call to the fwd_splitkv API in fmha_fwd example +if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) +endif() + +# conditionally enable call to the fwd_appendkv API in fmha_fwd example +if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) +endif() + # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index a5862ad5d9..66691356ab 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -82,6 +82,18 @@ def get_mask_check_map(mask : str): "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", } +ROPE_MAP = { + "no" : "ck_tile::RotaryEmbeddingEnum::NONE", + "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", + "half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED" +} + +ROPE_CHECK_MAP = { + "no" : "rope_enum::none", + "inter" : "rope_enum::interleaved", + "half" : "rope_enum::half_rotated" +} + MODE_MAP = { "batch" : "false", "group" : "true" @@ -105,4 +117,4 @@ def get_mask_check_map(mask : str): BOOL_MAP = { "t" : "true", "f" : "false" -} \ No newline at end of file +} diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py new file mode 100644 index 0000000000..cfd1d01c91 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + +from codegen.ops.fmha_fwd import ( + FmhaFwdApiTrait, + DTYPE_BITS, + FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_DTYPE, + FMHA_FWD_API_PER_HDIM_CASE, +) + + +FMHA_FWD_APPENDKV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_occupancy}>; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + {F_bs}, + {F_bsk}, + {F_bd}, + {F_bdv}, + {F_vlayout}, + {F_rope}, + {F_pagedkv}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< + fmha_pipeline_problem_{F_idx}>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdAppendKVKernel, + fmha_pipeline_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, + {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; + +#include + +template<> +float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp" +FMHA_FWD_APPENDKV_API=""" +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv})) {{ + using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; + return fmha_fwd_appendkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdAppendKVApiTrait: + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + bs : int # tile size along q seqlen + bsk : int # tile size along k seqlen + bd : int # tile size along qk gemm unroll + bdv : int # tile size along kv gemm unroll + vlayout : str + spad : str + skpad : str + dpad : str + dvpad : str + rope : str # key from ROPE_MAP + pagedkv : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\ + f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}' + + @property + def scheck(self) -> str: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/' + else : return f'a.seqlen_q % {self.bs} == 0' + + @property + def skcheck(self) -> str: + # we do not check all the values in a.seqlen_k_ptr + return 'true' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bd} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bdv} == 0' + +@dataclass +class FmhaFwdAppendKVPipeline: + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_rope : str # key from ROPE_MAP + F_pagedkv : str # t/f + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_rope != 'no': n += f'_{self.F_rope}' + if self.F_pagedkv == 't': n += '_pagedkv' + return n + +class FmhaFwdAppendKVApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], + F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdAppendKVTileSize: + F_bs : int # tile size along q seqlen + F_bsk : int # tile size along k seqlen + F_bd : int # tile size along qk gemm unroll + F_bdv : int # tile size along kv gemm unroll + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdAppendKVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaFwdAppendKVTileSize + F_pipeline : FmhaFwdAppendKVPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_APPENDKV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bs = self.F_tile.F_bs, + F_bsk = self.F_tile.F_bsk, + F_bd = self.F_tile.F_bd, + F_bdv = self.F_tile.F_bdv, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_rope = ROPE_MAP[self.F_pipeline.F_rope], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_occupancy = self.F_tile.F_occupancy) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdAppendKVApiTrait: + return FmhaFwdAppendKVApiTrait( + hdim=str(self.F_hdim), + dtype=self.F_dtype, + bs=self.F_tile.F_bs, + bsk=self.F_tile.F_bsk, + bd=self.F_tile.F_bd, + bdv=self.F_tile.F_bdv, + vlayout=self.F_pipeline.F_vlayout, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + rope=self.F_pipeline.F_rope, + pagedkv=self.F_pipeline.F_pagedkv) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1), + '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1) + } + else: + return None + +def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + # NOTICE: it will be very complicated if we consider all the hdim_q padding cases while + # applying rotary embedding, so I just use 't' in inter/half pipelines + for vlayout in ['row', 'col']: + for pagedkv in ["t", "f"]: + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv)) + + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv)) + + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv)) + elif dtype in ['fp8', 'bf8']: + # rope/paged-kv is not supported + pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdAppendKVApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str in d.keys(): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + k = FmhaFwdAppendKVKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_fwd_appendkv_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 5093945095..ba826c8fb3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -21,6 +21,14 @@ ) +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + FMHA_FWD_SPLITKV_PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", @@ -51,8 +59,8 @@ {F_bias}, false, {F_lse}, - {F_dropout}, {F_squant}, + {F_pagedkv}, kHasUnevenSplits, {F_occupancy}>; @@ -63,7 +71,6 @@ typename FmhaFwdTypeConfig::SaccDataType, typename FmhaFwdTypeConfig::SMPLComputeDataType, typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, typename FmhaFwdTypeConfig::LSEDataType, typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, @@ -86,7 +93,7 @@ fmha_pipeline, fmha_epilogue>; -static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); @@ -97,16 +104,21 @@ }}; }} -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_dvpad}>; #include template<> -void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if constexpr({F_mode} == false) {{ // batch mode - if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + // we don't check every seqlen_k values for kvcache + if (a.seqlen_k_ptr != nullptr) {{ + kernel_runner::run(s, a); + // make sure F_bn0 is divisible by F_bk1 + }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ kernel_runner::run(s, a); }} else {{ kernel_runner::run(s, a); @@ -160,7 +172,7 @@ fmha_pipeline, fmha_epilogue>; -static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); @@ -177,7 +189,7 @@ #include template<> -void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if (a.num_splits <= 16) {{ kernel_runner<4>::run(s, a); @@ -203,7 +215,7 @@ #include template -float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if(s.log_level_ > 0) std::cout @@ -217,22 +229,96 @@ ); }} -float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_splitkv_(s, a); }} """ +@dataclass +class FmhaFwdSplitKVApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0blen : int + vlayout : str + mask : str + bias : str # + lse : str # + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + pagedkv : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ + f'{self.dvpad}-{self.pagedkv}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bk0blen} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bk0blen} == 0' + else: assert False + @dataclass class FmhaFwdSplitKVPipeline: tag : str @@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline: F_dvpad : str # F_bias : str # true/false F_lse : str # - F_dropout : str # F_squant : str # + F_pagedkv : str # t/f F_mask : str # value from MASK_MAP @property @@ -267,8 +353,8 @@ def pad_name() -> str: else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_lse == 't' : n += '_lse' - if self.F_dropout == 't' : n += '_dropout' if self.F_squant == 't' : n += '_squant' + if self.F_pagedkv == 't' : n += '_pagedkv' return n @dataclass @@ -300,7 +386,7 @@ def __init__(self, mask_impl): self.pool = dict() self.mask_impl = mask_impl - def register_traits(self, trait : FmhaFwdApiTrait) -> None: + def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() @@ -322,8 +408,8 @@ def api(self) -> str: inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) @@ -383,8 +469,8 @@ def template(self) -> str: F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -401,8 +487,8 @@ def name(self) -> str: def filename(self) -> str: return self.name + ".cpp" - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( + def api_trait(self) -> FmhaFwdSplitKVApiTrait: + return FmhaFwdSplitKVApiTrait( pipeline_tag=self.F_pipeline.tag, hdim=str(self.F_hdim), dtype=self.F_dtype, @@ -417,8 +503,8 @@ def api_trait(self) -> FmhaFwdApiTrait: mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, squant=self.F_pipeline.F_squant, + pagedkv=self.F_pipeline.F_pagedkv, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, @@ -460,29 +546,6 @@ def name(self) -> str: def filename(self) -> str: return self.name + ".cpp" - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) - # TODO: design a more practical way to do it # this is current supported tile size per hdim def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: @@ -533,27 +596,27 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - # splitkv kernel donot support dropout - for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]): - if hdim == 256: + for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + # TODO: use async pipeline when compiler is more stable + if hdim == 256 or hdim in [32, 64, 128]: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - # no need lse/dropout kernels + # no need lse/paged-kv kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask)) else: assert False return pipelines @@ -574,6 +637,9 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue + if pipeline.F_pagedkv == 't': + # we only use batch mode kernels to handle (paged-) kvcache problems + continue k = Kernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 4cc3e77c75..723546a452 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -4,6 +4,7 @@ #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" #include "mask.hpp" +#include "rotary.hpp" #include "utils.hpp" #include @@ -16,6 +17,10 @@ #include #include +#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()" +#endif + template std::ostream& operator<<(std::ostream& os, const std::vector& v) { @@ -50,7 +55,11 @@ auto create_args(int argc, char* argv[]) "seqlen_q. if group-mode, means the average value of seqlen_q\n" "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" "also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("s_k", "-1", "seqlen_k (including new key/value), -1 means equal to s") + .insert("s_knew", + "0", + "seqlen_k for new key/value, 0 means not to use this at all; " + "-1 to choose s_knew in [1, s] randomly.") .insert("s_kpad", "-1", "seqlen_k stride between 2 tokens, currently used in group-mode only\n" @@ -114,9 +123,14 @@ auto create_args(int argc, char* argv[]) .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert( + "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") + .insert("rotary_interleaved", "1", "whether to apply interleaved RoPE") .insert("num_splits", "1", "# of splits for key/value. 0 to determine actual number by heuristic") + .insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe") + .insert("cache_batch_idx", "0", "whether to use index map to the kvcache") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -244,20 +258,6 @@ int override_num_splits_if_necessary( return num_splits; } -float fmha_fwd_dispatch(fmha_fwd_traits traits, - fmha_fwd_args args, - const ck_tile::stream_config& config) -{ - if(1 < args.num_splits) - { - return fmha_fwd_splitkv(traits, args, config); - } - else - { - return fmha_fwd(traits, args, config); - } -} - template bool run(const ck_tile::ArgParser& arg_parser) { @@ -276,11 +276,114 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } - auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, - batch, - arg_parser.get_str("s"), - arg_parser.get_str("s_k"), - arg_parser.get_str("s_kpad")); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v < 0) + hdim_v = hdim_q; + + ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API + if(seqlen_knew != 0) + { + std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; + seqlen_knew = 0; + } +#endif + if(seqlen_knew < 0) + { + seqlen_knew = randint(1, arg_parser.get_int("s"), seed); + } + + ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); + if constexpr(!(std::is_same_v || + std::is_same_v)) + { + if(0 < rotary_dim) + { + std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl; + return false; + } + } +#if !CK_TILE_FMHA_FWD_APPENDKV_API + else if(0 < rotary_dim) + { + std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option" + << std::endl; + rotary_dim = 0; + } +#endif + if(!(rotary_dim <= hdim_q)) + { + std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; + return false; + } + else if(!(rotary_dim % 16 == 0)) + { + std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl; + return false; + } + + ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) + { + std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" + << std::endl; + page_block_size = 0; + } +#endif + if(!(page_block_size % 128 == 0)) + { + std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" + << std::endl; + return false; + } + + bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API + if(use_cache_batch_idx) + { + std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } +#endif + if(0 < page_block_size && use_cache_batch_idx) + { + std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + // the input tensor layout for kvcache is same as batch mode + const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); + const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); + if(use_kvcache && mode != mode_enum::batch) + { + std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl; + mode = mode_enum::batch; + } + + auto [seqlen_qs, seqlen_ks, seqlen_kpads] = + decode_seqlen(mode, + batch, + arg_parser.get_str("s"), + arg_parser.get_str("s_k"), + arg_parser.get_str("s_kpad"), + /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, + use_kvcache); + // compute kvcache seqlen_k (before appending knew/vnew) + auto cache_seqlen_ks = seqlen_ks; + std::transform(cache_seqlen_ks.begin(), + cache_seqlen_ks.end(), + cache_seqlen_ks.begin(), + [&](auto seqlen_k) { return seqlen_k - seqlen_knew; }); #if 0 // clang-format off @@ -290,11 +393,6 @@ bool run(const ck_tile::ArgParser& arg_parser) // clang-format on #endif - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - if(hdim_v < 0) - hdim_v = hdim_q; - bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim @@ -356,14 +454,18 @@ bool run(const ck_tile::ArgParser& arg_parser) s_randval = true; } - std::string init_method = arg_parser.get_str("init"); - std::optional seed = arg_parser.get_uint32("seed"); - if(*seed == 0) + std::string init_method = arg_parser.get_str("init"); + + const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); + + ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); +#if !CK_TILE_FMHA_FWD_SPLITKV_API + if(num_splits != 1) { - seed.reset(); + std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl; + num_splits = 1; } - - int num_splits = arg_parser.get_int("num_splits"); +#endif int stream_warmup = arg_parser.get_int("warmup"); int stream_repeat = arg_parser.get_int("repeat"); @@ -425,6 +527,11 @@ bool run(const ck_tile::ArgParser& arg_parser) } } + const ck_tile::index_t max_num_page_blocks = + (0 < page_block_size + ? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size)) + : 0); + // legalize num_splits according to other options if(num_splits < 1) { @@ -436,6 +543,14 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cerr << "num_splits greater than 128 is not supported" << std::endl; return false; } +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < p_drop && (1 < num_splits || use_kvcache)) + { + std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" + << std::endl; + p_drop = 0.0f; + } +#endif auto get_lengths = [&](bool permute, ck_tile::index_t b /*batch*/, @@ -462,11 +577,26 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor k_host( - get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + 0 < page_block_size + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) + : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode + ck_tile::HostTensor knew_host( + 0 < seqlen_knew + ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor v_host( - is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) - : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); - + 0 < page_block_size + ? (is_v_rowmajor + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v) + : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) + : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); + ck_tile::HostTensor vnew_host( + 0 < seqlen_knew + ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) + : get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew)) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor bias_host( bias.type == bias_enum::elementwise_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) @@ -478,12 +608,15 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{batch, nhead}) : std::array{1, 1}); + auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin( + std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed); + ck_tile::HostTensor lse_acc_host( - 1 < num_splits + 1 < num_splits || use_kvcache ? std::array{num_splits, shape_batch, nhead, shape_seqlen_q} : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( - 1 < num_splits + 1 < num_splits || use_kvcache ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} : std::array{1, 1, 1, 1, 1}); @@ -500,39 +633,57 @@ bool run(const ck_tile::ArgParser& arg_parser) p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1}); + ck_tile::HostTensor block_table_host( + 0 < page_block_size ? std::array{batch, max_num_page_blocks / batch} + : std::array{1, 1}); + + ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx + ? std::array{batch} + : std::array{1}); + if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } else if(init_method == "ni") { ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } else if(init_method == "uf" || init_method == "1") { ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); } else if(init_method == "nf") { ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(knew_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(vnew_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); } else if(init_method == "tf" || init_method == "2") { ck_tile::FillTrigValue{}(q_host); ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(knew_host); ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } else if(init_method == "ufq" || init_method == "uf:q" || @@ -540,7 +691,9 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(knew_host); ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(v_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(vnew_host); // bias_fp8 = qscale_bias * bias_fp32 float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k); @@ -550,7 +703,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); - assert(slopes.size() == nhead); + assert(slopes.size() == static_cast(nhead)); if(bias.rank_info == 0) { // alibi in 1*h @@ -565,10 +718,14 @@ bool run(const ck_tile::ArgParser& arg_parser) } } } + iota_shuffle(block_table_host.begin(), block_table_host.end(), 0); + iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0); ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); @@ -576,27 +733,41 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_buf( + use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem cache_seqlen_k_buf( + need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); + knew_buf.ToDevice(knew_host.data()); v_buf.ToDevice(v_host.data()); + vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() : seqstart_k_with_padding_host.data()); - seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data()); + seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); + cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); + rotary_cos_buf.ToDevice(rotary_cos_host.data()); + rotary_sin_buf.ToDevice(rotary_sin_host.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data()); + block_table_buf.ToDevice(block_table_host.data()); + cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); // clang-format off auto layout_str = [&](bool permute){ - if (permute) return std::string("bhsd"); + if(permute) return std::string("bhsd"); else return std::string("bshd"); }; auto io_layout = [&](bool iperm_, bool operm_) { - if (iperm_ == operm_) return layout_str(iperm_); + if(iperm_ == operm_) return layout_str(iperm_); else return layout_str(iperm_) + std::string("-") + layout_str(operm_); }; // clang-format on @@ -609,51 +780,77 @@ bool run(const ck_tile::ArgParser& arg_parser) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout; +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(0 < rotary_dim) + { + std::cout << ", rotary_dim:" << rotary_dim << "(" + << (is_rotary_interleaved ? "inter" : "half") << ")"; + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API if(1 < num_splits) { std::cout << ", num_splits:" << num_splits; } + if(0 < page_block_size) + { + std::cout << ", page_block_size:" << page_block_size; + } + if(use_cache_batch_idx) + { + std::cout << ", cache_batch_idx:" << use_cache_batch_idx; + } +#endif std::cout << std::flush; - auto fmha_traits = fmha_fwd_traits{hdim_q, - hdim_v, - data_type, - mode == mode_enum::group, - is_v_rowmajor, - mask.type, - bias.type, - lse, - p_drop > 0.0f, - squant}; + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; - auto p_compute_element_func = [&]() { - if constexpr(std::is_same_v) - return ck_tile::scales{scale_p}; - else - return ck_tile::identity{}; - }(); + if constexpr(std::is_same_v>) + { + traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved + : rope_enum::half_rotated) + : rope_enum::none); + } + else // fmha_fwd_traits or fmha_splitkv_traits + { + traits.is_group_mode = (mode == mode_enum::group); + traits.mask_type = mask.type; + traits.bias_type = bias.type; + traits.has_lse = lse; + traits.do_fp8_static_quant = squant; - auto oacc_element_func = [&]() { - if constexpr(std::is_same_v) - return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o}); - else - return ck_tile::identity{}; - }(); + if constexpr(std::is_same_v>) + { + traits.has_dropout = (p_drop > 0.0f); + } + } + }; - auto fmha_args = [&, k_paddings_ = seqlen_kpads]() { + const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// 'nhead_stride_bias' are 0. // setup stride_* arguments - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size) + : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_vnew = [&]() { if(is_v_rowmajor) return i_perm ? hdim_v : nhead_k * hdim_v; else - return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; + return i_perm ? seqlen_knew : nhead_k * seqlen_knew; }(); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); @@ -661,12 +858,23 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_v = [&]() { + const ck_tile::index_t nhead_stride_k = + (0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q) + : (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); + const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { if(is_v_rowmajor) - return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v) + : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); else - return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size) + : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); + }(); + const ck_tile::index_t nhead_stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? seqlen_knew * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_knew : seqlen_knew; }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); @@ -676,88 +884,194 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); - const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = + (0 < page_block_size ? (nhead_k * page_block_size * hdim_q) + : (nhead_k * shape_seqlen_k * hdim_q)); + const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); + const ck_tile::index_t batch_stride_v = + (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) + : (nhead_k * hdim_v * shape_seqlen_k)); + const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); - return fmha_fwd_args{q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() - : bias_buf.GetDeviceBuffer(), - randval_buf.GetDeviceBuffer(), - lse_acc_buf.GetDeviceBuffer(), - o_acc_buf.GetDeviceBuffer(), - lse_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(), - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - num_splits, - scale_s, - scale_p, - scale_o, - stride_q, - stride_k, - stride_v, - bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) - : stride_bias, - stride_randval, - stride_o_acc, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_lse_acc, - nhead_stride_o_acc, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_lse_acc, - batch_stride_o_acc, - batch_stride_o, - split_stride_lse_acc, - split_stride_o_acc, - mask.left, - mask.right, - static_cast(mask.type), - p_drop, - s_randval, - {drop_seed, drop_offset}}; + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; // unused in group mode + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + if constexpr(std::is_same_v>) + { + args.knew_ptr = knew_buf.GetDeviceBuffer(); + args.vnew_ptr = vnew_buf.GetDeviceBuffer(); + args.seqlen_knew = seqlen_knew; + + args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer(); + + args.rotary_cos_ptr = (0 < rotary_dim ? rotary_cos_buf.GetDeviceBuffer() : nullptr); + args.rotary_sin_ptr = (0 < rotary_dim ? rotary_sin_buf.GetDeviceBuffer() : nullptr); + args.rotary_dim = rotary_dim; + args.has_mask = (mask.type != mask_enum::no_mask); + + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.stride_knew = stride_knew; + args.stride_vnew = stride_vnew; + args.nhead_stride_knew = nhead_stride_knew; + args.nhead_stride_vnew = nhead_stride_vnew; + args.batch_stride_knew = batch_stride_knew; + args.batch_stride_vnew = batch_stride_vnew; + } + else // fmha_fwd_args or fmha_fwd_splitkv_args + { + args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(); + args.lse_ptr = lse_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqstart_q_ptr = + (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = + (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = + (use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); + + args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + args.scale_p = scale_p; + args.scale_o = scale_o; + + args.stride_bias = + (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); + args.stride_o = stride_o; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + if constexpr(std::is_same_v>) + { + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); + + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + + args.p_drop = p_drop; + args.s_randval = s_randval; + args.drop_seed_offset = std::tie(drop_seed, drop_offset); + } + else if constexpr(std::is_same_v>) + { + args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); + args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); + + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.num_splits = num_splits; + + args.stride_o_acc = stride_o_acc; + args.nhead_stride_lse_acc = nhead_stride_lse_acc; + args.nhead_stride_o_acc = nhead_stride_o_acc; + args.batch_stride_lse_acc = batch_stride_lse_acc; + args.batch_stride_o_acc = batch_stride_o_acc; + args.split_stride_lse_acc = split_stride_lse_acc; + args.split_stride_o_acc = split_stride_o_acc; + } + } + }; + + const float appendkv_ave_time = [&] { +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(need_append_kvcache) + { + fmha_fwd_appendkv_traits fwd_appendkv_traits; + init_traits(fwd_appendkv_traits); + + fmha_fwd_appendkv_args fwd_appendkv_args; + init_args(fwd_appendkv_args); + + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config); + } +#endif + return 0.0f; }(); - float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config); + const float fwd_ave_time = [&] { +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(1 < num_splits || use_kvcache) + { + fmha_fwd_splitkv_traits fmha_splitkv_traits; + init_traits(fmha_splitkv_traits); + + fmha_fwd_splitkv_args fmha_splitkv_args; + init_args(fmha_splitkv_args); + + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); + } +#endif + fmha_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_fwd_args fmha_args; + init_args(fmha_args); - if(ave_time < 0) + return fmha_fwd(fmha_traits, fmha_args, stream_config); + }(); + + if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; return false; } + const float ave_time = (appendkv_ave_time + fwd_ave_time); + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; @@ -775,36 +1089,46 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); + + auto p_compute_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::scales{scale_p}; + else + return ck_tile::identity{}; + }(); + + auto oacc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::identity{}; + }(); + float p_undrop = 1.0 - p_drop; uint8_t p_undrop_in_uint8_t = uint8_t(std::floor(p_undrop * std::numeric_limits::max())); float rp_undrop = 1.0 / p_undrop; bool pass = true; - for(ck_tile::index_t wb = 0; wb < batch; ++wb) { const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; // adjust matrix index according to the mode - const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t cache_b_idx = + (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); - const auto v_host_ref_lengths = - std::array{nhead, hdim_v, real_seqlen_k}; - const auto v_host_ref_strides = - is_v_rowmajor - ? std::array{hdim_v * real_seqlen_k, 1, hdim_v} - : std::array{hdim_v * real_seqlen_k, real_seqlen_k, 1}; - ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); - ck_tile::HostTensor v_host_ref(v_host_ref_lengths, v_host_ref_strides); + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); ck_tile::HostTensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); @@ -815,22 +1139,138 @@ bool run(const ck_tile::ArgParser& arg_parser) // clang-format off // permute - if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); - else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); }); + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // optionally apply RoPE to the q_host_ref + if(0 < rotary_dim) + { + decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); + + auto [rotary_cos_slice, rotary_sin_slice] = + slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); + + ck_tile::reference_batched_rotary_position_embedding( + q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, q_host_ref_ro, + /*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask); - if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); - else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) { + if(i_perm) { + k_host_ref.ForEach([&](auto& self, auto i) { + self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); + }); + } else { + k_host_ref.ForEach([&](auto& self, auto i) { + self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); + }); + } + } else +#endif + { + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); + } + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // copy Knew to the end of K + if(0 < seqlen_knew) + { + ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); + if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); }); + else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); }); + + // optionally apply RoPE to the knew_host_ref + auto* real_knew_host_ref = &knew_host_ref; + std::optional knew_host_ref_ro; + if(0 < rotary_dim) + { + knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); + + auto [rotary_cos_slice, rotary_sin_slice] = + slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); + + ck_tile::reference_batched_rotary_position_embedding( + knew_host_ref, + rotary_cos_slice, + rotary_sin_slice, + is_rotary_interleaved, + knew_host_ref_ro.value()); + + real_knew_host_ref = &knew_host_ref_ro.value(); + } - if (is_v_rowmajor) { - // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); - // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + (*real_knew_host_ref).ForEach([&](auto& self, auto i) { + k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i); + }); } - else { - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) { + if(is_v_rowmajor) { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); + }); + } else { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); + }); + } + } + else + { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); + }); + } else { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); + }); + } + } + } else +#endif + { + if(is_v_rowmajor) { + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); + } + else + { + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); + } } + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // copy Vnew to the end of V + if(0 < seqlen_knew) + { + ck_tile::HostTensor vnew_host_ref({nhead, hdim_v, seqlen_knew}); + if(is_v_rowmajor) + { + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); }); + } + else + { + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); + } + + vnew_host_ref.ForEach([&](auto& self, auto i) { + v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i); + }); + } +#endif // clang-format on // reference @@ -959,7 +1399,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor randval_host_ref( {nhead, real_seqlen_q, real_seqlen_k}); randval_host_ref.ForEach([&](auto& self, auto idx) { - self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); ck_tile::reference_batched_dropout( p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); @@ -976,8 +1416,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off // permute - if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); - else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on auto [rtol, atol] = get_elimit(init_method); @@ -999,7 +1439,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b, idx[0], idx[1] + query_offset); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); }); cur_pass = ck_tile::check_err(lse_host_result, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index c4c951c43a..183475064a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -5,10 +5,13 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" -#include "mask.hpp" +#include "ck_tile/ops/fmha.hpp" + #include "bias.hpp" +#include "mask.hpp" +#include "rotary.hpp" + #include template @@ -93,13 +96,86 @@ struct fmha_fwd_args const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + std::tuple drop_seed_offset; +}; + +struct fmha_fwd_splitkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer void* lse_acc_ptr; void* o_acc_ptr; void* lse_ptr; void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // kvcache mode (use same kernel as batch mode): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -109,21 +185,21 @@ struct fmha_fwd_args ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; ck_tile::index_t num_splits; + float scale_s; float scale_p; float scale_o; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 - ck_tile::index_t stride_randval; ck_tile::index_t stride_o_acc; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; - ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_o_acc; @@ -132,19 +208,62 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_bias; - ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o; ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_o_acc; + ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; - float p_drop; - bool s_randval; - std::tuple drop_seed_offset; +}; + +struct fmha_fwd_appendkv_args +{ + void* q_ptr; + void* k_ptr; + const void* knew_ptr; + void* v_ptr; + const void* vnew_ptr; + + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_knew; + ck_tile::index_t batch; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0 + const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0 + ck_tile::index_t rotary_dim; + bool has_mask; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_knew; + ck_tile::index_t stride_v; + ck_tile::index_t stride_vnew; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_knew; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_vnew; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_knew; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_vnew; }; template @@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } template -auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { @@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, - args.rand_val_ptr, args.lse_acc_ptr, args.o_acc_ptr, args.batch, - args.max_seqlen_q, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_k_ptr, @@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) args.stride_k, args.stride_v, args.stride_bias, - args.stride_randval, args.stride_o_acc, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, - args.nhead_stride_randval, args.nhead_stride_lse_acc, args.nhead_stride_o_acc, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_lse_acc, args.batch_stride_o_acc, args.split_stride_lse_acc, args.split_stride_o_acc, args.window_size_left, args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.mask_type); } else { // create batch mode kernel arguments @@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, - args.rand_val_ptr, args.lse_acc_ptr, args.o_acc_ptr, args.batch, - args.max_seqlen_q, args.seqlen_q, args.seqlen_k, + args.seqlen_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, args.nhead_q / args.nhead_k, args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, args.scale_s, args.scale_p, args.stride_q, args.stride_k, args.stride_v, args.stride_bias, - args.stride_randval, args.stride_o_acc, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, - args.nhead_stride_randval, args.nhead_stride_lse_acc, args.nhead_stride_o_acc, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, args.batch_stride_bias, - args.batch_stride_randval, args.batch_stride_lse_acc, args.batch_stride_o_acc, args.split_stride_lse_acc, args.split_stride_o_acc, args.window_size_left, args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.mask_type); } }(); @@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) } template -auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { @@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) return ck_tile::make_tuple(kargs, grids); } +template +auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.knew_ptr, + args.v_ptr, + args.vnew_ptr, + args.seqlen_q, + args.seqlen_k_ptr, + args.seqlen_knew, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.rotary_cos_ptr, + args.rotary_sin_ptr, + args.rotary_dim, + args.has_mask, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.stride_q, + args.stride_k, + args.stride_knew, + args.stride_v, + args.stride_vnew, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_knew, + args.nhead_stride_v, + args.nhead_stride_vnew, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_knew, + args.batch_stride_v, + args.batch_stride_vnew); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew); + + return ck_tile::make_tuple(kargs, grids); +} + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +template +struct fmha_fwd_splitkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + template -void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); template std::string fmha_fwd_splitkv_get_name_(); @@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_ }; template -void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); template std::string fmha_fwd_splitkv_combine_get_name_(); +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_appendkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; + static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; + static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; + static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSk = kPadSk_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr auto RotaryEnum = RotaryEnum_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); + // This is the public API, will be generated by script struct fmha_fwd_traits { @@ -508,4 +743,32 @@ struct fmha_fwd_traits // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); -float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); + +struct fmha_fwd_splitkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, + fmha_fwd_splitkv_args, + const ck_tile::stream_config&); + +struct fmha_fwd_appendkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_v_rowmajor; + rope_enum rope_type; +}; +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, + fmha_fwd_appendkv_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 27347b4476..9b91d36fb2 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -5,25 +5,30 @@ import argparse from enum import IntEnum from pathlib import Path +import pkgutil +import sys from typing import List, Optional +import codegen.ops from codegen.cmake_config import * -from codegen.ops import ( - fmha_fwd, - fmha_fwd_splitkv, - fmha_bwd -) class HandlerId(IntEnum): LIST_BLOBS = 0 WRITE_BLOBS = 1 -handlers = { - 'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs), - 'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs), - 'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs), -} +# inspect all modules under 'codegen.ops' and register API handlers +ops = [] +for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): + full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) + if full_module_name not in sys.modules: + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) +unwanted_prefix = 'fmha_' +handlers = dict( + [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, + (op.list_blobs, op.write_blobs)) for op in ops] +) +assert 0 < len(handlers) def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: @@ -103,4 +108,4 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter if args.list_blobs is not None: list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) \ No newline at end of file + write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/rotary.hpp b/example/ck_tile/01_fmha/rotary.hpp new file mode 100644 index 0000000000..346f2a5e7e --- /dev/null +++ b/example/ck_tile/01_fmha/rotary.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include + +// keep sync with RotaryEmbeddingEnum +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + +template +std::tuple, ck_tile::HostTensor> +generate_rotary_cos_sin(ck_tile::index_t seqlen, + ck_tile::index_t rotary_dim, + std::optional seed = std::nullopt) +{ + // return dummy tensors if we won't apply RoPE at all + if(rotary_dim <= 0) + { + ck_tile::HostTensor dummy({1, 1}); + return std::make_tuple(dummy, dummy); + } + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_real_distribution generator(0.0f, 1.0f); + + const ck_tile::index_t num_rows = seqlen * 2; + const ck_tile::index_t num_cols = rotary_dim / 2; + + using std::begin, std::end; + + ck_tile::HostTensor angle({num_rows, num_cols}); + std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; }); + + ck_tile::HostTensor cos({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) { + return ck_tile::type_convert(std::cos(origin_value)); + }); + + ck_tile::HostTensor sin({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) { + return ck_tile::type_convert(std::sin(origin_value)); + }); + + return std::make_tuple(cos, sin); +} + +template +std::tuple, ck_tile::HostTensor> +slice_rotary_cos_sin(const ck_tile::HostTensor& cos, + const ck_tile::HostTensor& sin, + ck_tile::index_t seqlen_offset, + ck_tile::index_t seqlen) +{ + assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2); + assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1)); + + assert(static_cast(seqlen_offset + seqlen) <= cos.get_length(0)); + + const ck_tile::index_t num_rows = seqlen; + const ck_tile::index_t num_cols = cos.get_length(1); + + ck_tile::HostTensor cos_pt({num_rows, num_cols}); + cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); }); + + ck_tile::HostTensor sin_pt({num_rows, num_cols}); + sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); }); + + return std::make_tuple(cos_pt, sin_pt); +} diff --git a/example/ck_tile/01_fmha/script/benchmark_bwd.sh b/example/ck_tile/01_fmha/script/benchmark_bwd.sh index 7591f5442a..cfd792906c 100755 --- a/example/ck_tile/01_fmha/script/benchmark_bwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_bwd.sh @@ -1,7 +1,6 @@ #!/bin/sh -# TODO: run this script from CK root -BUILD=build -EXE=$BUILD/bin/tile_example_fmha_bwd +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" VALID=0 for prec in "fp16" "bf16" ; do diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 859cff9f60..599c595a75 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -1,7 +1,6 @@ #!/bin/sh -# TODO: run this script from CK root -BUILD=build -EXE=$BUILD/bin/tile_example_fmha_fwd +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" VALID=0 for prec in "fp16" "bf16" ; do diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index dbb592820e..5ba3425e26 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -1,7 +1,6 @@ #!/bin/sh -# TODO: run this script from CK root -BUILD=build -EXE=$BUILD/bin/tile_example_fmha_bwd +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" KNAME=1 export CK_WARMUP=0 diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 779e8d09ee..5dcc6ed42b 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -1,7 +1,6 @@ -#!/bin/sh -# TODO: run this script from CK root -BUILD=build -EXE=$BUILD/bin/tile_example_fmha_fwd +#!/bin/bash +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" KNAME=1 export CK_WARMUP=0 @@ -10,44 +9,98 @@ export CK_REPEAT=1 COMMON_ARGS='-v=1 -warmup=0 -repeat=1' # mode=0 # export HIP_VISIBLE_DEVICES=4 -set -x -for prec in "fp16" "bf16" ; do -for mode in 1 0 ; do -for perm in 0 1 ; do -for vlayout in "r" "c" ; do -for hdim in 32 64 128 256 ; do -for lse in 0 1 ; do -for bias in "n" "e" "a" ; do -for p_drop in 0.0 0.2; do - -# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -done -done -done -done -done -done -done +TEST_SPLITKV=0 +TEST_APPENDKV=0 +# options: +# -s: run splitkv tests +# -a: run appendkv tests +while getopts ":sa" opt; do + case "${opt}" in + s) + TEST_SPLITKV=1 + ;; + a) + TEST_APPENDKV=1 + ;; + *) + ;; + esac done +run_fp16_bf16_tests() { + local NUM_SPLITS=(1) + local PAGE_BLOCK_SIZE=(0) + local CACHE_BATCH_IDX=(0) -for perm in 0 1 ; do -for bias in "n" "e" "a" ; do -for b in 1 2 ; do -for hdim in 64 128 256 ; do -$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS -done -done -done -done -set +x + if [ $TEST_SPLITKV -eq 1 ] ; then + NUM_SPLITS+=(2 3) + PAGE_BLOCK_SIZE+=(128) + CACHE_BATCH_IDX+=(1) + fi + + for prec in "fp16" "bf16" ; do + for mode in 1 0 ; do + for perm in 0 1 ; do + for vlayout in "r" "c" ; do + for hdim in 32 64 128 256 ; do + for lse in 0 1 ; do + for bias in "n" "e" "a" ; do + for p_drop in 0.0 0.2 ; do + for num_splits in "${NUM_SPLITS[@]}" ; do + for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do + for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do + + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done ; done ; done + done ; +} + +run_fp8_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp16_appendkv_tests() { + for s in $(seq 63 1 65) ; do + for s_k in 65 129 ; do + for s_knew in 0 64 $s_k ; do + for hdim in 32 64 128 256 ; do + for ri in 0 1 ; do + for rdim in 0 16 32 $hdim ; do + for page_block_size in 0 128 ; do + for cache_batch_idx in 0 1 ; do + + $EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done +} + +set -x + +run_fp16_bf16_tests +run_fp8_tests + +if [ $TEST_APPENDKV -eq 1 ] ; then + run_fp16_appendkv_tests +fi + +set +x \ No newline at end of file diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index 737efd8256..70a5844cde 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -3,15 +3,17 @@ #pragma once +#include #include #include +#include #include #include +#include +#include #include #include #include -#include -#include #include "ck_tile/core/container/span.hpp" @@ -37,18 +39,21 @@ std::vector to_seqstarts(ck_tile::span seqlens) return seqstarts; } -std::vector generate_seqlens(mode_enum mode, - unsigned count, +std::vector generate_seqlens(unsigned count, int32_t seqlen_avg, + int32_t seqlen_min = -1, // if not negative, clamp min int32_t seqlen_max = -1, // if not negative, clamp max std::optional seed = std::nullopt) { assert(0 < count); - std::vector seqlens( - count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg); + seqlen_min = (0 < seqlen_min ? seqlen_min : 1); + seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits::max()); + assert(seqlen_min <= seqlen_max); - if(mode == mode_enum::group && 1 < count) + std::vector seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max)); + + if(1 < count) { using size_type = std::vector::size_type; @@ -62,15 +67,15 @@ std::vector generate_seqlens(mode_enum mode, for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) { const size_type to_decrease = next_idx(); - // make sure each elements of seqlens is always greater than 0 - if(seqlens[to_decrease] == 1) + // make sure each elements of seqlens is always greater than seqlen_min + if(seqlens[to_decrease] == seqlen_min) { continue; } const size_type to_increase = (to_decrease + next_step()) % count; - if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max) + if(seqlens[to_increase] >= seqlen_max) { continue; } @@ -83,13 +88,29 @@ std::vector generate_seqlens(mode_enum mode, return seqlens; } -std::vector generate_seqstarts(mode_enum mode, - unsigned count, - int32_t seqlen_avg, - int32_t seqlen_max = -1, - std::optional seed = std::nullopt) +// return random integer generated uniformly in range [low, high] +template +auto randint(Int low, Int high, std::optional seed = std::nullopt) + -> std::enable_if_t, Int> +{ + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution dist(low, high); + return dist(engine); +} + +// return random integers generated uniformly in range [low, high] +template +auto randints(ForwardIterator first, + ForwardIterator last, + Int low, + Int high, + std::optional seed = std::nullopt) + -> std::enable_if_t> { - return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed)); + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution dist(low, high); + + std::generate(first, last, [&] { return dist(engine); }); } /* @@ -112,16 +133,45 @@ decode_seqlen(mode_enum mode, std::string q_val, std::string k_val, std::string k_pad_val, - std::optional seed = std::nullopt) + ck_tile::index_t seqlen_k_min = 0, + bool use_kvcache = false, + std::optional seed = std::nullopt) { #define _S2I_(str_) static_cast(std::atoi((str_).c_str())) if(mode == mode_enum::batch) { ck_tile::index_t q = _S2I_(q_val); ck_tile::index_t k = _S2I_(k_val); - auto s_q = std::vector(batch, q); - auto s_k = std::vector(batch, k < 0 ? q : k); + + auto s_q = std::vector(batch, q); + auto s_k = [&] { + const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); + std::vector seqlen_ks(batch, seqlen_k_max); + + if(1 < batch && use_kvcache) + { + // to keep the original s_k value, we always use seqlen_k_max in first batch + randints(std::next(seqlen_ks.begin()), + seqlen_ks.end(), + seqlen_k_min, + seqlen_k_max, + seed); + return seqlen_ks; + } + + return seqlen_ks; + }(); auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + + // s_k should be greater than or equal to seqlen_k_min if provided + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + return std::make_tuple(s_q, s_k, s_kpad); } else @@ -149,6 +199,16 @@ decode_seqlen(mode_enum mode, s_q.push_back(q); s_k.push_back(k < 0 ? q : k); s_kpad.push_back(kp); + + // s_k should be greater than or equal to seqlen_k_min + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + idx++; if(found_q == std::string::npos || idx >= batch) { @@ -160,8 +220,9 @@ decode_seqlen(mode_enum mode, } if(idx < batch) { - auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed); - auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed); + auto rem_q = generate_seqlens(batch - idx, s_q.back(), 1, s_kpad.back(), seed); + auto rem_k = + generate_seqlens(batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed); s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); @@ -180,3 +241,15 @@ int env_get_int(const char* var_name, int default_int) r = std::atoi(v); return r; } + +template +std::enable_if_t> iota_shuffle(RandomAccessIterator first, + RandomAccessIterator last, + Int value, + std::optional seed = std::nullopt) +{ + std::iota(first, last, value); + + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::shuffle(first, last, engine); +} diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 9970bb3693..f512e50e0a 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -536,13 +536,20 @@ float log(float x) { return __logf(x); }; CK_TILE_HOST float log(float x) { return std::logf(x); }; -CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) +CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc) { - // TODO: this is hacky, we use u16 return __builtin_amdgcn_sad_u16(x, y, acc); } -CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) +CK_TILE_DEVICE uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc) +{ + /// TODO: replace inline asm when intrinsic is available + uint32_t res; + asm volatile("v_sad_u32 %0, %1, %2, %3" : "=v"(res) : "v"(x), "v"(y), "v"(acc)); + return res; +} + +CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc) { return (x > y ? (x - y) : (y - x)) + acc; } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 02c298e8a4..04ed44201b 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -214,6 +214,12 @@ struct tile_window_with_static_distribution CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + CK_TILE_DEVICE constexpr void + set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) + { + bottom_tensor_view_.buf_.p_data_ = data; + } + // move thread's window adaptor coordinate and bottom tensor coordinate // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( @@ -843,6 +849,17 @@ struct tile_window_with_static_lengths CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + } + + CK_TILE_DEVICE constexpr void + set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) + { + bottom_tensor_view_.buf_.p_data_ = data; + } + // move window-origin CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; } @@ -871,6 +888,39 @@ make_tile_window(const TensorView_& tensor_view, tensor_view, window_lengths, origin}; } +// duplicate tile window and replace its origin +template +CK_TILE_DEVICE constexpr auto +make_tile_window(const tile_window_with_static_lengths& tile_window, + const multi_index& origin) +{ + return tile_window_with_static_lengths{ + tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin}; +} + +template +CK_TILE_DEVICE constexpr auto +make_tile_window(const tile_window_with_static_lengths& tile_window, + const multi_index& origin, + const StaticTileDistribution& tile_distribution) +{ + return make_tile_window(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + origin, + tile_distribution); +} + +template +CK_TILE_DEVICE constexpr auto +make_tile_window(const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution) +{ + return make_tile_window(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution); +} + template CK_TILE_DEVICE void move_tile_window( tile_window_with_static_lengths& window, diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index f5dffda863..f6e133c759 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t>; template using remove_pointer_t = typename std::remove_pointer::type; +template +struct copy_const +{ + static_assert(!std::is_const_v); + + using type = To; +}; + +template +struct copy_const +{ + using type = std::add_const_t::type>; +}; + +template +using copy_const_t = typename copy_const::type; + namespace detail { template class Op, class... Args> struct detector diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 0e69a925d5..deebe90bf7 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -15,6 +15,7 @@ #include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp" +#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 43405ee69b..918abc69cc 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -155,7 +155,12 @@ struct HostTensorDescriptor return space; } + std::size_t get_length(std::size_t dim) const { return mLens[dim]; } + const std::vector& get_lengths() const { return mLens; } + + std::size_t get_stride(std::size_t dim) const { return mStrides[dim]; } + const std::vector& get_strides() const { return mStrides; } template @@ -325,8 +330,12 @@ struct HostTensor { } + std::size_t get_length(std::size_t dim) const { return mDesc.get_length(dim); } + decltype(auto) get_lengths() const { return mDesc.get_lengths(); } + std::size_t get_stride(std::size_t dim) const { return mDesc.get_stride(dim); } + decltype(auto) get_strides() const { return mDesc.get_strides(); } std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index e9c5a0c254..5c7bf12bfc 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) { // clang-format off if(!s.time_kernel_) { - (callables(s),...); hip_check_error(hipGetLastError()); + (callables(s),...); HIP_CHECK_ERROR(hipGetLastError()); return 0; } if(s.is_gpu_timer_) { gpu_timer timer {}; // warmup - for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); timer.start(s.stream_id_); - for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); timer.stop(s.stream_id_); return timer.duration() / s.nrepeat_; @@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) cpu_timer timer {}; // warmup - for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); timer.start(s.stream_id_); - for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); timer.stop(s.stream_id_); return timer.duration() / s.nrepeat_; diff --git a/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp b/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp new file mode 100644 index 0000000000..858144c8ba --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +#include +#include + +namespace ck_tile { + +template +CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor& input_bsd, + const HostTensor& cos_sd, + const HostTensor& sin_sd, + bool interleaved, + HostTensor& output_bsd, + bool use_1_row_sin_cos = false) +{ + assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2); + assert(cos_sd.get_length(0) == sin_sd.get_length(0) && + cos_sd.get_length(1) == sin_sd.get_length(1)); + + const index_t rotary_dim = cos_sd.get_length(1) * 2; + assert(static_cast(rotary_dim) <= input_bsd.get_length(2)); + + output_bsd.ForEach([&](auto& self, auto i) { + const index_t i_d = i[2]; + if(rotary_dim <= i_d) + { + self(i) = input_bsd(i); + return; + } + assert(i_d < rotary_dim); + + const index_t i_s = i[1]; + const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s); + + const ComputeDataType cos = type_convert( + interleaved ? cos_sd(i_s_cos_sin, i_d / 2) + : cos_sd(i_s_cos_sin, i_d % cos_sd.get_length(1))); + const ComputeDataType sin = type_convert( + interleaved ? sin_sd(i_s_cos_sin, i_d / 2) + : sin_sd(i_s_cos_sin, i_d % sin_sd.get_length(1))); + + const ComputeDataType half_rotated_input = [&] { + const index_t i_b = i[0]; + + if(interleaved) + { + const bool is_even = (i_d % 2 == 0); + const index_t pos = i_d + (is_even ? 1 : -1); + const ComputeDataType sign = (is_even ? -1 : 1); + return sign * type_convert(input_bsd(i_b, i_s, pos)); + } + else + { + const index_t half_rdim = (rotary_dim / 2); + const index_t pos = (i_d + half_rdim) % rotary_dim; + const ComputeDataType sign = (pos < half_rdim ? 1 : -1); + return sign * type_convert(input_bsd(i_b, i_s, pos)); + } + }(); + ComputeDataType result = + type_convert(input_bsd(i)) * cos + half_rotated_input * sin; + + self(i) = type_convert(result); + }); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index cad3009473..9389a5397f 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -7,7 +7,11 @@ #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/fmha/block/page_block_navigator.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" @@ -21,11 +25,11 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp index c2fdaf3a1a..703ec0967a 100644 --- a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp +++ b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp @@ -43,9 +43,12 @@ enum struct AlibiMode FROM_BOTTOM_RIGHT = 2, }; -template +template struct Alibi { + static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32, + "for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t"); + // RowMajor here means if pixel within the same thread are along the row, or col // this may impact the performance of update(), while the result are the same. // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false @@ -79,6 +82,19 @@ struct Alibi mode = mode_; } + CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); } + + CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) + { + if constexpr(LogMaxSadOprndSize <= 16) + { + return sad_u16( + static_cast(x), static_cast(y), static_cast(acc)); + } + + return sad_u32(x, y, acc); + } + CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) { if constexpr(RowMajor) @@ -128,7 +144,7 @@ struct EmptyPositionEncoding // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask // local is left_size >=0 or right_size >=0 -template +template CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, index_t window_left_size, index_t window_right_size, @@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, AlibiMode alibi_mode = is_causal ? AlibiMode::VERTICAL : static_cast(mask_enum) /*either top-left or bottom-right*/; - return Alibi{slope, y_total, x_total, alibi_mode}; + return Alibi{slope, y_total, x_total, alibi_mode}; } // https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 diff --git a/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp b/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp new file mode 100644 index 0000000000..5173279299 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class RotaryEmbeddingEnum +{ + NONE = 0, + INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc + HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc +}; + +template +struct RotaryEmbeddingEnumToStr; + +template <> +struct RotaryEmbeddingEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct RotaryEmbeddingEnumToStr +{ + static constexpr const char* name = "inter"; +}; +template <> +struct RotaryEmbeddingEnumToStr +{ + static constexpr const char* name = "half"; +}; + +template +struct BlockRotaryEmbedding +{ + template + CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile, + OtherDramBlockWindow other_window, + RotaryCosDramBlockWindow rotary_cos_window, + RotarySinDramBlockWindow rotary_sin_window, + index_t rotary_dim, + index_t thread_end) + { + using DataType = typename remove_cvref_t::DataType; + + if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) + { + auto rotary_cos_tile = load_tile(rotary_cos_window); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + if(thread_end <= rotary_dim) + { + constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + const auto left = type_convert(tile.thread_buf_[idx]); + const auto right = type_convert(tile.thread_buf_[idx + 1]); + + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx / 2]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx / 2]); + + tile.thread_buf_[idx] = type_convert(left * cos - right * sin); + tile.thread_buf_[idx + 1] = type_convert(right * cos + left * sin); + }); + } + } + else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + if(thread_end <= rotary_dim) + { + const bool is_left = (thread_end <= (rotary_dim / 2)); + + move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); + auto other_tile = load_tile(other_window); + + move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_cos_tile = load_tile(rotary_cos_window); + + move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); + static_for<0, thread_buffer_size, 1>{}([&](auto idx) { + const auto curr = type_convert(tile.thread_buf_[idx]); + const auto other = type_convert(other_tile.thread_buf_[idx]); + + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx]); + + tile.thread_buf_[idx] = + type_convert(curr * cos + other * (is_left ? -sin : sin)); + }); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/page_block_navigator.hpp b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp new file mode 100644 index 0000000000..e8abdc579b --- /dev/null +++ b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" + +namespace ck_tile { + +// assume that we have only 1 page-block/tensor view +template +struct TrivialPageBlockNavigator +{ + using DataType = typename TensorView::DataType; + using WindowOrigin = multi_index<2>; + + CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_) + : tensor_view(tensor_view_) + { + } + + template + CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin) const + { + return make_tuple(/*block_index=*/0, + ck_tile::make_tile_window(tensor_view, window_lengths, window_origin)); + } + + template + CK_TILE_HOST_DEVICE constexpr auto + make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin, + const TileDistribution& tile_distribution) const + { + return make_tuple( + /*block_index=*/0, + ck_tile::make_tile_window( + tensor_view, window_lengths, window_origin, tile_distribution)); + } + + template + CK_TILE_HOST_DEVICE static index_t + move_tile_window(index_t /*block_index*/, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) + { + ck_tile::move_tile_window(tile_window, step); + + return /*block_index=*/0; + } + + CK_TILE_HOST_DEVICE static constexpr WindowOrigin + to_local_window_origin(const WindowOrigin& global_window_origin) + { + return global_window_origin; + } + + CK_TILE_HOST_DEVICE static constexpr WindowOrigin + to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin) + { + return local_window_origin; + } + + private: + TensorView tensor_view; +}; + +// default page-block navigator, assume that tensor view size is same as page-block size or smaller +// if tile window on last page-block +template +struct PageBlockNavigator +{ + using DataType = DataType_; + static_assert(std::is_same_v); + static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window"); + using WindowOrigin = multi_index<2>; + + CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t* physical_blocks_, + long_index_t block_stride_, + long_index_t fixed_offset_, + const int32_t* physical_block_indices_, + index_t num_blocks_, + index_t page_block_size_, + const TensorView& complete_view_, + const TensorView& last_view_) + : physical_blocks(reinterpret_cast(physical_blocks_)), + block_stride(block_stride_), + fixed_offset(fixed_offset_), + physical_block_indices(physical_block_indices_), + num_blocks(num_blocks_), + page_block_size(page_block_size_), + complete_view(complete_view_), + last_view(last_view_) + { + } + + template + CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin) const + { + const index_t block_index = get_block_index(window_origin); + const WindowOrigin local_window_origin = to_local_window_origin(window_origin); + + auto new_tile_window = + ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view, + window_lengths, + local_window_origin); + new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index)); + + return make_tuple(block_index, new_tile_window); + } + + template + CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin, + const TileDistribution& tile_distribution) const + { + const index_t block_index = get_block_index(window_origin); + const WindowOrigin local_window_origin = to_local_window_origin(window_origin); + + auto new_tile_window = + ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view, + window_lengths, + local_window_origin, + tile_distribution); + new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index)); + + return make_tuple(block_index, new_tile_window); + } + + template + CK_TILE_HOST_DEVICE index_t + move_tile_window(index_t block_index, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) const + { + + ck_tile::move_tile_window(tile_window, step); + + const WindowOrigin global_window_origin = + to_global_window_origin(block_index, tile_window.get_window_origin()); + const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin); + + const index_t new_block_index = get_block_index(global_window_origin); + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); + tile_window.set_window_origin(local_window_origin); + tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); + + return new_block_index; + } + + CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const + { + return block_index == num_blocks - 1; + } + + template + CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index, + const TileWindow& tile_window) const + { + const index_t origin = tile_window.get_window_origin().at(number{}); + const index_t length = tile_window.get_window_lengths().at(number{}); + return (block_index < num_blocks - 1) && (page_block_size < origin + length); + } + + template + CK_TILE_HOST_DEVICE void + move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const + { + const multi_index<2> step = [&]() { + const index_t origin_diff = (block_index - new_block_index) * page_block_size; + if constexpr(VirtualDim == 0) + { + return make_multi_index(origin_diff, 0); + } + else + { + return make_multi_index(0, origin_diff); + } + }(); + + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); + tile_window.set_window_origin(tile_window.get_window_origin() + step); + tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); + } + + CK_TILE_HOST_DEVICE WindowOrigin + to_local_window_origin(const WindowOrigin& global_window_origin) const + { + if constexpr(VirtualDim == 0) + { + const index_t length = global_window_origin.at(number<0>{}); + const index_t num_complete_blocks = integer_divide_floor(length, page_block_size); + return make_multi_index(length - page_block_size * num_complete_blocks, + global_window_origin.at(number<1>{})); + } + else + { + const index_t length = global_window_origin.at(number<1>{}); + const index_t num_complete_blocks = integer_divide_floor(length, page_block_size); + return make_multi_index(global_window_origin.at(number<0>{}), + length - page_block_size * num_complete_blocks); + } + } + + CK_TILE_HOST_DEVICE WindowOrigin + to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const + { + if constexpr(VirtualDim == 0) + { + return make_multi_index(block_index * page_block_size + + local_window_origin.at(number<0>{}), + local_window_origin.at(number<1>{})); + } + else + { + return make_multi_index(local_window_origin.at(number<0>{}), + block_index * page_block_size + + local_window_origin.at(number<1>{})); + } + } + + private: + CK_TILE_HOST_DEVICE + DataType* get_block_ptr(index_t block_index) const + { + return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset; + } + + CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const + { + return integer_divide_floor(global_window_origin.at(number{}), page_block_size); + } + + DataType* physical_blocks; + long_index_t block_stride; + long_index_t fixed_offset; + + const int32_t* physical_block_indices; + index_t num_blocks; + index_t page_block_size; + + TensorView complete_view; + TensorView last_view; +}; + +template +CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view) +{ + return TrivialPageBlockNavigator(tensor_view); +} + +template +CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t* physical_blocks, + long_index_t block_stride, + long_index_t fixed_offset, + const int32_t* physical_block_indices, + index_t num_blocks, + index_t page_block_size, + const TensorView& complete_view, + const TensorView& last_view) +{ + return PageBlockNavigator(physical_blocks, + block_stride, + fixed_offset, + physical_block_indices, + num_blocks, + page_block_size, + complete_view, + last_view); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp new file mode 100644 index 0000000000..d598f97433 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -0,0 +1,679 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include +#include + +namespace ck_tile { + +template +struct FmhaFwdAppendKVKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE; + static constexpr bool kIsPagedKV = FmhaPipeline::kIsPagedKV; + + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s::name) + "_" + "b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" + + _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + + (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr::name)) + + (kIsPagedKV ? "_pagedkv" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct EmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct BasicKargs + { + void* q_ptr; + void* k_ptr; + const void* knew_ptr; + void* v_ptr; + const void* vnew_ptr; + + const int32_t* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t seqlen_knew; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_knew; + ck_tile::index_t stride_v; + ck_tile::index_t stride_vnew; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_knew; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_vnew; + + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_knew; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_vnew; + }; + + struct RoPEKargs + { + const void* rotary_cos_ptr; + const void* rotary_sin_ptr; + ck_tile::index_t rotary_dim; + bool has_mask; + }; + + struct PageBlockTableKargs + { + const int32_t* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + ck_tile::index_t page_block_size; + }; + + struct CacheBatchIdxKargs + { + const int32_t* cache_batch_idx; + }; + + struct Kargs : BasicKargs, + std::conditional_t>, + std::conditional_t + { + }; + + __host__ static constexpr Kargs MakeKargs(void* q_ptr, + void* k_ptr, + const void* knew_ptr, + void* v_ptr, + const void* vnew_ptr, + ck_tile::index_t seqlen_q, + const void* seqlen_k_ptr, + ck_tile::index_t seqlen_knew, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + const void* rotary_cos_ptr, + const void* rotary_sin_ptr, + ck_tile::index_t rotary_dim, + bool has_mask, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, + const void* cache_batch_idx, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_knew, + ck_tile::index_t stride_v, + ck_tile::index_t stride_vnew, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_knew, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_vnew, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_knew, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_vnew) + { + Kargs kargs{ + {q_ptr, + k_ptr, + knew_ptr, + v_ptr, + vnew_ptr, + reinterpret_cast(seqlen_k_ptr), + seqlen_q, + -1, // seqlen_k will be updated by content of seqlen_k_ptr + seqlen_knew, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + stride_q, + stride_k, + stride_knew, + stride_v, + stride_vnew, + nhead_stride_q, + nhead_stride_k, + nhead_stride_knew, + nhead_stride_v, + nhead_stride_vnew, + batch_stride_q, + batch_stride_k, + batch_stride_knew, + batch_stride_v, + batch_stride_vnew}, // args for common karg + {}, // placeholder for rope + {} // placeholder for paged-block table or cache_batch_idx + }; + + if constexpr(kApplyRoPE) + { + kargs.rotary_cos_ptr = rotary_cos_ptr; + kargs.rotary_sin_ptr = rotary_sin_ptr; + kargs.rotary_dim = rotary_dim; + kargs.has_mask = has_mask; + } + + if constexpr(kIsPagedKV) + { + kargs.block_table_ptr = reinterpret_cast(block_table_ptr); + kargs.batch_stride_block_table = batch_stride_block_table; + kargs.page_block_size = page_block_size; + } + else + { + kargs.cache_batch_idx = reinterpret_cast(cache_batch_idx); + } + + return kargs; + } + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_knew) + { + return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, seqlen_knew); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // divide problem + const auto [i_tile, i_nhead, i_batch] = TilePartitioner{}(); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0); + const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0); + + const index_t i_cache_batch = [&, i_batch_ = i_batch] { + if constexpr(kIsPagedKV) + { + return i_batch_; + } + else + { + return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_] + : i_batch_); + } + }(); + + const long_index_t batch_offset_q = + static_cast(i_batch) * kargs.batch_stride_q; + const long_index_t batch_offset_k = + static_cast(i_cache_batch) * kargs.batch_stride_k; + const long_index_t batch_offset_knew = + static_cast(i_batch) * kargs.batch_stride_knew; + const long_index_t batch_offset_v = + static_cast(i_cache_batch) * kargs.batch_stride_v; + const long_index_t batch_offset_vnew = + static_cast(i_batch) * kargs.batch_stride_vnew; + + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + + // for simplicity, batch stride we just modify the pointer + QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const KDataType* knew_ptr = + reinterpret_cast(kargs.knew_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_knew + + batch_offset_knew; + VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + const VDataType* vnew_ptr = + reinterpret_cast(kargs.vnew_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_vnew + + batch_offset_vnew; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + const auto make_k_dram = [&](KDataType* data, index_t height) { + const auto k_dram_naive = make_naive_tensor_view( + data, // will update this pointer if using paged-kvcache + make_tuple(height, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }; + const auto k_dram = [&]() { + if constexpr(kIsPagedKV) + { + return make_k_dram(nullptr, kargs.page_block_size); + } + else + { + return make_k_dram(k_ptr, kargs.seqlen_k + kargs.seqlen_knew); + } + }(); + + const auto knew_dram = [&]() { + const auto knew_dram_naive = make_naive_tensor_view( + knew_ptr, + make_tuple(kargs.seqlen_knew, kargs.hdim_q), + make_tuple(kargs.stride_knew, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + knew_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + const auto make_v_dram = [&](VDataType* data, index_t length) { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + data, // will update this pointer if using paged-kvcache + make_tuple(length, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(length)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + data, // will update this pointer if using paged-kvcache + make_tuple(kargs.hdim_v, length), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }; + const auto v_dram = [&]() { + if constexpr(kIsPagedKV) + { + return make_v_dram(nullptr, kargs.page_block_size); + } + else + { + return make_v_dram(v_ptr, kargs.seqlen_k + kargs.seqlen_knew); + } + }(); + + const auto vnew_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto vnew_dram_naive = make_naive_tensor_view( + vnew_ptr, + make_tuple(kargs.seqlen_knew, kargs.hdim_v), + make_tuple(kargs.stride_vnew, 1), + number{}, + number<1>{}); + + const auto vnew_dram_transposed = transform_tensor_view( + vnew_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_knew)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view( + vnew_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto vnew_dram_naive = make_naive_tensor_view( + vnew_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_knew), + make_tuple(kargs.stride_vnew, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + vnew_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + constexpr auto q_rotary_cos_sin_dram_window_lengths = + make_tuple(number{}, number{}); + const auto q_rotary_cos_dram_window = [&]() { + if constexpr(kApplyRoPE) + { + const auto rotary_cos_dram_native = + make_naive_tensor_view( + reinterpret_cast(kargs.rotary_cos_ptr) + + kargs.seqlen_k * (kargs.rotary_dim / 2), + make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2), + make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1), + number<8>{}, + number<1>{}); + + const auto rotary_cos_dram = [&]() { + return pad_tensor_view(rotary_cos_dram_native, + q_rotary_cos_sin_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window( + rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths); + } + }(); + const auto q_rotary_sin_dram_window = [&]() { + if constexpr(kApplyRoPE) + { + const auto rotary_sin_dram_native = + make_naive_tensor_view( + reinterpret_cast(kargs.rotary_sin_ptr) + + kargs.seqlen_k * (kargs.rotary_dim / 2), + make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2), + make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1), + number<8>{}, + number<1>{}); + + const auto rotary_sin_dram = [&]() { + return pad_tensor_view(rotary_sin_dram_native, + q_rotary_cos_sin_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window( + rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths); + } + }(); + + constexpr auto knew_rotary_cos_sin_dram_window_lengths = + make_tuple(number{}, number{}); + const auto knew_rotary_cos_dram_window = [&]() { + if constexpr(kApplyRoPE) + { + const auto rotary_cos_dram_native = + make_naive_tensor_view( + reinterpret_cast(kargs.rotary_cos_ptr) + + kargs.seqlen_k * (kargs.rotary_dim / 2), + make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2), + make_tuple(kargs.rotary_dim / 2, 1), + number<8>{}, + number<1>{}); + + const auto rotary_cos_dram = [&]() { + return pad_tensor_view(rotary_cos_dram_native, + knew_rotary_cos_sin_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window( + rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0}); + } + else + { + return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths); + } + }(); + const auto knew_rotary_sin_dram_window = [&]() { + if constexpr(kApplyRoPE) + { + const auto rotary_sin_dram_native = + make_naive_tensor_view( + reinterpret_cast(kargs.rotary_sin_ptr) + + kargs.seqlen_k * (kargs.rotary_dim / 2), + make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2), + make_tuple(kargs.rotary_dim / 2, 1), + number<8>{}, + number<1>{}); + + const auto rotary_sin_dram = [&]() { + return pad_tensor_view(rotary_sin_dram_native, + knew_rotary_cos_sin_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window( + rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0}); + } + else + { + return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths); + } + }(); + + auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(kIsPagedKV) + { + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch_ * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size); + + const long_index_t fixed_offset = + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k; + + return make_page_block_navigator( + kargs.k_ptr, + kargs.batch_stride_k, + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size, + k_dram, + make_k_dram(nullptr, + (kargs.seqlen_k + kargs.seqlen_knew) - + (num_blocks - 1) * kargs.page_block_size)); + } + else + { + return make_page_block_navigator(k_dram); + } + }(); + + auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(kIsPagedKV) + { + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch_ * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size); + + const long_index_t fixed_offset = + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v; + + return make_page_block_navigator( + kargs.v_ptr, + kargs.batch_stride_v, + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size, + v_dram, + make_v_dram(nullptr, + (kargs.seqlen_k + kargs.seqlen_knew) - + (num_blocks - 1) * kargs.page_block_size)); + } + else + { + return make_page_block_navigator(v_dram); + } + }(); + + auto q_dram_window = + make_tile_window(q_dram, + make_tuple(number{}, number{}), + {i_m0, 0}); + + const bool skip_append_kv = kargs.seqlen_knew <= i_n0; + // window origin = (0, 0) if no work to do for current block + auto [i_page_block_k, k_dram_window] = k_page_block_navigator.make_tile_window( + make_tuple(number{}, number{}), + {!skip_append_kv * (kargs.seqlen_k + i_n0), 0}); + + auto knew_dram_window = + make_tile_window(knew_dram, + make_tuple(number{}, number{}), + {i_n0, 0}); + + // window origin = (0, 0) if no work to do for current block + auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( + make_tuple(number{}, number{}), + {0, !skip_append_kv * (kargs.seqlen_k + i_n0)}); + + auto vnew_dram_window = + make_tile_window(vnew_dram, + make_tuple(number{}, number{}), + {0, i_n0}); + + if constexpr(kApplyRoPE) + { + FmhaPipeline{}(q_dram_window, + k_dram_window, + i_page_block_k, + k_page_block_navigator, + knew_dram_window, + v_dram_window, + i_page_block_v, + v_page_block_navigator, + vnew_dram_window, + q_rotary_cos_dram_window, + q_rotary_sin_dram_window, + knew_rotary_cos_dram_window, + knew_rotary_sin_dram_window, + kargs.rotary_dim, + kargs.seqlen_q <= i_m0, + skip_append_kv); + } + else + { + FmhaPipeline{}(q_dram_window, + k_dram_window, + i_page_block_k, + k_page_block_navigator, + knew_dram_window, + v_dram_window, + i_page_block_v, + v_page_block_navigator, + vnew_dram_window, + q_rotary_cos_dram_window, + q_rotary_sin_dram_window, + knew_rotary_cos_dram_window, + knew_rotary_sin_dram_window, + 0, // rotary_dim not used + kargs.seqlen_q <= i_m0, + skip_append_kv); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp new file mode 100644 index 0000000000..97c9b960c2 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaFwdAppendKVTilePartitioner +{ + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + + static_assert(kK0 == kN1); + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_knew) + { + // TODO: this may need tuning + return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0), + ck_tile::integer_divide_ceil(seqlen_knew, kN0)), + nhead, + batch_size); + } + + CK_TILE_DEVICE auto operator()() + { + const index_t i_tile = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_tile, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 36c10db79c..22978f1a3c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -32,8 +32,6 @@ struct FmhaFwdSplitKVKernel using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; using BiasDataType = ck_tile::remove_cvref_t; - using RandValOutputDataType = - ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; using OaccDataType = remove_cvref_t; @@ -46,8 +44,10 @@ struct FmhaFwdSplitKVKernel static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; - static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; + static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV), + "paged-kvcache only supported by batch mode kernels"); using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -85,8 +85,8 @@ struct FmhaFwdSplitKVKernel "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" ); #undef _SS_ #undef _TS_ // clang-format on @@ -110,7 +110,6 @@ struct FmhaFwdSplitKVKernel void* o_acc_ptr; ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -136,6 +135,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_o_acc; ck_tile::index_t split_stride_lse_acc; @@ -173,32 +173,16 @@ struct FmhaFwdSplitKVKernel float scale_p; }; - struct CommonDropoutKargs + struct PageBlockTableKargs { - void init_dropout(const float p_drop, - const std::tuple& drop_seed_offset) - { - float p_undrop = 1.0 - p_drop; - p_undrop_in_uint8_t = - uint8_t(std::floor(p_undrop * std::numeric_limits::max())); - rp_undrop = 1.0 / p_undrop; - - drop_seed = std::get<0>(drop_seed_offset); - drop_offset = std::get<1>(drop_seed_offset); - } - float rp_undrop = 1; - uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - bool is_store_randval = false; - uint64_t drop_seed = 1; - uint64_t drop_offset = 0; - void* rand_val_ptr = nullptr; - - ck_tile::index_t stride_randval = 0; - ck_tile::index_t nhead_stride_randval = 0; + const int32_t* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + ck_tile::index_t page_block_size; }; - struct BatchModeDropoutKargs : CommonDropoutKargs + + struct CacheBatchIdxKargs { - ck_tile::index_t batch_stride_randval = 0; + const int32_t* cache_batch_idx; }; struct BatchModeKargs @@ -210,12 +194,13 @@ struct FmhaFwdSplitKVKernel EmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t { + const int32_t* seqlen_k_ptr; + ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_lse_acc; }; struct GroupModeKargs @@ -226,12 +211,14 @@ struct FmhaFwdSplitKVKernel AlibiKargs, EmptyKargs<0>>>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; + + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; }; using Kargs = std::conditional_t; @@ -242,48 +229,45 @@ struct FmhaFwdSplitKVKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, - void* rand_val_ptr, void* lse_acc_ptr, void* o_acc_ptr, ck_tile::index_t batch, - ck_tile::index_t max_seqlen_q, ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, + ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified + const void* seqlen_k_ptr, // only used for (paged-) kvcache ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, + const void* cache_batch_idx, float scale_s, float scale_p, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) + ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, k_ptr, @@ -291,7 +275,6 @@ struct FmhaFwdSplitKVKernel lse_acc_ptr, o_acc_ptr, batch, - max_seqlen_q, seqlen_q, seqlen_k, hdim_q, @@ -313,17 +296,18 @@ struct FmhaFwdSplitKVKernel nhead_stride_v, nhead_stride_lse_acc, nhead_stride_o_acc, + batch_stride_lse_acc, batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for fp8_static_quant args - {}, // placeholder for dropout + {}, // placeholder for paged-block table or cache_batch_idx + reinterpret_cast(seqlen_k_ptr), batch_stride_q, batch_stride_k, - batch_stride_v, - batch_stride_lse_acc}; + batch_stride_v}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -347,14 +331,15 @@ struct FmhaFwdSplitKVKernel { kargs.scale_p = scale_p; } - if constexpr(kHasDropout) + if constexpr(kIsPagedKV) + { + kargs.block_table_ptr = reinterpret_cast(block_table_ptr); + kargs.batch_stride_block_table = batch_stride_block_table; + kargs.page_block_size = page_block_size; + } + else { - kargs.init_dropout(p_drop, drop_seed_offset); - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.batch_stride_randval = batch_stride_randval; - kargs.is_store_randval = s_randval; + kargs.cache_batch_idx = reinterpret_cast(cache_batch_idx); } return kargs; @@ -366,11 +351,9 @@ struct FmhaFwdSplitKVKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, - void* rand_val_ptr, void* lse_acc_ptr, void* o_acc_ptr, ck_tile::index_t batch, - ck_tile::index_t max_seqlen_q, const void* seqstart_q_ptr, const void* seqstart_k_ptr, const void* seqlen_k_ptr, @@ -385,24 +368,22 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) + ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, k_ptr, @@ -410,9 +391,8 @@ struct FmhaFwdSplitKVKernel lse_acc_ptr, o_acc_ptr, batch, - max_seqlen_q, - -1, // seqlen will be updated by another pointer - -1, // + -1, // seqlen_q will be updated by another pointer + -1, // seqlen_k will be updated by another pointer hdim_q, hdim_v, num_head_q, @@ -432,16 +412,18 @@ struct FmhaFwdSplitKVKernel nhead_stride_v, nhead_stride_lse_acc, nhead_stride_o_acc, + batch_stride_lse_acc, batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for fp8_static_quant args - {}, // placeholder for dropout reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_k_ptr)}; + reinterpret_cast(seqlen_k_ptr), + batch_stride_k, + batch_stride_v}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -464,14 +446,6 @@ struct FmhaFwdSplitKVKernel { kargs.scale_p = scale_p; } - if constexpr(kHasDropout) - { - kargs.init_dropout(p_drop, drop_seed_offset); - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.is_store_randval = s_randval; - } return kargs; } @@ -508,7 +482,6 @@ struct FmhaFwdSplitKVKernel long_index_t batch_offset_k = 0; long_index_t batch_offset_v = 0; long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; long_index_t batch_offset_lse_acc = 0; const long_index_t batch_offset_o_acc = static_cast(i_batch) * kargs.batch_stride_o_acc; @@ -534,14 +507,9 @@ struct FmhaFwdSplitKVKernel { batch_offset_bias = query_start * kargs.stride_bias + key_start; } - if constexpr(kHasDropout) - { - batch_offset_randval = query_start * kargs.stride_randval; - } // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; // # of required blocks is different in each groups, terminate unnecessary blocks // earlier @@ -556,24 +524,36 @@ struct FmhaFwdSplitKVKernel } else { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; } } else { + const index_t i_cache_batch = [&, i_batch_ = i_batch] { + if constexpr(kIsPagedKV) + { + return i_batch_; + } + else + { + return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_] + : i_batch_); + } + }(); + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_k = static_cast(i_cache_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_cache_batch) * kargs.batch_stride_v; batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } - if constexpr(kHasDropout) + + if(kargs.seqlen_k_ptr != nullptr) { - batch_offset_randval = - static_cast(i_batch) * kargs.batch_stride_randval; + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } } @@ -589,6 +569,7 @@ struct FmhaFwdSplitKVKernel reinterpret_cast(kargs.v_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + batch_offset_v; + OaccDataType* o_acc_ptr = reinterpret_cast(kargs.o_acc_ptr) + static_cast(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc + i_split * kargs.split_stride_o_acc; @@ -616,10 +597,11 @@ struct FmhaFwdSplitKVKernel sequence{}); } }(); - const auto k_dram = [&]() { + + const auto make_k_dram = [&](const KDataType* data, index_t height) { const auto k_dram_naive = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), + data, // will update this pointer if using paged-kvcache + make_tuple(height, kargs.hdim_q), make_tuple(kargs.stride_k, 1), number{}, number<1>{}); @@ -628,13 +610,24 @@ struct FmhaFwdSplitKVKernel k_dram_naive, make_tuple(number{}, number{}), sequence{}); + }; + const auto k_dram = [&]() { + if constexpr(kIsPagedKV) + { + return make_k_dram(nullptr, kargs.page_block_size); + } + else + { + return make_k_dram(k_ptr, kargs.seqlen_k); + } }(); - const auto v_dram = [&]() { + + const auto make_v_dram = [&](const VDataType* data, index_t length) { if constexpr(std::is_same_v) { const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_v), + data, // will update this pointer if using paged-kvcache + make_tuple(length, kargs.hdim_v), make_tuple(kargs.stride_v, 1), number{}, number<1>{}); @@ -642,7 +635,7 @@ struct FmhaFwdSplitKVKernel const auto v_dram_transposed = transform_tensor_view(v_dram_naive, make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), + make_pass_through_transform(length)), make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -654,8 +647,8 @@ struct FmhaFwdSplitKVKernel else { const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen_k), + data, // will update this pointer if using paged-kvcache + make_tuple(kargs.hdim_v, length), make_tuple(kargs.stride_v, 1), number{}, number<1>{}); @@ -665,6 +658,76 @@ struct FmhaFwdSplitKVKernel make_tuple(number{}, number{}), sequence{}); } + }; + const auto v_dram = [&]() { + if constexpr(kIsPagedKV) + { + return make_v_dram(nullptr, kargs.page_block_size); + } + else + { + return make_v_dram(v_ptr, kargs.seqlen_k); + } + }(); + + auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(kIsPagedKV) + { + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch_ * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + + const long_index_t fixed_offset = + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k; + + return make_page_block_navigator( + kargs.k_ptr, + kargs.batch_stride_k, + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size, + k_dram, + make_k_dram(nullptr, + kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size)); + } + else + { + return make_page_block_navigator(k_dram); + } + }(); + + auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(kIsPagedKV) + { + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch_ * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + + const long_index_t fixed_offset = + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v; + + return make_page_block_navigator( + kargs.v_ptr, + kargs.batch_stride_v, + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size, + v_dram, + make_v_dram(nullptr, + kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size)); + } + else + { + return make_page_block_navigator(v_dram); + } }(); auto q_dram_window = make_tile_window( @@ -678,13 +741,11 @@ struct FmhaFwdSplitKVKernel }(), {i_m0, 0}); - auto k_dram_window = make_tile_window( - k_dram, make_tuple(number{}, number{}), {0, 0}); + auto k_dram_window_lengths = + make_tuple(number{}, number{}); + auto v_dram_window_lengths = + make_tuple(number{}, number{}); - auto v_dram_window = - make_tile_window(v_dram, - make_tuple(number{}, number{}), - {i_n1, 0}); /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// following copy capture of the 'i_nhead' if in C++20 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { @@ -741,62 +802,6 @@ struct FmhaFwdSplitKVKernel return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0}); }(); - // dropout - float rp_undrop = 1; - uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - uint64_t drop_seed = 0; - uint64_t drop_offset = 0; - bool is_store_randval = false; - - if constexpr(kHasDropout) - { - rp_undrop = kargs.rp_undrop; - p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; - drop_seed = kargs.drop_seed; - drop_offset = kargs.drop_offset; - is_store_randval = kargs.is_store_randval; - } - BlockDropout dropout(i_batch, - i_nhead, - kargs.num_head_q, - drop_seed, - drop_offset, - rp_undrop, - p_undrop_in_uint8_t, - is_store_randval); - - auto randval_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(kHasDropout) - { - RandValOutputDataType* rand_val_ptr = - reinterpret_cast(kargs.rand_val_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_randval + - batch_offset_randval; - - const auto randval_dram = [&]() { - const auto randval_dram_naive = - make_naive_tensor_view( - rand_val_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_randval, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(randval_dram_window_lengths); - } - }(); - FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( @@ -823,16 +828,16 @@ struct FmhaFwdSplitKVKernel #endif if constexpr(kHasMask) { - return make_alibi_from_lr_mask(slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); } else { - return Alibi{ + return Alibi{ slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; } } @@ -847,13 +852,14 @@ struct FmhaFwdSplitKVKernel { return FmhaPipeline{}(q_dram_window, identity{}, // q_element_func - k_dram_window, + k_dram_window_lengths, + k_page_block_navigator, identity{}, // k_element_func - v_dram_window, + v_dram_window_lengths, + v_page_block_navigator, identity{}, // v_element_func bias_dram_window, identity{}, // bias_element_func - randval_dram_window, lse_acc_dram_window, identity{}, // lse_element_func identity{}, // s_acc_element_func @@ -864,24 +870,23 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, - smem_ptr, - dropout); + smem_ptr); } else { return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, + k_dram_window_lengths, + k_page_block_navigator, + v_dram_window_lengths, + v_page_block_navigator, bias_dram_window, - randval_dram_window, lse_acc_dram_window, kargs.num_splits, i_split_, mask, position_encoding, kargs.scale_s, - smem_ptr, - dropout); + smem_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp new file mode 100644 index 0000000000..5c5dbb3a96 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" + +namespace ck_tile { + +template +struct BlockFmhaFwdAppendKVPipeline +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = typename Problem::QDataType; + using KDataType = typename Problem::KDataType; + using VDataType = typename Problem::VDataType; + + using VLayout = typename Problem::VLayout; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = Problem::kM0; + static constexpr index_t kN0 = Problem::kN0; + static constexpr index_t kK0 = Problem::kK0; + static constexpr index_t kN1 = Problem::kN1; + + static constexpr auto RotaryEnum = Problem::RotaryEnum; + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0 <= 32) + { + return 2; + } + else if constexpr(kK0 <= 64) + { + return 3; + } + else if constexpr(kK0 <= 128) + { + return 2; + } + else if constexpr(kK0 <= 256) + { + return 1; + } + } + }(); + + template + CK_TILE_HOST_DEVICE auto + operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile + const QElementFunction& q_element_func, + KDramBlockWindow& k_dram_block_window, // N0*K0 tile + index_t i_page_block_k, + const KPageBlockNavigator& k_page_block_navigator, + const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile + const KnewElementFunction& knew_element_func, + VDramBlockWindow& v_dram_block_window, // N1*N0 tile + index_t i_page_block_v, + const VPageBlockNavigator& v_page_block_navigator, + const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile + const VnewElementFunction& vnew_element_func, + const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window, + const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window, + const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window, + const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window, + index_t rotary_dim, + bool skip_rotate_q, + bool skip_rotate_append_kv) const + { + if(!skip_rotate_append_kv) + { + // append Knew to K + auto knew_window = make_tile_window( + knew_dram_block_window, Policy::template MakeKnewDramTileDistribution()); + + auto knew_tile = [&]() { + auto knew = load_tile(knew_window); + return tile_elementwise_in(knew_element_func, knew); + }(); + + // optionally apply rotary embedding to Knew + if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE) + { + auto rotary_cos_window = + make_tile_window(knew_rotary_cos_dram_block_window, + Policy::template MakeRotaryCosSinTileDistribution< + Problem, + /*IsRotaryCosSinForQ=*/false>()); + + auto rotary_sin_window = + make_tile_window(knew_rotary_sin_dram_block_window, + Policy::template MakeRotaryCosSinTileDistribution< + Problem, + /*IsRotaryCosSinForQ=*/false>()); + + // We assume that each thread owns contiguous elements on head dimention. And we + // will use the distribution to enable/disable threads in order to override partial + // knew_tile content + auto [thread_start, thread_end] = + Policy::template GetKnewThreadRangeAlongK(); + ignore = thread_start; + + BlockRotaryEmbedding::apply(knew_tile, + knew_window, + rotary_cos_window, + rotary_sin_window, + rotary_dim, + thread_end); + } + + store_tile(k_dram_block_window, knew_tile); + + // write tile to another block if nesscary + if constexpr(kIsPagedKV) + { + if(k_page_block_navigator.is_cross_block(i_page_block_k, k_dram_block_window)) + { + k_page_block_navigator.move_to_block( + i_page_block_k, k_dram_block_window, i_page_block_k + 1); + store_tile(k_dram_block_window, knew_tile); + } + } + + // append Vnew to V + auto vnew_window = make_tile_window( + vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution()); + + auto vnew_tile = [&]() { + auto vnew = load_tile(vnew_window); + return tile_elementwise_in(vnew_element_func, vnew); + }(); + + store_tile(v_dram_block_window, vnew_tile); + + // write tile to another block if nesscary + if constexpr(kIsPagedKV) + { + if(v_page_block_navigator.is_cross_block(i_page_block_v, v_dram_block_window)) + { + v_page_block_navigator.move_to_block( + i_page_block_v, v_dram_block_window, i_page_block_v + 1); + store_tile(v_dram_block_window, vnew_tile); + } + } + } + + if(!skip_rotate_q) + { + // optionally apply rotary embedding to Q + if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE) + { + auto q_window = make_tile_window( + q_dram_block_window, Policy::template MakeQDramTileDistribution()); + + auto q_tile = [&]() { + auto q = load_tile(q_window); + return tile_elementwise_in(q_element_func, q); + }(); + + auto rotary_cos_window = + make_tile_window(q_rotary_cos_dram_block_window, + Policy::template MakeRotaryCosSinTileDistribution< + Problem, + /*IsRotaryCosSinForQ=*/true>()); + + auto rotary_sin_window = + make_tile_window(q_rotary_sin_dram_block_window, + Policy::template MakeRotaryCosSinTileDistribution< + Problem, + /*IsRotaryCosSinForQ=*/true>()); + + // We assume that each thread owns contiguous elements on head dimention. And we + // will use the distribution to enable/disable threads in order to override partial + // q_tile content + auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK(); + ignore = thread_start; + + BlockRotaryEmbedding::apply( + q_tile, q_window, rotary_cos_window, rotary_sin_window, rotary_dim, thread_end); + + store_tile(q_dram_block_window, q_tile); + } + } + } + + template + CK_TILE_HOST_DEVICE auto + operator()(QDramBlockWindow& q_dram_block_window, + KDramBlockWindow& k_dram_block_window, + index_t i_page_block_k, + const KPageBlockNavigator& k_page_block_navigator, + const KnewDramBlockWindow& knew_dram_block_window, + VDramBlockWindow& v_dram_block_window, + index_t i_page_block_v, + const VPageBlockNavigator& v_page_block_navigator, + const VnewDramBlockWindow& vnew_dram_block_window, + const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window, + const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window, + const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window, + const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window, + index_t rotary_dim, + bool skip_rotate_q, + bool skip_rotate_append_kv) const + { + return operator()(q_dram_block_window, + identity{}, + k_dram_block_window, + i_page_block_k, + k_page_block_navigator, + knew_dram_block_window, + identity{}, + v_dram_block_window, + i_page_block_v, + v_page_block_navigator, + vnew_dram_block_window, + identity{}, + q_rotary_cos_dram_block_window, + q_rotary_sin_dram_block_window, + knew_rotary_cos_dram_block_window, + knew_rotary_sin_dram_block_window, + rotary_dim, + skip_rotate_q, + skip_rotate_append_kv); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp new file mode 100644 index 0000000000..cf3f7466e7 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +namespace ck_tile { + +// This pipeline is qkv all located in LDS +struct BlockFmhaFwdAppendKVPipelineDefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + using QDataType = remove_cvref_t; + + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; + if constexpr(std::is_same_v) + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::kN0; + constexpr index_t kKPerBlock = Problem::kN1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + else + { + return 16 / sizeof(VDataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQNumElemsPerRead() + { + using DataType = typename Problem::QDataType; + + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + /// NOTICE: we might need to lower down this to support smaller rotary_dim + return 16 / sizeof(DataType); + } + else + { + return 16 / sizeof(DataType); + } + } + + template + CK_TILE_DEVICE static auto GetQThreadRangeAlongK() + { + static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE); + + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) + { + constexpr index_t KPerThread = GetQNumElemsPerRead(); + static_assert(Problem::kK0 % KPerThread == 0); + constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread; + index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread; + + return make_tuple(start_pos, start_pos + KPerThread); + } + else + { + constexpr index_t KPerThread = GetQNumElemsPerRead(); + static_assert(Problem::kK0 % KPerThread == 0); + constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread; + index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread; + + return make_tuple(start_pos, start_pos + KPerThread); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kKPerBlock = Problem::kK0; + + constexpr index_t KPerThread = GetQNumElemsPerRead(); + constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; + constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (NumWarps * MThreadPerWarp); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetKnewNumElemsPerRead() + { + using DataType = typename Problem::KDataType; + + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + /// NOTICE: we might need to lower down this to support smaller rotary_dim + return 16 / sizeof(DataType); + } + else + { + return 16 / sizeof(DataType); + } + } + + template + CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK() + { + static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE); + + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) + { + constexpr index_t KPerThread = GetKnewNumElemsPerRead(); + constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread; + index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread; + + return make_tuple(start_pos, start_pos + KPerThread); + } + else + { + constexpr index_t KPerThread = GetKnewNumElemsPerRead(); + constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread; + index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread; + + return make_tuple(start_pos, start_pos + KPerThread); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::kN0; + constexpr index_t kKPerBlock = Problem::kK0; + + constexpr index_t KPerThread = GetKnewNumElemsPerRead(); + constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; + constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVnewDramTileDistribution() + { + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::kN1; + constexpr index_t kKPerBlock = Problem::kN0; + + if constexpr(std::is_same_v) + { + + constexpr index_t NPerThread = 16 / sizeof(VDataType); + constexpr index_t NThreadPerBlock = kNPerBlock / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t KPerThread = kKPerBlock / (NumWarps * KThreadPerWarp); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 0>>{}); + } + else + { + constexpr index_t KPerThread = 16 / sizeof(VDataType); + constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; + constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetRotaryCosSinTileSize() + { + constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0); + + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + return make_tuple(number{}, number{}); + } + else + { + return make_tuple(number{}, number{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRotaryCosSinTileDistribution() + { + using DataType = std::conditional_t; + + constexpr auto TileSize = GetRotaryCosSinTileSize(); + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = TileSize[number<0>{}]; + constexpr index_t kKPerBlock = TileSize[number<1>{}]; + + constexpr index_t KPerThread = []() { + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + /// NOTICE: we might need to lower down this to support smaller rotary_dim + return 16 / sizeof(DataType); + } + else + { + return 8 / sizeof(DataType); + } + }(); + constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; + constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index a6d74b3885..b257b9e93d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" -#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -15,19 +14,18 @@ namespace ck_tile { template struct BlockFmhaFwdSplitKVPipelineQRKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using RandValOutputDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = true; // always store LSE (acc) - static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kStoreLSE = true; // always store LSE (acc) + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } template > && - std::is_same_v> && - std::is_same_v>, + std::is_same_v> && + std::is_same_v>, "wrong!"); static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN0 == KDramBlockWindowLengths{}[number<0>{}] && + kK0 == KDramBlockWindowLengths{}[number<1>{}] && + kN1 == VDramBlockWindowLengths{}[number<0>{}] && + kK1 == VDramBlockWindowLengths{}[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); @@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - // check early exit if masked and no work to do. if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) { - if(num_total_loop <= 0) + const index_t original_num_total_loop = + integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + if(original_num_total_loop <= 0) { if constexpr(kStoreLSE) { @@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } } - auto k_dram_block_window = - make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + // make sure the first tile is completely located in page-block + const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] { + if constexpr(kIsPagedKV) + { + return kN0 * integer_divide_floor(seqlen_k_start_, kN0); + } + else + { + return seqlen_k_start_; + } + }(); + const index_t num_total_loop = + integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0); + + auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( + k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window( bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + {bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); - - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? - Policy::template MakeVDramTileDistribution()); + auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( + v_dram_block_window_lengths, + {0, adjusted_seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { // STAGE 1, QK gemm auto k_dram_window = make_tile_window( - k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), + k_dram_block_window, Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load auto k_block_tile = load_tile(k_dram_window); { + // moving k_dram_window is an in-page-block operation, so there is + // no need to invoke k_page_block_navigator.move_tile_window() here. move_tile_window(k_dram_window, {0, kK0}); clear_tile(s_acc); // initialize C store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); @@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) { - const auto k_origin = k_dram_block_window.get_window_origin(); + const auto k_origin = k_page_block_navigator.to_global_window_origin( + i_page_block_k, k_dram_block_window.get_window_origin()); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); s_acc = tile_elementwise_in(s_acc_element_func, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { @@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } move_tile_window(bias_dram_window, {0, kN0}); - /// TODO: only check in last iteration without increasing code size + /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) { - const auto k_origin = k_dram_block_window.get_window_origin(); + const auto k_origin = k_page_block_navigator.to_global_window_origin( + i_page_block_k, k_dram_block_window.get_window_origin()); set_tile_if(s_acc, -numeric::infinity(), - [&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { + [&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end]( + auto tile_idx) { const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return seqlen_k_end_ <= col; + if constexpr(kIsPagedKV) + { + return col < seqlen_k_start_ || seqlen_k_end_ <= col; + } + else + { + return seqlen_k_end_ <= col; + } }); } if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = k_dram_block_window.get_window_origin(); + const auto k_origin = k_page_block_navigator.to_global_window_origin( + i_page_block_k, k_dram_block_window.get_window_origin()); bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), k_origin.at(number<0>{}), number{}, @@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS }); }); - if constexpr(kHasDropout) - { - dropout.Run( - smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); - } - block_sync_lds(); if constexpr(std::is_same_v) { @@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } - move_tile_window(v_dram_window, {0, kK1}); + i_page_block_v = + v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1}); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); @@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS // STAGE 3, KV gemm if constexpr(k1_loops > 1) { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v + static_for<0, k1_loops - 1, 1>{}([&, + &i_page_block_v_ = i_page_block_v, + &v_dram_window_ = v_dram_window](auto i_k1) { + const auto v = load_tile(v_dram_window_); // load next v block_sync_lds(); gemm_1(o_acc, get_slice_tile( @@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v)); // store next v } - move_tile_window(v_dram_window, {0, kK1}); + i_page_block_v_ = v_page_block_navigator.move_tile_window( + i_page_block_v_, v_dram_window_, {0, kK1}); }); } // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); + i_page_block_k = k_page_block_navigator.move_tile_window( + i_page_block_k, k_dram_block_window, {kN0, 0}); // tail { block_sync_lds(); @@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } template CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile + const KPageBlockNavigator& k_page_block_navigator, + const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile + const VPageBlockNavigator& v_page_block_navigator, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile index_t num_splits, index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr, - BlockDropout& dropout) const + void* smem_ptr) const { return operator()(q_dram_block_window_tmp, identity{}, - k_dram_block_window_tmp, + k_dram_block_window_lengths, + k_page_block_navigator, identity{}, - v_dram_block_window_tmp, + v_dram_block_window_lengths, + v_page_block_navigator, identity{}, bias_dram_block_window_tmp, identity{}, - randval_dram_block_window_tmp, lse_acc_dram_block_window_tmp, identity{}, identity{}, @@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS mask, position_encoding, scale_s, - smem_ptr, - dropout); + smem_ptr); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp deleted file mode 100644 index ae363a4978..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp +++ /dev/null @@ -1,770 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" -#include "ck_tile/ops/fmha/block/block_dropout.hpp" -#include "ck_tile/ops/reduce/block/block_reduce.hpp" - -namespace ck_tile { - -// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) -template -struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync -{ - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using RandValOutputDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; - - using BlockFmhaShape = remove_cvref_t; - using VLayout = remove_cvref_t; - static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once - static_assert(kQLoadOnce == Policy::QLoadOnce); - - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; - - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) - // only need special care about seq_k padding (oob need set -INF of p instead of zero) - static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && - Problem::kPadHeadDimV == true); - static constexpr bool kPadSeqLenQ = true; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) - static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = true; // always store LSE (acc) - static constexpr bool kHasDropout = false; // ignore this flag - static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; - - // last dimension vector length used to create tensor view(and decide buffer_load vector length) - // ... together with tensor distribution. tensor dist should able to overwrite this - static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); - static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = []() { - if constexpr(std::is_same_v) - return Policy::template GetAlignmentV(); - else - return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); - }(); - static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - -#if CK_TILE_FMHA_FWD_FAST_EXP2 - static constexpr auto R_LOG2E = 1.0 / log2e_v; -#endif - - static constexpr index_t kBlockPerCu = []() { - if constexpr(Problem::kBlockPerCu != -1) - return Problem::kBlockPerCu; - else - { - if constexpr(kK0BlockLength <= 32) - { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && - FmhaMask::IsMasking) - return 1; - else - return 2; - } - else if constexpr(kK0BlockLength <= 64) - { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - return 2; - else - return 3; - } - else if constexpr(kK0BlockLength <= 128) - { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - return 1; - else - return 2; - } - else if constexpr(kK0BlockLength <= 256) - { - return 1; - } - } - }(); - - static constexpr const char* name = "qr_async"; - - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const KElementFunction& /*k_element_func*/, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction& bias_element_func, - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile - const LSEaccElementFunction& lse_acc_element_func, - const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, - index_t num_splits, - index_t i_split, - FmhaMask mask, - PositionEncoding position_encoding, - float scale_s, - void* smem_ptr, - BlockDropout& dropout) const - { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); - - // K tile in LDS - auto k_lds_ptr = reinterpret_cast(smem_ptr); - auto k_lds_store = generate_tuple( - [&](auto i_buf) { - return make_tile_window( - make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), - Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), - {0, 0, 0}); - }, - number{}); - -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - auto k_lds_load = generate_tuple( - [&](auto i_buf) { - return make_tile_window( - make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), - Policy::template MakeKLdsLoadBlockDescriptor(i_buf).get_lengths(), - {0, 0}); - }, - number{}); -#else - auto k_lds_Load_view = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); - - auto k_lds_load = - make_tile_window(k_lds_Load_view, - Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), - {0, 0}); -#endif - - // V tile in LDS - auto v_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_window = make_tile_window( - v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); - - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramTileDistribution()); - - // TODO: we use async Copy for K, which is inline asm - // a side effect is we have to use inline asm for q as well - auto q = decltype(load_tile(q_dram_window)){}; - set_tile(q, number<0>{}); // use per-dword clear to avoid scratch - load_tile_raw(q, q_dram_window); - __builtin_amdgcn_sched_barrier(0); - - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - auto s_acc = SaccBlockTileType{}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = decltype(cast_tile(s_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); - - using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - - // init Oacc, M, L - auto o_acc = OaccBlockTileType{}; - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; - - clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); - - __builtin_amdgcn_sched_barrier(0); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) - { - if(num_total_loop <= 0) - { - if constexpr(kStoreLSE) - { - auto lse_acc = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse_acc, -numeric::infinity()); - - store_tile(lse_acc_dram_window_tmp, - tile_elementwise_in(lse_acc_element_func, lse_acc)); - } - buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) - // otherwise will have compute error(maybe compiler bug?) - - // Note: here occ are all cleard, return it - return o_acc; - } - __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check - } - - auto k_dram_block_window = - make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); - - auto k_dram_window = make_tile_window( - k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_window = make_tile_window( - bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); - - auto randval_dram_window = dropout.MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); - - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? - Policy::template MakeVDramTileDistribution()); - - // prefetch K tile - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - __builtin_amdgcn_sched_barrier(0); - - buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); - (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 - // auto q_tile = q; // tile_elementwise_in(q_element_func, q); - - index_t i_total_loops = 0; - constexpr index_t k0_loops = kK0BlockLength / kK0; - constexpr index_t k1_loops = kN0 / kK1; - - static_assert(1 <= k0_loops); - static_assert(1 <= k1_loops); - // main loop - do - { - // STAGE 1, QK gemm - clear_tile(s_acc); // initialize C - if constexpr(k0_loops > 1) - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - async_load_tile_raw(k_lds_store(number{})>{}), - k_dram_window); - if constexpr(i_k0 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); - - async_load_fence(k_dram_window.get_num_access()); - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); - gemm_0(s_acc, - get_slice_tile( - q, sequence<0, i_k0 * kK0>{}, sequence{}), -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - k_lds_load[number{})>{}]); - -#else - get_slice_tile(k_lds_load, - sequence<(LdsSeq.at(number{})) * kN0, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); -#endif - }); - } - - // TODO: this to fix a bug when loop smaller than 2, - // the following fence/barrier will be scheduled inside 1st loop - if constexpr(k0_loops <= 2) - __builtin_amdgcn_sched_barrier(0); - - async_load_fence(); - __builtin_amdgcn_s_barrier(); - - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - auto v_buf = load_tile(v_dram_window, bool_constant{}); - __builtin_amdgcn_sched_barrier(0); - { // tail - gemm_0(s_acc, - get_slice_tile( - q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - k_lds_load[number{})>{}]); - -#else - get_slice_tile( - k_lds_load, - sequence<(LdsSeq.at(number{})) * kN0, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); -#endif - } - __builtin_amdgcn_sched_barrier(1); - - // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); - tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); -#endif - }, - s_acc, - bias_tile); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); - }); - }); - } - else - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif - } - move_tile_window(bias_dram_window, {0, kN0}); - - /// TODO: only check in last iteration without increasing code size - if constexpr(kHasUnevenSplits) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - set_tile_if(s_acc, - -numeric::infinity(), - [&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return seqlen_k_end_ <= col; - }); - } - - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); - - if(need_perpixel_check) - { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); - } - } - - const auto s = cast_tile(s_acc); // S{j} - auto m_local = block_tile_reduce( - s, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max, bool_constant{}); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s.get_tile_distribution()); // Pcompute{j} - - __builtin_amdgcn_sched_barrier(0x7F); - // store & prefetch next v, after the max reduction - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_buf); - - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - - store_tile( - v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store the prefetch - } - - if constexpr(k1_loops > 1) - { - move_tile_window( - v_dram_window, - {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf - } - __builtin_amdgcn_sched_barrier(0); - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration. alibi does not have this problem - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); -#endif - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); - } -#else - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); -#endif - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - }(); -#else - const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); -#endif - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - // FIXME: this use different equation from FA v2 paper, - // but produce correc result. - // Is the equation wrong? - o_acc(i_j_idx) *= tmp; - }); - }); - - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); - } - - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); - - // STAGE 3, KV gemm - if constexpr(k1_loops > 1) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) - { - v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf - } - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_buf); - auto v_lds_window_tmp = get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, - v_shuffle_tmp)); // store the prefetch - } - else - { - auto v_lds_window_tmp = get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store next v_buf - } - if constexpr(i_k1 < k1_loops - 1) - move_tile_window(v_dram_window, {0, kK1}); - }); - } - i_total_loops++; - if(i_total_loops < num_total_loop) - { - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), - Policy::template MakeKDramTileDistribution()); - - if constexpr(k1_loops >= 2 && - LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) - __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - } - // tail - { - block_sync_lds(); - gemm_1( - o_acc, - get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); - } - } while(i_total_loops < num_total_loop); - - // store lse acc - if constexpr(kStoreLSE) - { - auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); - sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse_acc(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); - } - else - { - lse_acc(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); - } -#else - lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); -#endif - }); - - store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); - } - - // finally, O - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(FmhaMask::IsMasking) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - }); - - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - - return o_acc; - } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile - index_t num_splits, - index_t i_split, - FmhaMask mask, - PositionEncoding position_encoding, - float scale_s, - void* smem_ptr, - BlockDropout& dropout) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - identity{}, - v_dram_block_window_tmp, - identity{}, - bias_dram_block_window_tmp, - identity{}, - randval_dram_block_window_tmp, - lse_acc_dram_block_window_tmp, - identity{}, - identity{}, - identity{}, - identity{}, - num_splits, - i_split, - mask, - position_encoding, - scale_s, - smem_ptr, - dropout); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp deleted file mode 100644 index 6109fa5ab9..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp +++ /dev/null @@ -1,19 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" - -namespace ck_tile { - -// This pipeline is qkv all located in LDS -using BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy = - BlockFmhaPipelineQXKSVSCustomPolicy; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 23b75f16ac..d254f07e2d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -template -struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem +template +struct BlockFmhaFwdSplitKVPipelineProblem { - static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr bool kIsPagedKV = Traits::kIsPagedKV; + static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; template +struct BlockFmhaFwdAppendKVPipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + + static constexpr index_t kM0 = kM0_; + static constexpr index_t kN0 = kN0_; + static constexpr index_t kK0 = kK0_; + static constexpr index_t kN1 = kN1_; + + using VLayout = std::conditional_t; + + static constexpr auto RotaryEnum = RotaryEnum_; + static constexpr bool kIsPagedKV = kIsPagedKV_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 12af81bb98..80fbc8e380 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy() + GetSmemSizeDropout(); + return GetSmemSizeKV() + GetSmemSizeDropout(0); } else { - return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout()); + return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout(0)); } } + // this method is only available when Problem::kHasDropout is present template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() + CK_TILE_HOST_DEVICE static constexpr std:: + enable_if_t, ck_tile::index_t> + GetSmemSizeDropout(int) { if constexpr(Problem::kHasDropout) { @@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout(...) + { + return 0; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index be4fdfd711..e3187042d2 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" namespace ck_tile { @@ -32,30 +33,31 @@ struct TileFmhaTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -template -struct TileFmhaFwdSplitKVTraits : TileFmhaTraits +template +struct TileFmhaFwdSplitKVTraits { + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kIsPagedKV = kIsPagedKV_; // determine if some split (length) is not divisible by tile size static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; }; template +struct TileFmhaFwdAppendKVTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + template