Skip to content

Commit 9acab61

Browse files
committed
fmt
1 parent b47bb95 commit 9acab61

File tree

331 files changed

+6642
-7990
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

331 files changed

+6642
-7990
lines changed

benchmark/blocksparse_attention/benchmark_library_dense_fmha.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
77
bsz, num_head, downsample_len, _ = x.shape
88
# N_CTX = downsample_len * BLOCK
99
sparse_index = torch.topk(x, topk, dim=-1).indices
10-
dense_mask = torch.full(
11-
[bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device
12-
)
10+
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
11+
False,
12+
dtype=torch.bool,
13+
device=x.device)
1314
dense_mask.scatter_(-1, sparse_index, True)
1415
if use_dense_for_last_block:
1516
dense_mask[:, :, -2:, :] = True

benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
1515
bsz, num_head, downsample_len, _ = x.shape
1616
# N_CTX = downsample_len * BLOCK
1717
sparse_index = torch.topk(x, topk, dim=-1).indices
18-
dense_mask = torch.full(
19-
[bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device
20-
)
18+
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
19+
False,
20+
dtype=torch.bool,
21+
device=x.device)
2122
dense_mask.scatter_(-1, sparse_index, True)
2223
if use_dense_for_last_block:
2324
dense_mask[:, :, -2:, :] = True
@@ -38,7 +39,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
3839
block_N = 64
3940
num_stages = 2
4041
threads = 128
41-
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
42+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
4243
shape = [batch, heads, seq_len, dim]
4344
block_mask_shape = [batch, heads, downsample_len, downsample_len]
4445

@@ -47,6 +48,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
4748
block_mask_dtype = "bool"
4849

4950
def kernel_func(block_M, block_N, num_stages, threads):
51+
5052
@T.macro
5153
def MMA0(
5254
K: T.Tensor(shape, dtype),
@@ -58,12 +60,11 @@ def MMA0(
5860
by: T.int32,
5961
bz: T.int32,
6062
):
61-
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
63+
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
6264
if is_causal:
6365
for i, j in T.Parallel(block_M, block_N):
64-
acc_s[i, j] = T.if_then_else(
65-
bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)
66-
)
66+
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
67+
-T.infinity(acc_s.dtype))
6768
else:
6869
T.clear(acc_s)
6970
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@@ -78,18 +79,18 @@ def MMA1(
7879
by: T.int32,
7980
bz: T.int32,
8081
):
81-
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
82+
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
8283
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
8384

8485
@T.macro
8586
def Softmax(
86-
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
87-
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
88-
scores_max: T.FragmentBuffer([block_M], accum_dtype),
89-
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
90-
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
91-
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
92-
logsum: T.FragmentBuffer([block_M], accum_dtype),
87+
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
88+
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
89+
scores_max: T.FragmentBuffer([block_M], accum_dtype),
90+
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
91+
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
92+
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
93+
logsum: T.FragmentBuffer([block_M], accum_dtype),
9394
):
9495
T.copy(scores_max, scores_max_prev)
9596
T.fill(scores_max, -T.infinity(accum_dtype))
@@ -113,25 +114,26 @@ def Softmax(
113114

114115
@T.macro
115116
def Rescale(
116-
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
117-
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
117+
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
118+
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
118119
):
119120
for i, j in T.Parallel(block_M, dim):
120121
acc_o[i, j] *= scores_scale[i]
121122

122123
@T.prim_func
123124
def main(
124-
Q: T.Tensor(shape, dtype),
125-
K: T.Tensor(shape, dtype),
126-
V: T.Tensor(shape, dtype),
127-
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
128-
Output: T.Tensor(shape, dtype),
125+
Q: T.Tensor(shape, dtype),
126+
K: T.Tensor(shape, dtype),
127+
V: T.Tensor(shape, dtype),
128+
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
129+
Output: T.Tensor(shape, dtype),
129130
):
130-
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (
131-
bx,
132-
by,
133-
bz,
134-
):
131+
with T.Kernel(
132+
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (
133+
bx,
134+
by,
135+
bz,
136+
):
135137
Q_shared = T.alloc_shared([block_M, dim], dtype)
136138
K_shared = T.alloc_shared([block_N, dim], dtype)
137139
V_shared = T.alloc_shared([block_N, dim], dtype)
@@ -146,7 +148,7 @@ def main(
146148
logsum = T.alloc_fragment([block_M], accum_dtype)
147149
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
148150

149-
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
151+
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
150152
T.fill(acc_o, 0)
151153
T.fill(logsum, 0)
152154
T.fill(scores_max, -T.infinity(accum_dtype))
@@ -155,10 +157,8 @@ def main(
155157
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
156158

157159
loop_range = (
158-
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
159-
if is_causal
160-
else T.ceildiv(seq_len, block_N)
161-
)
160+
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
161+
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
162162

163163
for k in T.Pipelined(loop_range, num_stages=num_stages):
164164
if block_mask[k]:
@@ -177,7 +177,7 @@ def main(
177177
for i, j in T.Parallel(block_M, dim):
178178
acc_o[i, j] /= logsum[i]
179179
T.copy(acc_o, O_shared)
180-
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
180+
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
181181

182182
return main
183183

@@ -199,14 +199,13 @@ def benchmark_topk_sparse_attention():
199199
# Create sparse mask (downsampled to block level)
200200
downsample_factor = BLOCK
201201
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
202-
x_ds = torch.randn(
203-
[BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16
204-
)
202+
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
203+
device="cuda",
204+
dtype=torch.bfloat16)
205205
x_ds[:, :, :, 0] = 100
206206
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
207207
program = blocksparse_flashattn(
208-
BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True
209-
)
208+
BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
210209
kernel = tilelang.compile(program, out_idx=4)
211210

212211
def benchmark_fn():

benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
1010
bsz, num_head, downsample_len, _ = x.shape
1111
# N_CTX = downsample_len * BLOCK
1212
sparse_index = torch.topk(x, topk, dim=-1).indices
13-
dense_mask = torch.full(
14-
[bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device
15-
)
13+
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
14+
False,
15+
dtype=torch.bool,
16+
device=x.device)
1617
dense_mask.scatter_(-1, sparse_index, True)
1718
if use_dense_for_last_block:
1819
dense_mask[:, :, -2:, :] = True
@@ -45,9 +46,9 @@ def benchmark_topk_sparse_attention():
4546
# Create sparse mask (downsampled to block level)
4647
downsample_factor = BLOCK
4748
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
48-
x_ds = torch.randn(
49-
[BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16
50-
)
49+
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
50+
device="cuda",
51+
dtype=torch.bfloat16)
5152
x_ds[:, :, :, 0] = 100
5253
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
5354

benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
1515
bsz, num_head, downsample_len, _ = x.shape
1616
# N_CTX = downsample_len * BLOCK
1717
sparse_index = torch.topk(x, topk, dim=-1).indices
18-
dense_mask = torch.full(
19-
[bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device
20-
)
18+
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
19+
False,
20+
dtype=torch.bool,
21+
device=x.device)
2122
dense_mask.scatter_(-1, sparse_index, True)
2223
if use_dense_for_last_block:
2324
dense_mask[:, :, -2:, :] = True
@@ -70,9 +71,8 @@ def _fwd_kernel_inner(
7071

7172
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
7273
if LAST_K_BLOCK:
73-
qk += tl.where(
74-
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")
75-
)
74+
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0,
75+
float("-inf"))
7676

7777
m_ij = tl.maximum(m_i, tl.max(qk, 1))
7878
qk -= m_ij[:, None]
@@ -191,11 +191,8 @@ def _fwd_kernel(
191191
acc = acc.to(Out.dtype.element_ty)
192192

193193
off_o = (
194-
off_z * stride_oz
195-
+ off_h * stride_oh
196-
+ offs_m[:, None] * stride_om
197-
+ offs_d[None, :] * stride_od
198-
)
194+
off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om +
195+
offs_d[None, :] * stride_od)
199196
out_ptrs = Out + off_o
200197
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
201198

@@ -257,6 +254,7 @@ def _forward(
257254

258255

259256
class _sparse_attention(torch.autograd.Function):
257+
260258
@staticmethod
261259
def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
262260
# shape constraints
@@ -289,9 +287,9 @@ def benchmark_topk_sparse_attention():
289287
# Create sparse mask (downsampled to block level)
290288
downsample_factor = BLOCK
291289
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
292-
x_ds = torch.randn(
293-
[BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16
294-
)
290+
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
291+
device="cuda",
292+
dtype=torch.bfloat16)
295293
x_ds[:, :, :, 0] = 100
296294
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
297295

benchmark/matmul/benchmark_matmul.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,9 @@ def get_configs(args, kwargs):
101101
policy=[T.GemmWarpPolicy.Square],
102102
enable_rasteration=[True, False],
103103
)
104-
return [
105-
{k: v for k, v in zip(iter_params, values)}
106-
for values in itertools.product(*iter_params.values())
107-
]
104+
return [{
105+
k: v for k, v in zip(iter_params, values)
106+
} for values in itertools.product(*iter_params.values())]
108107
return configs
109108

110109

@@ -113,9 +112,7 @@ def get_configs(args, kwargs):
113112
warmup=3,
114113
rep=20,
115114
)
116-
@jit(
117-
out_idx=[2],
118-
)
115+
@jit(out_idx=[2],)
119116
def matmul(
120117
M,
121118
N,
@@ -162,9 +159,9 @@ def matmul(
162159

163160
@T.prim_func
164161
def main(
165-
A: T.Tensor((M, K), dtype),
166-
B: T.Tensor((N, K), dtype),
167-
C: T.Tensor((M, N), dtype),
162+
A: T.Tensor((M, K), dtype),
163+
B: T.Tensor((N, K), dtype),
164+
C: T.Tensor((M, N), dtype),
168165
):
169166
"""
170167
The compiled TVM function for block-level matrix multiplication.

benchmark/matmul/benchmark_matmul_intrinsic.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import tilelang.language as T
77
from tilelang.intrinsics import get_swizzle_layout
88
from tilelang.intrinsics.mma_macro_generator import (
9-
TensorCoreIntrinEmitter,
10-
)
9+
TensorCoreIntrinEmitter,)
1110
from tilelang.transform import simplify_prim_func
1211
from tilelang.autotuner import autotune
1312
import itertools
@@ -104,9 +103,9 @@ def tl_matmul(
104103

105104
@T.prim_func
106105
def main(
107-
A: T.Tensor(A_shape, in_dtype),
108-
B: T.Tensor(B_shape, in_dtype),
109-
C: T.Tensor((M, N), out_dtype),
106+
A: T.Tensor(A_shape, in_dtype),
107+
B: T.Tensor(B_shape, in_dtype),
108+
C: T.Tensor((M, N), out_dtype),
110109
):
111110
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
112111
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
@@ -116,12 +115,10 @@ def main(
116115
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
117116
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
118117

119-
T.annotate_layout(
120-
{
121-
A_shared: make_swizzle_layout(A_shared),
122-
B_shared: make_swizzle_layout(B_shared),
123-
}
124-
)
118+
T.annotate_layout({
119+
A_shared: make_swizzle_layout(A_shared),
120+
B_shared: make_swizzle_layout(B_shared),
121+
})
125122

126123
# Improve L2 Cache
127124
T.use_swizzle(panel_size=10, enable=enable_rasteration)
@@ -232,10 +229,9 @@ def get_configs(args, kwargs):
232229
stage=[0, 2],
233230
enable_rasteration=[True, False],
234231
)
235-
return [
236-
{k: v for k, v in zip(iter_params, values)}
237-
for values in itertools.product(*iter_params.values())
238-
]
232+
return [{
233+
k: v for k, v in zip(iter_params, values)
234+
} for values in itertools.product(*iter_params.values())]
239235

240236
return configs
241237

@@ -247,9 +243,7 @@ def get_configs(args, kwargs):
247243
ref_prog=ref_program,
248244
skip_check=True,
249245
)
250-
@tl.jit(
251-
out_idx=[2],
252-
)
246+
@tl.jit(out_idx=[2],)
253247
def matmul(
254248
M,
255249
N,
@@ -300,8 +294,7 @@ def kernel():
300294
help="Whether to use roller to deduce search spaces",
301295
)
302296
parser.add_argument(
303-
"--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type"
304-
)
297+
"--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type")
305298
args = parser.parse_args()
306299

307300
M, N, K = args.m, args.n, args.k

0 commit comments

Comments
 (0)