Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Jul 24, 2024
1 parent fb16238 commit cf7a7d4
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 93 deletions.
161 changes: 83 additions & 78 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,39 +158,42 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
uint32_t pos = blockIdx.x;
uint32_t head_idx = ty;

if (num_index_sets > 1) {
vec_t<float, vec_size> v_merged_vec;
state_t<vec_size> st;
v_merged_vec.fill(0.f);
#pragma unroll 2
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
float s = S[(pos * num_index_sets + iter) * num_heads + head_idx];
vec_t<float, vec_size> v;
v.cast_load(V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim +
tx * vec_size);
st.merge(v, s, 1);
}

st.normalize();
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = st.get_lse();
}
} else if (num_index_sets == 1) {
if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx];
s_merged[pos * num_heads + head_idx] = -5e4;
}
} else {
// num_index_sets == 0
return;
}

if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx];
}
return;
}

vec_t<float, vec_size> v_merged_vec;
state_t<vec_size> st;
v_merged_vec.fill(0.f);
#pragma unroll 2
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
float s = S[(pos * num_index_sets + iter) * num_heads + head_idx];
vec_t<float, vec_size> v;
v.cast_load(V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim +
tx * vec_size);
st.merge(v, s, 1);
}

st.normalize();
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = st.get_lse();
}
}

Expand Down Expand Up @@ -313,68 +316,70 @@ __global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float*
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];

if (num_index_sets > 1) {
#pragma unroll
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
(iter * bdy + ty) < num_index_sets);
cp_async::commit_group();
}
#pragma unroll 4
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
if (iter % bdx == 0) {
s_smem[ty * bdx + tx] =
iter * bdy + (ty * bdx + tx) < num_index_sets
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
: 0.f;
__syncthreads();
}
cp_async::wait_group<num_smem_stages - 1>();
__syncthreads();
vec_t<float, vec_size> v;
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
if (iter * bdy + ty < num_index_sets) {
float s = s_smem[(iter % bdx) * bdy + ty];
st.merge(v, s, 1);
}
__syncthreads();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
V +
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
head_dim +
tx * vec_size,
(iter + num_smem_stages) * bdy + ty < num_index_sets);
cp_async::commit_group();
}
cp_async::wait_group<0>();
__syncthreads();

st.normalize();
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
st.normalize();

st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = st.get_lse();
s_merged[pos * num_heads + head_idx] = -5e4;
}
} else if (num_index_sets == 1) {
return;
}

if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
}
} else {
// num_index_sets == 0
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
}

#pragma unroll
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
(iter * bdy + ty) < num_index_sets);
cp_async::commit_group();
}
#pragma unroll 4
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
if (iter % bdx == 0) {
s_smem[ty * bdx + tx] =
iter * bdy + (ty * bdx + tx) < num_index_sets
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
: 0.f;
__syncthreads();
}
cp_async::wait_group<num_smem_stages - 1>();
__syncthreads();
vec_t<float, vec_size> v;
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
if (iter * bdy + ty < num_index_sets) {
float s = s_smem[(iter % bdx) * bdy + ty];
st.merge(v, s, 1);
}
__syncthreads();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
V +
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
head_dim +
tx * vec_size,
(iter + num_smem_stages) * bdy + ty < num_index_sets);
cp_async::commit_group();
}
cp_async::wait_group<0>();
__syncthreads();

st.normalize();
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
st.normalize();

st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = st.get_lse();
}
}

Expand Down
30 changes: 15 additions & 15 deletions src/test_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,21 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
thrust::device_vector<float> S_merged_1_device(seq_len * num_heads);

if (num_index_sets > 1) {
// Method 0: use MergeState
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
thrust::raw_pointer_cast(S_device_trans.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
for (uint i = 2; i < num_index_sets; ++i) {
MergeStateInPlace(
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
num_heads, head_dim);
}
// Method 0: use MergeState
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
thrust::raw_pointer_cast(S_device_trans.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
for (uint i = 2; i < num_index_sets; ++i) {
MergeStateInPlace(
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
num_heads, head_dim);
}
} else {
V_merged_0_device = V_device;
S_merged_0_device = S_device;
Expand Down

0 comments on commit cf7a7d4

Please sign in to comment.