forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #74 from ROCm/ck_tile/kvcache
Ck tile/kvcache
- Loading branch information
Showing
9 changed files
with
950 additions
and
33 deletions.
There are no files selected for viewing
Submodule composable_kernel
updated
77 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, Tri Dao. | ||
******************************************************************************/ | ||
|
||
#include "flash_common.hpp" | ||
|
||
namespace flash { | ||
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) | ||
{ | ||
int device; | ||
auto status = hipGetDevice(&device); | ||
if(status != hipSuccess) | ||
return num_splits; | ||
|
||
hipDeviceProp_t props{}; | ||
status = hipGetDeviceProperties(&props, device); | ||
if(status != hipSuccess) | ||
return num_splits; | ||
|
||
// TODO - tile size should match the TileFmhaShape, hardcode for now | ||
const int kM0 = 128; | ||
const int kN1 = hdim_v; | ||
|
||
const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; | ||
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1; | ||
|
||
if(num_splits < 1 && p_drop == 0.0f) | ||
return num_splits_heuristic_ck( | ||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); | ||
|
||
return num_splits; | ||
} | ||
|
||
} // namespace flash |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.