Skip to content

Commit

Permalink
resolve page offsets absolutely not relatively
Browse files Browse the repository at this point in the history
  • Loading branch information
skrider committed Mar 26, 2024
1 parent 725d0fe commit b47b419
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 38 deletions.
16 changes: 8 additions & 8 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout()));

if (block_table != nullptr) {
tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}

Expand Down Expand Up @@ -769,9 +769,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
if (n_block > n_block_copy_min) {
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
}
Expand Down Expand Up @@ -865,7 +865,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
Expand Down Expand Up @@ -897,7 +897,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
Expand Down Expand Up @@ -937,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
Expand All @@ -955,7 +955,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
Expand Down
32 changes: 2 additions & 30 deletions csrc/flash_attn/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,11 @@ void cp_async_wait() {

////////////////////////////////////////////////////////////////////////////////////////////////////

// resolves initial base offset of a slice of a paged kv copy from gmem.
// resolves offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits>
__forceinline__ __device__
int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
Expand All @@ -313,34 +313,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
+ page_offset * row_stride
+ col_offset;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// advances base address of a slice of a paged copy from gmem
template <typename Kernel_traits>
__forceinline__ __device__
int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kBlockN = Kernel_traits::kBlockN;

const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;

const int global_row_offset_cur = block_row_offset + n_block * kBlockN;
const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN;

const int page_offset_cur = global_row_offset_cur % page_block_size;
const int page_offset_next = global_row_offset_next % page_block_size;

const int virtual_page_idx_cur = global_row_offset_cur / page_block_size;
const int virtual_page_idx_next = global_row_offset_next / page_block_size;

const int table_diff = block_table[virtual_page_idx_next] - block_table[virtual_page_idx_cur];
const int offset_diff = page_offset_next - page_offset_cur;

return table_diff * page_stride + offset_diff * row_stride;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down

0 comments on commit b47b419

Please sign in to comment.