Skip to content

Commit adcba27

Browse files
Alex4210987xinxyxiaoLeiWang1999
authored
Add Flash Attn example on amd mi300 series (#682)
* [Enhancement] Refactor buffer index handling for improved precision and clarity (#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices. * Remove obsolete test script for AMD example, streamlining the examples directory. * Remove unused dtype_size variable in AMD example script to streamline code. * Add input configuration file and update AMD example script for enhanced flexibility - Introduced a new input.txt file for configurable parameters. - Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack. - Streamlined the main function for better clarity and organization. - Added a new test script to facilitate running the example with specified parameters. * Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations - Deleted input.txt and test.sh files as they are no longer needed. - Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance. - Reintroduced swizzle usage in the kernel for better performance. * Refactor AMD example script for FlashAttention-2 - Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`. - Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls. - Removed outdated comments and improved code organization for better readability. * Refactor formatting in AMD FlashAttention example script - Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function. - Streamlined the `main` function parameter formatting for consistency. - Removed unnecessary blank lines to enhance overall code organization. * Update example_amd_flash_attn_fwd.py --------- Co-authored-by: xinxyxiao <xinyxiao@amd.com> Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
1 parent 05f2fc6 commit adcba27

File tree

2 files changed

+239
-2
lines changed

2 files changed

+239
-2
lines changed
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import tilelang
4+
import tilelang.language as T
5+
import itertools
6+
import argparse
7+
from functools import partial
8+
9+
10+
def ref_program(Q, K, V, is_causal, groups=1):
11+
assert Q.size(
12+
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
13+
assert Q.size(
14+
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
15+
dim = Q.size(-1)
16+
K = K.repeat_interleave(groups, dim=2)
17+
V = V.repeat_interleave(groups, dim=2)
18+
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
19+
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
20+
if is_causal:
21+
seq_len = Q.size(1)
22+
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
23+
mask = mask.unsqueeze(0).unsqueeze(0)
24+
scores = scores.masked_fill(mask == 0, float('-inf'))
25+
attention_weights = F.softmax(scores, dim=-1)
26+
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
27+
return output
28+
29+
30+
def get_configs():
31+
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
32+
block_M = [64, 128, 256]
33+
block_N = [32, 64, 128]
34+
threads = [128, 256, 512]
35+
num_split_q = [32, 64, 128]
36+
num_stages = [0, 1, 2]
37+
enable_rasterization = [True, False]
38+
k_pack = [1, 2]
39+
40+
valid_configs = []
41+
42+
for m, n, s, t, stages, r, k in itertools.product(block_M, block_N, num_split_q, threads,
43+
num_stages, enable_rasterization, k_pack):
44+
valid_configs.append({
45+
"block_M": m,
46+
"block_N": n,
47+
"num_split_q": s,
48+
"threads": t,
49+
"num_stages": stages,
50+
"enable_rasterization": r,
51+
"k_pack": k
52+
})
53+
valid_configs.append({
54+
'block_M': 64,
55+
'block_N': 64,
56+
'num_split_q': 64,
57+
'threads': 256,
58+
'num_stages': 1,
59+
'enable_rasterization': True,
60+
'k_pack': 2
61+
})
62+
return valid_configs
63+
64+
65+
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True)
66+
@tilelang.jit(out_idx=[3])
67+
def fast_flashattn(
68+
batch,
69+
heads,
70+
seq_len,
71+
dim,
72+
is_causal,
73+
groups,
74+
block_M: int,
75+
block_N: int,
76+
num_split_q: int,
77+
threads: int,
78+
num_stages: int,
79+
enable_rasterization: bool,
80+
k_pack: int,
81+
):
82+
scale = (1.0 / dim)**0.5 * 1.44269504
83+
head_kv = heads // groups
84+
q_shape = [batch, seq_len, heads, dim]
85+
kv_shape = [batch, seq_len, head_kv, dim]
86+
dtype = "float16"
87+
accum_dtype = "float"
88+
89+
v_vec_size = 4
90+
vec_size = 4 * k_pack
91+
92+
@T.prim_func
93+
def main(
94+
Q: T.Tensor(q_shape, dtype),
95+
K: T.Tensor(kv_shape, dtype),
96+
V: T.Tensor(kv_shape, dtype),
97+
Output: T.Tensor(q_shape, dtype),
98+
):
99+
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
100+
T.use_swizzle(10, enable=enable_rasterization)
101+
102+
bz = byz_combined // heads
103+
by = byz_combined % heads
104+
105+
num_q_blocks = T.ceildiv(seq_len, block_M)
106+
107+
bx = T.alloc_var("int32")
108+
bx[0] = b_split
109+
110+
with T.While(bx[0] < num_q_blocks):
111+
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
112+
m_i = T.alloc_fragment([block_M], accum_dtype)
113+
l_i = T.alloc_fragment([block_M], accum_dtype)
114+
T.fill(acc_o, 0)
115+
T.fill(m_i, -T.infinity(accum_dtype))
116+
T.fill(l_i, 0)
117+
118+
current_bx = bx[0]
119+
q_block_offset = current_bx * block_M
120+
121+
Q_shared = T.alloc_shared([block_M, dim], dtype)
122+
K_shared = T.alloc_shared([block_N, dim], dtype)
123+
V_shared = T.alloc_shared([block_N, dim], dtype)
124+
P_shared = T.alloc_shared([block_M, block_N], dtype)
125+
126+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
127+
m_prev = T.alloc_fragment([block_M], accum_dtype)
128+
scale_factor = T.alloc_fragment([block_M], accum_dtype)
129+
130+
T.copy(
131+
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
132+
Q_shared,
133+
coalesced_width=vec_size)
134+
135+
loop_end_k = T.ceildiv(q_block_offset + block_M,
136+
block_N) if is_causal else T.ceildiv(seq_len, block_N)
137+
138+
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
139+
kv_idx = k * block_N
140+
141+
T.copy(
142+
K[bz, kv_idx:kv_idx + block_N, by // groups, :],
143+
K_shared,
144+
coalesced_width=vec_size)
145+
T.copy(
146+
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
147+
V_shared,
148+
coalesced_width=v_vec_size)
149+
150+
T.clear(acc_s)
151+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, k_pack=k_pack)
152+
153+
if is_causal:
154+
for i, j in T.Parallel(block_M, block_N):
155+
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j,
156+
acc_s[i, j], -T.infinity(acc_s.dtype))
157+
158+
T.copy(m_i, m_prev)
159+
T.reduce_max(acc_s, m_i, dim=1, clear=False)
160+
161+
for i in T.Parallel(block_M):
162+
sf = T.exp2(m_prev[i] * scale - m_i[i] * scale)
163+
l_i[i] *= sf
164+
scale_factor[i] = sf
165+
166+
for i, j in T.Parallel(block_M, dim):
167+
acc_o[i, j] *= scale_factor[i]
168+
169+
for i, j in T.Parallel(block_M, block_N):
170+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale)
171+
172+
row_sum = T.alloc_fragment([block_M], accum_dtype)
173+
T.reduce_sum(acc_s, row_sum, dim=1)
174+
for i in T.Parallel(block_M):
175+
l_i[i] += row_sum[i]
176+
177+
T.copy(acc_s, P_shared)
178+
T.sync_threads()
179+
180+
T.gemm(P_shared, V_shared, acc_o)
181+
182+
l_inv = T.alloc_fragment([block_M], accum_dtype)
183+
for i in T.Parallel(block_M):
184+
safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
185+
l_inv[i] = 1.0 / safe_l
186+
187+
for i, j in T.Parallel(block_M, dim):
188+
Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]
189+
190+
bx[0] = current_bx + num_split_q
191+
192+
return main
193+
194+
195+
def main(batch: int = 1,
196+
heads: int = 8,
197+
seq_len: int = 4096,
198+
dim: int = 128,
199+
is_causal: bool = False,
200+
groups: int = 1):
201+
202+
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
203+
total_flops = 2 * flops_per_matmul
204+
if is_causal:
205+
total_flops *= 0.5
206+
207+
print("Starting autotuning for FlashAttention-V2...")
208+
kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups=groups)
209+
print(f"Autotuning finished. Best Configuration: {kernel.config}")
210+
211+
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
212+
213+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
214+
215+
print("Verifying correctness...")
216+
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
217+
print("All checks pass.")
218+
219+
latency = profiler.do_bench(ref_program_processed, warmup=100)
220+
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
221+
222+
latency = profiler.do_bench(warmup=100)
223+
print(
224+
f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops"
225+
)
226+
227+
228+
if __name__ == "__main__":
229+
parser = argparse.ArgumentParser()
230+
parser.add_argument('--batch', type=int, default=1, help='batch size')
231+
parser.add_argument('--heads', type=int, default=8, help='heads')
232+
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
233+
parser.add_argument('--dim', type=int, default=128, help='dim')
234+
parser.add_argument('--is_causal', action='store_true', help='causal')
235+
parser.add_argument('--groups', type=int, default=1, help='groups')
236+
args = parser.parse_args()
237+
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)

src/tl_templates/hip/reduce.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct MinOp {
2222
}
2323
};
2424

25-
template <class Reducer, int threads, int scale> struct AllReduce {
25+
template <class Reducer, int threads, int scale, int thread_offset = 0> struct AllReduce {
2626
static_assert(threads == 1024 || threads == 512 || threads == 256 ||
2727
threads == 128 || threads == 64 || threads == 32 ||
2828
threads == 16 || threads == 8 || threads == 4 || threads == 2);
@@ -43,7 +43,7 @@ template <class Reducer, int threads, int scale> struct AllReduce {
4343
if constexpr (offset == scale) {
4444
return x;
4545
} else {
46-
return AllReduce<Reducer, offset, scale>::run(x, red_buf);
46+
return AllReduce<Reducer, offset, scale, thread_offset>::run(x, red_buf);
4747
}
4848
}
4949
};

0 commit comments

Comments
 (0)