Skip to content

Commit

Permalink
Opitmization for AlphaFold2 model (#4)
Browse files Browse the repository at this point in the history
* add for alpha_fold2

* add some extra setting

* fix some bugs

* fix some changes

* fix some bugs 2nd

* Add another initition of Gmem_tile_qkv and Gmem_tile_o

* add some compensation for try..catch

* fix mistake in flash_attn_fwd

* commit for code style and bug check

* fix some bugs for flash_attn_with_bias-mask

* add more print for pointer debug

* add some bug test cases.

* backward function

* fix bugs

* make some changes for backward

* Fix compiling error.

* quote all printf debug

* quote all printf debug and fix interface error

* quote all printf debug and fix interface error, fix typo

* remove all printf

* split files

* remove useless debug code

* split fwd and bwd execution function

* split fwd and bwd execution function

* remove useless codes

* remove useless codes

* remove useless codes 3rd times

* remove useless codes 4th times

* Fix compiling error.

* Remove const.
  • Loading branch information
JamesLim-sy authored May 8, 2023
1 parent 5ff4bbf commit 209f02b
Show file tree
Hide file tree
Showing 27 changed files with 3,033 additions and 13 deletions.
1 change: 1 addition & 0 deletions csrc/flash_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_library(flashattn SHARED
${SOURCES_CU}
${SOURCES_CPP}
flash_attn.cpp
flash_attn_with_bias_mask.cpp
)

target_compile_options(flashattn PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
Expand Down
3 changes: 2 additions & 1 deletion csrc/flash_attn/flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
*
******************************************************************************/

#include "flash_attn.h"
#include "fmha.h"
#include "utils.h"
#include "cuda_utils.h"
Expand Down Expand Up @@ -62,7 +63,7 @@ extern "C" {

static thread_local std::unique_ptr<char[]> flash_attn_err_msg;

static void flash_attn_set_error(const char *msg) {
void flash_attn_set_error(const char *msg) {
if (msg == nullptr || *msg == '\0') {
msg = "unknown error";
}
Expand Down
73 changes: 73 additions & 0 deletions csrc/flash_attn/flash_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,79 @@ bool flash_attn_bwd(
uint64_t offset
);

bool flash_attn_fwd_with_bias_and_mask(
const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const int32_t *cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence
const int32_t *cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence
const int total_q,
const int total_k,
const int batch_size,
const int num_heads,
const int head_size,
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const bool is_bf16,
const int num_splits, // SMs per attention matrix, can be 1
void *softmax_lse_ptr, // softmax log_sum_exp
void *softmax_ptr,
void *workspace_ptr,
uint64_t *workspace_size,
cudaStream_t stream,
uint64_t seed,
uint64_t offset,
const void *attn_mask,
const void *attn_bias,
const int64_t* mask_dims,
const int64_t* bias_dims
);

bool flash_attn_bwd_with_bias_and_mask(
const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
void *dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
void *dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
void *dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *dout, // total_q x num_heads, x head_size
const int32_t *cu_seqlens_q, // int32, batch_size+1
const int32_t *cu_seqlens_k, // int32, batch_size+1
const int total_q,
const int total_k,
const int batch_size,
const int num_heads,
const int head_size,
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const bool is_bf16,
const int num_splits,
const void *softmax_lse_ptr,
void *dsoftmax_ptr,
void *dbias_ptr,
void *workspace_ptr,
uint64_t *workspace_size,
cudaStream_t stream,
uint64_t seed,
uint64_t offset,
const void* attn_mask,
const void* attn_bias,
const int64_t* mask_dims,
const int64_t* bias_dims
);

void flash_attn_set_error(const char *msg);

const char *flash_attn_error();

#ifdef __cplusplus
Expand Down
Loading

0 comments on commit 209f02b

Please sign in to comment.