Skip to content

Commit cda16d5

Browse files
author
nicunxiao
committed
Merge branch 'main' into fix_1046
2 parents c49edee + 86c8bb4 commit cda16d5

File tree

168 files changed

+2168
-1443
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

168 files changed

+2168
-1443
lines changed

.clang-tidy

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
Checks: >
1+
---
2+
InheritParentConfig: true
3+
ExtraArgs: ['-v']
4+
FormatStyle: file
5+
UseColor: true
6+
WarningsAsErrors: '*'
7+
ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$'
8+
9+
# NOTE: there must be no spaces before the '-', so put the comma last.
10+
Checks: >-
211
# 1. Retained categories: easier to find bugs/performance issues
312
clang-analyzer-*,
413
cppcoreguidelines-pro-type-static-cast-downcast,
@@ -47,7 +56,3 @@ Checks: >
4756
-clang-analyzer-deadcode.DeadStores,
4857
-clang-analyzer-optin.cplusplus.VirtualCall,
4958
-clang-diagnostic-tautological-constant-compare,
50-
51-
WarningsAsErrors: '*'
52-
53-
HeaderFilterRegex: '^(?!.*(3rdparty|build)).*$'

.github/workflows/ci.yml

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -287,21 +287,39 @@ jobs:
287287
echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure."
288288
uv cache clean
289289
290-
- name: Run format check
291-
id: format-check
290+
- name: Run clang-tidy
291+
id: clang-tidy
292+
if: runner.os == 'Linux'
292293
run: |
293-
mkdir -p build
294+
echo "\$ $(command -v clang-tidy) --version" && clang-tidy --version
295+
296+
if [[ -x "$(command -v run-clang-tidy)" ]]; then
297+
echo "Using run-clang-tidy from $(command -v run-clang-tidy)"
298+
CLANG_TIDY=(run-clang-tidy)
299+
else
300+
echo "Downloading run-clang-tidy script"
301+
wget -O run-clang-tidy.py https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py
302+
CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py)
303+
fi
304+
if [[ -x "$(command -v clang-apply-replacements)" ]]; then
305+
echo "Using clang-apply-replacements from $(command -v clang-apply-replacements)"
306+
CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)")
307+
else
308+
echo "::warning::clang-apply-replacements not found in PATH, automatic fixing disabled."
309+
fi
310+
294311
# Run cmake to create the build directory with compile_commands.json
295-
(
296-
cd build
297-
cmake .. ${CLANG_TIDY_CMAKE_OPTIONS} # no quotes here
298-
)
312+
cmake -S . -B cmake-build --fresh ${CLANG_TIDY_CMAKE_OPTIONS} # no quotes here
313+
314+
CXX_FILES=$(find src -type f -iname "*.[ch]pp" -o -iname "*.cc" -o -iname "*.c" -o -iname "*.h")
299315
rc=0
300-
bash format.sh || rc="$?"
301-
rm -rf build
302-
if [[ "${rc}" -ne 0 ]]; then
303-
echo "::error::Format check failed. Please run 'bash format.sh' locally to fix the issues."
304-
exit 1
316+
"${CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \
317+
-p="cmake-build" ${CXX_FILES} || rc="$?"
318+
rm -rf cmake-build run-clang-tidy.py
319+
if (( rc != 0 )); then
320+
echo "::error::clang-tidy found issues (exit code: ${rc}). Please run 'clang-tidy --fix' locally to fix them."
321+
git diff --color=always || true
322+
exit "${rc}"
305323
fi
306324
307325
- name: Enable core dump generation (Linux / GitHub-hosted runners)

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,7 @@ tilelang/jit/adapter/cython/.cycache
9797

9898
# claude
9999
**/.claude
100+
101+
# CMake
102+
cmake-build/
103+
cmake-build-*/

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ repos:
3232
args: [--ignore-case]
3333
files: ^docs/spelling_wordlist\.txt$
3434
- repo: https://github.com/pre-commit/mirrors-clang-format
35-
rev: v15.0.7 # sync with requirements-lint.txt
35+
rev: v21.1.2 # sync with requirements-lint.txt
3636
hooks:
3737
- id: clang-format
3838
exclude: |
@@ -41,7 +41,7 @@ repos:
4141
^.+\.json$
4242
)
4343
- repo: https://github.com/astral-sh/ruff-pre-commit
44-
rev: v0.14.0 # sync with requirements-lint.txt
44+
rev: v0.14.1 # sync with requirements-lint.txt
4545
hooks:
4646
- id: ruff-check
4747
args: [--fix, --exit-non-zero-on-fix]

benchmark/mamba2/README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Mamba2_chunk_scan Benchmark
2+
3+
This document records the throughput achieved by `benchmark_mamba_chunk_scan.py` when computing `batch = 8`, `heads = 80`, `groups = 1`, `chunk_size = 256`, `dim = 64`, and `dstate = 128` across different `seq_len` using the default autotuning search space.
4+
5+
## Environment
6+
7+
- Repository commit: `8a5eb569704bfea64478c29adcfe3a09e3c2b12c`
8+
- GPUs: `NVIDIA H800 SXM` on driver `560.35.05`
9+
10+
## How to Reproduce
11+
12+
```bash
13+
cd benchmark/mamba2
14+
python - <<'PY'
15+
from benchmark_mamba_chunk_scan import chunk_scan_fwd
16+
17+
batch = 8
18+
heads = 80
19+
groups = 1
20+
chunk_size = 256
21+
dim = 64
22+
dstate = 128
23+
for seq_len in [1024, 2048, 4096, 8192, 16384, 32768]:
24+
res = chunk_scan_fwd(
25+
batch,
26+
seq_len,
27+
chunk_size,
28+
groups,
29+
heads,
30+
dim,
31+
dstate)
32+
tflops = (2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate) / res.latency * 1e-9
33+
print(f"seq_len={seq_len:5d} latency={res.latency:.6f}ms TFlops={tflops:.3f}")
34+
PY
35+
```
36+
37+
## Results
38+
39+
| Seq_len| Latency (ms) | Throughput (TFLOPs) |
40+
|-------|-------------|---------------------|
41+
| 1024 | 0.169 | 126.477 |
42+
| 2048 | 0.329 | 130.195 |
43+
| 4096 | 0.645 | 133.054 |
44+
| 8192 | 1.278 | 134.362 |
45+
| 16384 | 2.531 | 135.711 |
46+
| 32768 | 5.076 | 135.379 |
47+
48+
<figure style="text-align: center">
49+
<a href="mamba_benchmark_result.png">
50+
<img src="mamba_benchmark_result.png" alt="Mamba2_chunk_scan Performance Comparison on H100">
51+
</a>
52+
<figcaption style="text-align: center;">Performance comparison across compilers on NVIDIA H100</figcaption>
53+
</figure>
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import argparse
2+
import torch
3+
import tilelang
4+
from tilelang.autotuner import *
5+
import tilelang.language as T
6+
from einops import rearrange, repeat
7+
import itertools
8+
9+
10+
def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
11+
"""
12+
Argument:
13+
cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
14+
x: (batch, seqlen, nheads, headdim)
15+
dt: (batch, nheads, nchunks, chunk_size)
16+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
17+
C: (batch, seqlen, ngroups, dstate)
18+
prev_states: (batch, nchunks, nheads, headdim, dstate)
19+
D: (nheads, headdim) or (nheads,)
20+
z: (batch, seqlen, nheads, headdim)
21+
Return:
22+
out: (batch, seqlen, nheads, headdim)
23+
"""
24+
_, _, ngroups, _, _ = cb.shape
25+
batch, seqlen, nheads, headdim = x.shape
26+
# _, _, ngroups, dstate = B.shape
27+
# assert B.shape == (batch, seqlen, ngroups, dstate)
28+
_, _, nchunks, chunk_size = dt.shape
29+
assert seqlen == nchunks * chunk_size
30+
# assert C.shape == B.shape
31+
# B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
32+
C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
33+
cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups)
34+
# CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
35+
# rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
36+
# (batch, nheads, nchunks, chunksize, chunksize)
37+
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
38+
decay = torch.exp(dt_segment_sum)
39+
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
40+
causal_mask = torch.tril(
41+
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
42+
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
43+
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
44+
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
45+
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
46+
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(
47+
C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
48+
out = out + out_prev
49+
out = rearrange(out, "b c l h p -> b (c l) h p")
50+
if D is not None:
51+
if D.dim() == 1:
52+
D = rearrange(D, "h -> h 1")
53+
out = out + x * D
54+
return out
55+
56+
57+
def get_configs():
58+
iter_params = dict(
59+
block_M=[64, 128, 256],
60+
block_N=[32, 64],
61+
block_K=[64, 128, 256],
62+
block_Dstate=[128],
63+
num_stages=[1, 2, 3, 4, 5])
64+
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
65+
66+
67+
@autotune(configs=get_configs(), warmup=10, rep=10)
68+
@tilelang.jit(
69+
out_idx=[7],
70+
pass_configs={
71+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
72+
},
73+
)
74+
def chunk_scan_fwd(batch,
75+
seqlen,
76+
chunk_size,
77+
ngroups,
78+
nheads,
79+
headdim,
80+
dstate,
81+
block_M=64,
82+
block_N=64,
83+
block_K=64,
84+
block_Dstate=128,
85+
num_stages=2,
86+
threads=128):
87+
dtype = "float16"
88+
accum_dtype = "float"
89+
nchunks = T.ceildiv(seqlen, chunk_size)
90+
p = 1.44269504
91+
92+
@T.prim_func
93+
def main(
94+
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
95+
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
96+
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
97+
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
98+
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
99+
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
100+
D: T.Tensor((nheads), dtype), # type: ignore
101+
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore
102+
):
103+
with T.Kernel(
104+
nheads,
105+
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
106+
batch * nchunks,
107+
threads=threads) as (bz, bx, by):
108+
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
109+
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
110+
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
111+
cb_local = T.alloc_fragment((block_M, block_K), dtype)
112+
dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared")
113+
dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype)
114+
dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype)
115+
dt_shared = T.alloc_shared((block_K), dtype, scope="shared")
116+
dt_local = T.alloc_fragment((block_K), accum_dtype)
117+
x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn")
118+
dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared")
119+
scale_m_local = T.alloc_fragment((block_M), accum_dtype)
120+
C_shared = T.alloc_shared((block_M, block_Dstate), dtype)
121+
prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype)
122+
D_local = T.alloc_fragment((1), accum_dtype)
123+
x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn")
124+
x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype)
125+
126+
batch_idx = by % batch
127+
chunk_idx = by // batch
128+
# m: chunk_size
129+
# n : headdim
130+
m_idx = bx // T.ceildiv(headdim, block_N)
131+
n_idx = bx % T.ceildiv(headdim, block_N)
132+
133+
T.annotate_layout({
134+
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
135+
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
136+
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
137+
})
138+
139+
T.no_set_max_nreg()
140+
141+
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
142+
dA_cs_m_shared)
143+
T.copy(dA_cs_m_shared, dA_cs_m_local)
144+
T.clear(acc_o)
145+
146+
for i in T.Parallel(block_M):
147+
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
148+
T.copy(
149+
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
150+
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared)
151+
T.copy(
152+
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N,
153+
0:block_Dstate], prev_state_shared)
154+
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
155+
for i, j in T.Parallel(block_M, block_N):
156+
acc_o[i, j] *= scale_m_local[i]
157+
158+
loop_range = T.ceildiv((m_idx + 1) * block_M, block_K)
159+
160+
for k in T.Pipelined(loop_range, num_stages=num_stages):
161+
T.copy(
162+
cb[batch_idx, chunk_idx, bz // (nheads // ngroups),
163+
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K],
164+
cb_shared)
165+
T.copy(cb_shared, cb_local)
166+
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
167+
dA_cs_k_shared)
168+
T.copy(dA_cs_k_shared, dA_cs_k_local)
169+
for i, j in T.Parallel(block_M, block_K):
170+
cb_local[i,
171+
j] = cb_local[i,
172+
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
173+
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
174+
T.copy(dt_shared, dt_local)
175+
for i, j in T.Parallel(block_M, block_K):
176+
cb_local[i, j] *= dt_local[j]
177+
for i, j in T.Parallel(block_M, block_K):
178+
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j,
179+
cb_local[i, j], 0)
180+
T.copy(
181+
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
182+
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared)
183+
T.gemm(cb_local, x_shared, acc_o)
184+
185+
D_local[0] = D[bz]
186+
T.copy(
187+
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
188+
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N],
189+
x_residual_shared)
190+
T.copy(x_residual_shared, x_residual_local)
191+
for i, j in T.Parallel(block_M, block_N):
192+
acc_o[i, j] += x_residual_local[i, j] * D_local[0]
193+
194+
T.copy(acc_o, acc_o_shared)
195+
T.copy(
196+
acc_o_shared,
197+
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
198+
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N])
199+
200+
return main
201+
202+
203+
if __name__ == "__main__":
204+
parser = argparse.ArgumentParser()
205+
parser.add_argument('--batch', type=int, default=8, help='batch size')
206+
parser.add_argument('--heads', type=int, default=80, help='heads')
207+
parser.add_argument('--groups', type=int, default=1, help='groups')
208+
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
209+
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size')
210+
parser.add_argument('--dim', type=int, default=64, help='dim')
211+
parser.add_argument('--dstate', type=int, default=128, help='dstate')
212+
parser.add_argument('--tune', action='store_true', help='tune configs')
213+
args = parser.parse_args()
214+
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
215+
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
216+
217+
kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
218+
best_latency = kernel.latency
219+
best_config = kernel.config
220+
ref_latency = kernel.ref_latency
221+
print(f"Best latency: {best_latency}")
222+
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
223+
print(f"Best config: {best_config}")
85.6 KB
Loading

docs/conf.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
# -*- coding: utf-8 -*-
2-
31
# General information about the project.
42
project = "Tile Language <br>"
53
author = "Tile Lang Contributors"
6-
copyright = "2025-2025, %s" % author
4+
copyright = f"2025-2025, {author}"
75

86
# Version information.
9-
with open("../VERSION", "r") as f:
7+
with open("../VERSION") as f:
108
version = f.read().strip()
119
release = version
1210

0 commit comments

Comments
 (0)