Skip to content

Commit 16b919b

Browse files
authored
[Dev] Adjust computation logic to avoid precision loss when casting acc_s from float to float16 (#141)
- Remove redundant `acc_s_0` fragment in flash attention kernel - Simplify memory copy and reduction operations - Reorder memory copy and scaling steps for improved performance - Add Hopper-specific synchronization method in CUDA reduce template - Update reduce operation to use architecture-specific synchronization
1 parent d5d4247 commit 16b919b

File tree

3 files changed

+36
-20
lines changed

3 files changed

+36
-20
lines changed

examples/deepseek_mla/example_mla_decode.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def flash_attn(
3131
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
3232
O_shared = T.alloc_shared([block_H, dim], dtype)
3333
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
34-
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
3534
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
3635
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
3736
scores_max = T.alloc_fragment([block_H], accum_dtype)
@@ -57,28 +56,27 @@ def flash_attn(
5756
for k in T.Pipelined(loop_range, num_stages=2):
5857
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
5958
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
60-
T.clear(acc_s_0)
59+
T.clear(acc_s)
6160
T.gemm(
62-
Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
61+
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
6362
T.gemm(
6463
Q_pe_shared,
6564
K_pe_shared,
66-
acc_s_0,
65+
acc_s,
6766
transpose_B=True,
6867
policy=T.GemmWarpPolicy.FullCol)
6968
T.copy(scores_max, scores_max_prev)
7069
T.fill(scores_max, -T.infinity(accum_dtype))
71-
T.copy(acc_s_0, S_shared)
72-
T.copy(S_shared, acc_s)
7370
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
7471
for i in T.Parallel(block_H):
7572
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
7673
for i, j in T.Parallel(block_H, block_N):
7774
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
7875
T.reduce_sum(acc_s, scores_sum, dim=1)
76+
T.copy(acc_s, S_shared)
77+
T.copy(S_shared, acc_s_cast)
7978
for i in T.Parallel(block_H):
8079
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
81-
T.copy(acc_s, acc_s_cast)
8280
for i, j in T.Parallel(block_H, dim):
8381
acc_o[i, j] *= scores_scale[i]
8482
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
@@ -105,7 +103,6 @@ def flash_attn_split(
105103
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
106104
O_shared = T.alloc_shared([block_H, dim], dtype)
107105
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
108-
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
109106
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
110107
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
111108
scores_max = T.alloc_fragment([block_H], accum_dtype)
@@ -131,31 +128,29 @@ def flash_attn_split(
131128
for k in T.Pipelined(loop_range, num_stages=2):
132129
kv_start = (seqlen_kv // num_split) * bz + k * block_N
133130
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
134-
135131
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
136132
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
137-
T.clear(acc_s_0)
133+
T.clear(acc_s)
138134
T.gemm(
139-
Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
135+
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
140136
T.gemm(
141137
Q_pe_shared,
142138
K_pe_shared,
143-
acc_s_0,
139+
acc_s,
144140
transpose_B=True,
145141
policy=T.GemmWarpPolicy.FullCol)
146142
T.copy(scores_max, scores_max_prev)
147143
T.fill(scores_max, -T.infinity(accum_dtype))
148-
T.copy(acc_s_0, S_shared)
149-
T.copy(S_shared, acc_s)
150144
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
151145
for i in T.Parallel(block_H):
152146
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
153147
for i, j in T.Parallel(block_H, block_N):
154148
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
155149
T.reduce_sum(acc_s, scores_sum, dim=1)
150+
T.copy(acc_s, S_shared)
151+
T.copy(S_shared, acc_s_cast)
156152
for i in T.Parallel(block_H):
157153
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
158-
T.copy(acc_s, acc_s_cast)
159154
for i, j in T.Parallel(block_H, dim):
160155
acc_o[i, j] *= scores_scale[i]
161156
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
@@ -301,4 +296,4 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
301296
print("All close")
302297
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch")
303298
print("Tile-lang: {:.2f} ms".format(latency))
304-
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
299+
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))

src/op/reduce.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,13 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
161161
continue;
162162
int reducing_threads = (*extent) * (*scale);
163163
std::stringstream ss;
164-
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
165-
<< reducing_threads << ", " << (*scale) << ">::run";
164+
if (Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
165+
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
166+
<< reducing_threads << ", " << (*scale) << ">::run_hopper";
167+
} else {
168+
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
169+
<< reducing_threads << ", " << (*scale) << ">::run";
170+
}
166171
Array<PrimExpr> thread_reduce_args = {
167172
StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
168173
if (reducing_threads >= 32) {

src/tl_templates/cuda/reduce.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@ template <class Reducer, int threads, int scale> struct AllReduce {
3333
constexpr int offset = threads / 2;
3434
if constexpr (offset >= 32) {
3535
__syncthreads();
36-
// asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(256));
3736
red_buf[threadIdx.x] = x;
3837
__syncthreads();
39-
// asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(256));
4038
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
4139
} else {
4240
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
@@ -47,6 +45,24 @@ template <class Reducer, int threads, int scale> struct AllReduce {
4745
return AllReduce<Reducer, offset, scale>::run(x, red_buf);
4846
}
4947
}
48+
49+
template <typename T>
50+
static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) {
51+
constexpr int offset = threads / 2;
52+
if constexpr (offset >= 32) {
53+
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(threads));
54+
red_buf[threadIdx.x] = x;
55+
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(threads));
56+
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
57+
} else {
58+
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
59+
}
60+
if constexpr (offset == scale) {
61+
return x;
62+
} else {
63+
return AllReduce<Reducer, offset, scale>::run_hopper(x, red_buf);
64+
}
65+
}
5066
};
5167

5268
} // namespace tl

0 commit comments

Comments
 (0)