Skip to content

Commit 19e8a7d

Browse files
authored
Merge branch 'main' into fix_1046
2 parents 7f1a507 + f7ba45d commit 19e8a7d

35 files changed

+2593
-370
lines changed

.clang-tidy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ ExtraArgs: ['-v']
44
FormatStyle: file
55
UseColor: true
66
WarningsAsErrors: '*'
7-
ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$'
7+
HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*'
88

99
# NOTE: there must be no spaces before the '-', so put the comma last.
1010
Checks: >-

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ env:
2222
PYTHONDEVMODE: "1"
2323
PYTHONUNBUFFERED: "1"
2424
PYTHONPATH: "" # explicit cleanup
25+
COLUMNS: "100"
2526
FORCE_COLOR: "1"
2627
CLICOLOR_FORCE: "1"
2728
UV_INDEX_STRATEGY: "unsafe-best-match"

.github/workflows/dist.yml

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,74 @@ concurrency:
2828
group: "${{ github.workflow }}-${{ github.ref }}"
2929
cancel-in-progress: true
3030

31+
env:
32+
PYTHONDEVMODE: "1"
33+
PYTHONUNBUFFERED: "1"
34+
COLUMNS: "100"
35+
FORCE_COLOR: "1"
36+
CLICOLOR_FORCE: "1"
37+
3138
jobs:
39+
build-sdist:
40+
name: Build SDist
41+
if: |
42+
github.repository_owner == 'tile-ai' &&
43+
(github.event_name != 'pull_request' || !github.event.pull_request.draft)
44+
runs-on: macos-latest
45+
timeout-minutes: 30
46+
env:
47+
NO_VERSION_LABEL: ${{ github.event_name == 'release' && 'OFF' || 'ON' }}
48+
# NO_GIT_VERSION disables embedding the git commit hash in version metadata.
49+
# Otherwise, the version of the SDist has a git hash suffix (e.g., 0.1.0+gitabcdef12),
50+
# but the package built from the SDist has no way to get the git hash (it is not a git repo),
51+
# leading to inconsistent versions between SDist and built packages (+gitabcdef12 vs. +gitunknown).
52+
NO_GIT_VERSION: "ON"
53+
54+
steps:
55+
- name: Checkout repository
56+
uses: actions/checkout@v5
57+
with:
58+
fetch-depth: 1
59+
submodules: recursive
60+
61+
- name: Setup Python and uv with caching
62+
id: setup-uv
63+
uses: astral-sh/setup-uv@v7
64+
with:
65+
python-version: "3.12"
66+
activate-environment: true
67+
68+
- name: Build SDist
69+
run: |
70+
uv run --no-project --with=build -m -- build --sdist --outdir=dist
71+
72+
- name: Setup ccache
73+
uses: hendrikmuhs/ccache-action@v1
74+
with:
75+
create-symlink: true
76+
key: ccache-${{ runner.os }}-${{ runner.arch }}
77+
evict-old-files: "7d"
78+
79+
- name: Test SDist buildable
80+
run: |
81+
TEMP_DIR="$(mktemp -d -t tilelang-sdist-test)"
82+
cp -r dist "${TEMP_DIR}/dist"
83+
uv venv --seed "${TEMP_DIR}/venv"
84+
source "${TEMP_DIR}/venv/bin/activate"
85+
cd "${TEMP_DIR}"
86+
python3 -m pip install --upgrade pip setuptools wheel
87+
python3 -m pip install -v dist/*.tar.gz
88+
python3 -c "import tilelang; print(tilelang.__version__)"
89+
90+
- name: Upload SDist
91+
# Not PR to save artifact storage, as SDist is only needed for releases.
92+
if: github.event_name != 'pull_request'
93+
uses: actions/upload-artifact@v4
94+
with:
95+
name: sdist
96+
path: dist/*.tar.gz
97+
if-no-files-found: error
98+
3299
build-wheels:
33100
name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.target.runner }} with ${{ matrix.target.toolkit }}
34101
if: |
@@ -94,22 +161,30 @@ jobs:
94161
- name: Upload wheels
95162
# Not PR to save artifact storage, as wheels are only needed for releases.
96163
if: github.event_name != 'pull_request'
97-
uses: actions/upload-artifact@v4
164+
uses: actions/upload-artifact@v5
98165
with:
99166
name: wheels-${{ matrix.python-version }}-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }}
100167
path: wheelhouse/*.whl
101168
if-no-files-found: error
102169

103170
list-artifacts:
104171
name: List artifacts
105-
# Not PR to save artifact storage, as wheels are only needed for releases.
172+
# Not PR to save artifact storage, as artifacts are only needed for releases.
106173
if: github.event_name != 'pull_request'
107174
runs-on: ubuntu-latest
108-
needs: [build-wheels]
175+
needs: [build-sdist, build-wheels]
109176
timeout-minutes: 15
110177
steps:
111-
- name: Download built wheels
178+
- name: Download built SDist
112179
uses: actions/download-artifact@v5
180+
with:
181+
# unpacks default artifact into dist/
182+
# if `name: artifact` is omitted, the action will create extra parent dir
183+
name: sdist
184+
path: dist
185+
186+
- name: Download built wheels
187+
uses: actions/download-artifact@v6
113188
with:
114189
pattern: wheels-*
115190
path: dist
@@ -119,7 +194,7 @@ jobs:
119194
run: ls -lh dist/*
120195

121196
- name: Upload artifacts
122-
uses: actions/upload-artifact@v4
197+
uses: actions/upload-artifact@v5
123198
with:
124199
name: artifacts
125200
path: dist/*

MANIFEST.in

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,35 @@
1+
# Reference: https://setuptools.pypa.io/en/latest/userguide/miscellaneous.html
2+
3+
# Include licenses
14
include VERSION
2-
include CMakeLists.txt
3-
include requirements.txt
4-
include requirements-test.txt
5-
include requirements-dev.txt
5+
include LICENSE
6+
include THIRDPARTYNOTICES.txt
7+
8+
# Version and dependency files
9+
include version_provider.py
10+
include requirements*.txt
611
include tilelang/jit/adapter/cython/cython_wrapper.pyx
7-
recursive-include src *
8-
recursive-include 3rdparty *
9-
recursive-exclude 3rdparty/clang* *
10-
recursive-exclude 3rdparty/llvm* *
12+
13+
# Include source files in SDist
14+
include CMakeLists.txt
15+
graft src
16+
graft cmake
17+
graft 3rdparty
18+
19+
# Include test suites in SDist
20+
graft testing
21+
graft examples
22+
global-exclude .coverage .coverage.* coverage.xml coverage-*.xml coverage.*.xml
23+
global-exclude .junit .junit.* junit.xml junit-*.xml junit.*.xml
24+
25+
# Exclude unneeded files and directories
26+
prune .git
27+
prune .github
28+
prune */.git
29+
prune */.github
30+
prune 3rdparty/clang*
31+
prune 3rdparty/llvm*
32+
33+
# Prune compiled files
34+
prune */__pycache__
35+
global-exclude *~ *.py[cod] *.so *.a *.dylib *.pxd *.dll *.lib *.o *.obj

benchmark/mamba2/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ PY
4545
| 16384 | 2.531 | 135.711 |
4646
| 32768 | 5.076 | 135.379 |
4747

48+
49+
## Compare with Baselines
50+
51+
- Triton: v3.5.0, mamba-ssm: v2.2.6.post3
52+
- Helion: v0.2.1
53+
4854
<figure style="text-align: center">
4955
<a href="mamba_benchmark_result.png">
5056
<img src="mamba_benchmark_result.png" alt="Mamba2_chunk_scan Performance Comparison on H100">

benchmark/mamba2/benchmark_mamba_chunk_scan.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@
55
import tilelang.language as T
66
from einops import rearrange, repeat
77
import itertools
8+
import math
9+
from tilelang.profiler import do_bench
10+
11+
try:
12+
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd
13+
except ImportError as err:
14+
raise ImportError("Please install mamba-ssm to use the triton chunk scan operator.") from err
15+
16+
try:
17+
import helion
18+
from helion._testing import run_example
19+
import helion.language as hl
20+
except ImportError as err:
21+
raise ImportError("Please install helion to use the helion chunk scan operator.") from err
822

923

1024
def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
@@ -54,6 +68,119 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
5468
return out
5569

5670

71+
def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
72+
out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D)
73+
return out
74+
75+
76+
def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
77+
78+
@helion.kernel()
79+
def helion_mamba2_chunk_scan_kernel(
80+
cb: torch.Tensor,
81+
x: torch.Tensor,
82+
dt: torch.Tensor,
83+
dA_cumsum: torch.Tensor,
84+
C: torch.Tensor,
85+
prev_states: torch.Tensor,
86+
D: torch.Tensor,
87+
) -> torch.Tensor:
88+
"""
89+
Argument:
90+
cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
91+
x: (batch, seqlen, nheads, headdim)
92+
dt: (batch, nheads, nchunks, chunk_size)
93+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
94+
C: (batch, seqlen, ngroups, dstate)
95+
prev_states: (batch, nchunks, nheads, headdim, dstate)
96+
D: (nheads,)
97+
Return:
98+
out: (batch, seqlen, nheads, headdim)
99+
"""
100+
101+
batch, nchunks, ngroups, chunk_size, _ = cb.shape
102+
_, seqlen, nheads, headdim = x.shape
103+
_, _, _, dstate = C.shape
104+
assert nchunks == (seqlen + chunk_size - 1) // chunk_size
105+
106+
block_m = hl.register_block_size(chunk_size)
107+
block_n = hl.register_block_size(headdim)
108+
block_k = hl.register_block_size(64, 64)
109+
dstate = hl.specialize(dstate)
110+
111+
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
112+
assert x.shape == (batch, seqlen, nheads, headdim)
113+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
114+
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
115+
assert C.shape == (batch, seqlen, ngroups, dstate)
116+
assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
117+
assert D.shape == (nheads,)
118+
119+
dtype = cb.dtype
120+
accum_dtype = torch.float32
121+
assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype ==
122+
dtype)
123+
124+
out = torch.empty_like(x)
125+
126+
p = 1.44269504
127+
128+
for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile(
129+
[nheads, chunk_size, headdim, batch, nchunks],
130+
block_size=[1, block_m, block_n, 1, 1],
131+
):
132+
acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype)
133+
dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin,
134+
tile_m].to(torch.float32)
135+
scale_m_local = torch.exp2(dA_cumsum_local_m * p)
136+
137+
C_local = C[
138+
tile_b.begin,
139+
tile_m.index + tile_c.begin * chunk_size,
140+
tile_h.begin // (nheads // ngroups),
141+
:,
142+
]
143+
prev_states_local = prev_states[tile_b.begin, tile_c.begin, tile_h.begin, tile_n, :]
144+
acc_o = hl.dot(C_local, prev_states_local.T, acc=acc_o)
145+
acc_o *= scale_m_local[:, None]
146+
147+
for tile_k in hl.tile((tile_m.id + 1) * block_m, block_size=block_k):
148+
cb_local = cb[
149+
tile_b.begin,
150+
tile_c.begin,
151+
tile_h.begin // (nheads // ngroups),
152+
tile_m,
153+
tile_k,
154+
]
155+
dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin,
156+
tile_k].to(torch.float32)
157+
cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p -
158+
dA_cumsum_local_k[None, :] * p)
159+
dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
160+
cb_local = (cb_local * dt_local[None, :]).to(dtype)
161+
pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :]
162+
cb_local = torch.where(pred, cb_local, torch.zeros_like(cb_local))
163+
x_local = x[
164+
tile_b.begin,
165+
tile_c.begin * chunk_size + tile_k.index,
166+
tile_h.begin,
167+
tile_n,
168+
]
169+
acc_o = hl.dot(cb_local, x_local, acc=acc_o)
170+
171+
D_local = D[tile_h.begin].to(torch.float32)
172+
x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin,
173+
tile_n].to(torch.float32)
174+
acc_o += x_residual * D_local
175+
out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin,
176+
tile_n] = acc_o.to(dtype=dtype)
177+
178+
return out
179+
180+
args = (cb, x, dt, dA_cumsum, C, states, D)
181+
run_example(helion_mamba2_chunk_scan_kernel, ref_program, args)
182+
183+
57184
def get_configs():
58185
iter_params = dict(
59186
block_M=[64, 128, 256],
@@ -212,12 +339,30 @@ def main(
212339
parser.add_argument('--tune', action='store_true', help='tune configs')
213340
args = parser.parse_args()
214341
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
342+
nchunks = math.ceil(seq_len / chunk_size)
215343
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
216344

345+
print("Benchmarking TileLang...")
217346
kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
218347
best_latency = kernel.latency
219348
best_config = kernel.config
220349
ref_latency = kernel.ref_latency
221350
print(f"Best latency: {best_latency}")
222351
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
223352
print(f"Best config: {best_config}")
353+
354+
cb = torch.randn(batch, nchunks, groups, chunk_size, chunk_size).half().cuda()
355+
x = torch.randn(batch, seq_len, heads, dim).half().cuda()
356+
dt = torch.randn(batch, heads, nchunks, chunk_size).half().cuda()
357+
dA_cumsum = torch.randn(batch, heads, nchunks, chunk_size).half().cuda()
358+
C = torch.randn(batch, seq_len, groups, dstate).half().cuda()
359+
states = torch.randn(batch, nchunks, heads, dim, dstate).half().cuda()
360+
D = torch.randn(heads).half().cuda()
361+
362+
print("Benchmarking Triton...")
363+
triton_latency = do_bench(
364+
lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
365+
print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}")
366+
367+
print("Benchmarking Helion...")
368+
chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D)

0 commit comments

Comments
 (0)