Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support for cached multi-query attention towards speculative decoding #1679

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "csrc/flash_attn_v2/third_party/cutlass"]
path = csrc/flash_attn_v2/third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
29 changes: 29 additions & 0 deletions csrc/flash_attn_v2/paged_flash/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
181 changes: 181 additions & 0 deletions csrc/flash_attn_v2/paged_flash/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# CMAKE generated file: DO NOT EDIT!
# Generated by "Unix Makefiles" Generator, CMake Version 3.25

# Default target executed when no arguments are given to make.
default_target: all
.PHONY : default_target

# Allow only one "make -f Makefile2" at a time, but pass parallelism.
.NOTPARALLEL:

#=============================================================================
# Special targets provided by cmake.

# Disable implicit rules so canonical targets will work.
.SUFFIXES:

# Disable VCS-based implicit rules.
% : %,v

# Disable VCS-based implicit rules.
% : RCS/%

# Disable VCS-based implicit rules.
% : RCS/%,v

# Disable VCS-based implicit rules.
% : SCCS/s.%

# Disable VCS-based implicit rules.
% : s.%

.SUFFIXES: .hpux_make_needs_suffix_list

# Command-line flag to silence nested $(MAKE).
$(VERBOSE)MAKESILENT = -s

#Suppress display of executed commands.
$(VERBOSE).SILENT:

# A target that is always out of date.
cmake_force:
.PHONY : cmake_force

#=============================================================================
# Set environment variables for the build.

# The shell in which to execute make rules.
SHELL = /bin/sh

# The CMake executable.
CMAKE_COMMAND = /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake

# The command to remove a file.
RM = /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake -E rm -f

# Escaping for special characters.
EQUALS = =

# The top-level source directory on which CMake was run.
CMAKE_SOURCE_DIR = /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash

# The top-level build directory on which CMake was run.
CMAKE_BINARY_DIR = /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash

#=============================================================================
# Targets provided globally by CMake.

# Special rule for the target edit_cache
edit_cache:
@$(CMAKE_COMMAND) -E cmake_echo_color --switch=$(COLOR) --cyan "No interactive CMake dialog available..."
/home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake -E echo No\ interactive\ CMake\ dialog\ available.
.PHONY : edit_cache

# Special rule for the target edit_cache
edit_cache/fast: edit_cache
.PHONY : edit_cache/fast

# Special rule for the target rebuild_cache
rebuild_cache:
@$(CMAKE_COMMAND) -E cmake_echo_color --switch=$(COLOR) --cyan "Running CMake to regenerate build system..."
/home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake --regenerate-during-build -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR)
.PHONY : rebuild_cache

# Special rule for the target rebuild_cache
rebuild_cache/fast: rebuild_cache
.PHONY : rebuild_cache/fast

# The main all target
all: cmake_check_build_system
$(CMAKE_COMMAND) -E cmake_progress_start /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash/CMakeFiles /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash//CMakeFiles/progress.marks
$(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 all
$(CMAKE_COMMAND) -E cmake_progress_start /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash/CMakeFiles 0
.PHONY : all

# The main clean target
clean:
$(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 clean
.PHONY : clean

# The main clean target
clean/fast: clean
.PHONY : clean/fast

# Prepare targets for installation.
preinstall: all
$(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall
.PHONY : preinstall

# Prepare targets for installation.
preinstall/fast:
$(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall
.PHONY : preinstall/fast

# clear depends
depend:
$(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 1
.PHONY : depend

#=============================================================================
# Target rules for targets named gemm

# Build rule for target.
gemm: cmake_check_build_system
$(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 gemm
.PHONY : gemm

# fast build rule for target.
gemm/fast:
$(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/build
.PHONY : gemm/fast

flash_fwd_hdim32_bf16_sm80.o: flash_fwd_hdim32_bf16_sm80.cu.o
.PHONY : flash_fwd_hdim32_bf16_sm80.o

# target to build an object file
flash_fwd_hdim32_bf16_sm80.cu.o:
$(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/flash_fwd_hdim32_bf16_sm80.cu.o
.PHONY : flash_fwd_hdim32_bf16_sm80.cu.o

flash_fwd_hdim32_bf16_sm80.i: flash_fwd_hdim32_bf16_sm80.cu.i
.PHONY : flash_fwd_hdim32_bf16_sm80.i

# target to preprocess a source file
flash_fwd_hdim32_bf16_sm80.cu.i:
$(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/flash_fwd_hdim32_bf16_sm80.cu.i
.PHONY : flash_fwd_hdim32_bf16_sm80.cu.i

flash_fwd_hdim32_bf16_sm80.s: flash_fwd_hdim32_bf16_sm80.cu.s
.PHONY : flash_fwd_hdim32_bf16_sm80.s

# target to generate assembly for a file
flash_fwd_hdim32_bf16_sm80.cu.s:
$(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/flash_fwd_hdim32_bf16_sm80.cu.s
.PHONY : flash_fwd_hdim32_bf16_sm80.cu.s

# Help Target
help:
@echo "The following are some of the valid targets for this Makefile:"
@echo "... all (the default if no target is provided)"
@echo "... clean"
@echo "... depend"
@echo "... edit_cache"
@echo "... rebuild_cache"
@echo "... gemm"
@echo "... flash_fwd_hdim32_bf16_sm80.o"
@echo "... flash_fwd_hdim32_bf16_sm80.i"
@echo "... flash_fwd_hdim32_bf16_sm80.s"
.PHONY : help



#=============================================================================
# Special targets to cleanup operation of make.

# Special rule to run CMake to check the build system integrity.
# No rule that depends on this can have commands that come from listfiles
# because they might be regenerated.
cmake_check_build_system:
$(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 0
.PHONY : cmake_check_build_system

27 changes: 27 additions & 0 deletions csrc/flash_attn_v2/paged_flash/attention_atom.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

#pragma once

#include <cstdint>
#include "cuda.h"
#include "cute/pointer.hpp"

struct __align__(32) AttentionAtom {
using index_t = uint32_t;

index_t* block_idx_list;

index_t q_start_idx;
index_t q_len;
index_t kv_blocks;
index_t total_extent;
index_t global_q_idx;
index_t unused;

template <int threads>
__device__ void load_kv_block_idxs(cute::smem_ptr<int32_t> block_idx_list_shr, int tidx) const
{
for (int i = tidx; i < kv_blocks; i += threads) { block_idx_list_shr[i] = block_idx_list[i]; }
// Aggressive (but safe) sync
__syncthreads();
}
};
83 changes: 83 additions & 0 deletions csrc/flash_attn_v2/paged_flash/flash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/

#pragma once

#include <cuda.h>
#include <vector>

#include "attention_atom.h"

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;

#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
void* __restrict__ v_ptr;

// The stride between rows of the Q, K and V matrices.
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;

// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void* __restrict__ o_ptr;

// The attention metadata
// AttentionAtom* __restrict__ atoms;

// Total attention atoms
// int num_atoms;

// PagedAttention metadata
int num_seqs;
int max_num_query;
int max_context_len;
int block_size;
int max_num_blocks_per_seq;

index_t* __restrict__ block_tables;
index_t* __restrict__ context_lens;
index_t* __restrict__ draft_lens;

// The stride between rows of O.
index_t o_row_stride;
index_t o_head_stride;

// The dimensions
int d, d_rounded;

// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;

bool is_bf16;
bool is_causal;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, int Headdim>
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
Loading
Loading