Skip to content

Commit

Permalink
Try to crop the library's size. (#6)
Browse files Browse the repository at this point in the history
* Hide the internal symbols and remve the hdim16 implementation.

* Remove SYMBOL_EXPORT.

* Remove some debugging statements.

* Change has_attn_bias and has_attn_mask to arguments instead of template.

* Avoid to compile the no-used .cu files.

* Remove return_softmax related template and argument.

* Remove the support of is_causal for the implementation with mask and bias.

* Reorganize codes.

* Polish codes.

* Add check of softmax_scale.
  • Loading branch information
Xreki authored May 18, 2023
1 parent 209f02b commit 18106c1
Show file tree
Hide file tree
Showing 20 changed files with 443 additions and 636 deletions.
19 changes: 17 additions & 2 deletions csrc/flash_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,23 @@ include_directories(
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
)

file(GLOB SOURCES_CU "src/*.cu")
file(GLOB SOURCES_CPP "src/*.cpp")
#file(GLOB SOURCES_CU "src/*.cu")
#file(GLOB SOURCES_CPP "src/*.cpp")
set(SOURCES_CU
src/fmha_fwd_hdim32.cu
src/fmha_fwd_hdim64.cu
src/fmha_fwd_hdim128.cu
src/fmha_bwd_hdim32.cu
src/fmha_bwd_hdim64.cu
src/fmha_bwd_hdim128.cu
src/fmha_fwd_with_mask_bias_hdim32.cu
src/fmha_fwd_with_mask_bias_hdim64.cu
src/fmha_fwd_with_mask_bias_hdim128.cu
src/fmha_bwd_with_mask_bias_hdim32.cu
src/fmha_bwd_with_mask_bias_hdim64.cu
src/fmha_bwd_with_mask_bias_hdim128.cu
src/utils.cu)
set(SOURCES_CPP src/cuda_utils.cpp)

#add_library(flashattn OBJECT
add_library(flashattn SHARED
Expand Down
3 changes: 0 additions & 3 deletions csrc/flash_attn/flash_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,9 @@ bool flash_attn_fwd_with_bias_and_mask(
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const bool is_bf16,
const int num_splits, // SMs per attention matrix, can be 1
void *softmax_lse_ptr, // softmax log_sum_exp
void *softmax_ptr,
void *workspace_ptr,
uint64_t *workspace_size,
cudaStream_t stream,
Expand Down Expand Up @@ -124,7 +122,6 @@ bool flash_attn_bwd_with_bias_and_mask(
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const bool is_bf16,
const int num_splits,
const void *softmax_lse_ptr,
Expand Down
362 changes: 183 additions & 179 deletions csrc/flash_attn/flash_attn_with_bias_mask.cpp

Large diffs are not rendered by default.

26 changes: 12 additions & 14 deletions csrc/flash_attn/src/fmha.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,16 @@ struct FMHA_dgrad_params : public FMHA_fprop_params {
////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_params>
struct Launch_params{
struct Launch_params {
Launch_params(cudaDeviceProp * props_,
cudaStream_t stream_,
bool is_dropout_,
bool return_softmax_)
: elts_per_thread(0)
, props(props_)
, stream(stream_)
, is_dropout(is_dropout_)
, return_softmax(return_softmax_) {
: elts_per_thread(0),
props(props_),
stream(stream_),
is_dropout(is_dropout_),
return_softmax(return_softmax_) {
}

size_t elts_per_thread;
Expand Down Expand Up @@ -206,15 +206,13 @@ void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const b
void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);

void run_fmha_fwd_with_mask_bias_hdim16(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
void run_fmha_fwd_with_mask_bias_hdim32(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
void run_fmha_fwd_with_mask_bias_hdim64(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
void run_fmha_fwd_with_mask_bias_hdim128(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
bool run_fmha_fwd_with_mask_bias_hdim32(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
bool run_fmha_fwd_with_mask_bias_hdim64(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
bool run_fmha_fwd_with_mask_bias_hdim128(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);

void run_fmha_bwd_with_mask_bias_hdim16(FMHA_dgrad_params &params, cudaStream_t stream);
void run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params &params, cudaStream_t stream);
void run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params &params, cudaStream_t stream);
void run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params &params, cudaStream_t stream);
bool run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params &params, cudaStream_t stream);
bool run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params &params, cudaStream_t stream);
bool run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params &params, cudaStream_t stream);

void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);

Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/fmha/gmem_tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ struct Gmem_tile_mma_bias {
static_assert(Mma_tile::N_PER_MMA == 16);

// The distance between two blocks (in bytes).
// TODO: mask is [bs, head, seq_q, seq_k]
// TODO: bias is [bs, head, seq_q, seq_k]
// The block index.
// uint32_t bidx = binfo.bidb * params.h + binfo.bidh;
uint32_t bidx = ( binfo.bidb % params.bias_mod_size ) * params.h + binfo.bidh;
Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/fmha/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}

template<bool zero=false, typename Fragment, typename Mask>
inline __device__ void apply_attn_mask(const Fragment (&attn_mask)[MMAS_M][MMAS_N], const Mask &mask, int l = 0, int loop_step_idx = 0) {
inline __device__ void apply_attn_mask(const Fragment (&attn_mask)[MMAS_M][MMAS_N], const Mask &mask, int l = 0) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
Expand Down
126 changes: 0 additions & 126 deletions csrc/flash_attn/src/fmha_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ __global__ void fmha_bwd_q_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params param
fmha::compute_dq_dk_dv_seqparallel<Kernel_traits, Is_dropout, Is_causal>(params);
}

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}

template<typename Kernel_traits>
void run_fmha_bwd_loop(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
Expand Down Expand Up @@ -118,124 +113,3 @@ void run_fmha_bwd_loop(FMHA_dgrad_params &params, cudaStream_t stream, const boo
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}));
}

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Need_attn_mask, bool Need_attn_bias, int loop_steps=-1>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_1xN_with_bias_mask<Kernel_traits, Is_dropout, Is_causal, Need_attn_mask, Need_attn_bias, loop_steps>(params);
}

template<typename Kernel_traits>
void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;

using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);

constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);

bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"

bool has_attn_mask = !(params.attn_mask_ptr == nullptr);
bool has_attn_bias = !(params.attn_bias_ptr == nullptr);

if (has_attn_mask) {
if (has_attn_bias) {
BOOL_SWITCH_FUNC(is_dropout, IsDropoutConst, [&] {
auto kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, true, true>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, true, true>;
if (params.seqlen_k == blocksize_c) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, true, true, /*loop_steps=*/1>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, true, true, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, true, true, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, true, true, /*loop_steps=*/2>;
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}else{
BOOL_SWITCH_FUNC(is_dropout, IsDropoutConst, [&] {
auto kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, true, false>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, true, false>;
if (params.seqlen_k == blocksize_c) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, true, false, /*loop_steps=*/1>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, true, false, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, true, false, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, true, false, /*loop_steps=*/2>;
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}
}else{
if (has_attn_bias) {
BOOL_SWITCH_FUNC(is_dropout, IsDropoutConst, [&] {
auto kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, false, true>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, false, true>;
if (params.seqlen_k == blocksize_c) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, false, true, /*loop_steps=*/1>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, false, true, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, false, true, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, false, true, /*loop_steps=*/2>;
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}else{
BOOL_SWITCH_FUNC(is_dropout, IsDropoutConst, [&] {
auto kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, false, false>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, false, false>;
if (params.seqlen_k == blocksize_c) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, false, false, /*loop_steps=*/1>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, false, false, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, false, false, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, false, false, /*loop_steps=*/2>;
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}
}
}
10 changes: 5 additions & 5 deletions csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim128.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Copyright (c) 2022, Tri Dao.

// Splitting the different head dimensions to different files to speed up compilation.
#include "fmha_bwd_with_mask_bias_launch_template.h"

#include "fmha_bwd_launch_template.h"

void run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params &params, cudaStream_t stream) {
bool run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params &params, cudaStream_t stream) {
bool status = true;
FP16_SWITCH(params.is_bf16, ([&] {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}));
return status;
}
22 changes: 0 additions & 22 deletions csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim16.cu

This file was deleted.

12 changes: 6 additions & 6 deletions csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim32.cu
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
// Copyright (c) 2022, Tri Dao.

// Splitting the different head dimensions to different files to speed up compilation.
#include "fmha_bwd_with_mask_bias_launch_template.h"

#include "fmha_bwd_launch_template.h"

void run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params &params, cudaStream_t stream) {
bool run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params &params, cudaStream_t stream) {
bool status = true;
FP16_SWITCH(params.is_bf16, ([&] {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
}));
return status;
}
16 changes: 8 additions & 8 deletions csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim64.cu
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
// Copyright (c) 2022, Tri Dao.

// Splitting the different head dimensions to different files to speed up compilation.
#include "fmha_bwd_with_mask_bias_launch_template.h"

#include "fmha_bwd_launch_template.h"

void run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params &params, cudaStream_t stream) {
bool run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params &params, cudaStream_t stream) {
bool status = true;
auto dprops = GetDeviceProperties(-1);
FP16_SWITCH(params.is_bf16, ([&] {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.seqlen_k >= 256 ) {
if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
}
}));
return status;
}
Loading

0 comments on commit 18106c1

Please sign in to comment.