Skip to content

Commit e9a608e

Browse files
xwhzzLeiWang1999
andauthored
[Bugfix][CI] Bug fixing and migrate CI from ada to hopper (#652)
* fix CI bugs in hopper * lint fix * Update bulk_copy.cc * Refactor bulk copy logic in LowerBulkCopy function - Removed unnecessary blank lines for improved code readability. - Enhanced stride validation by checking for null pointers in global stride calculations, ensuring robustness against symbolic strides. - Updated pass configuration handling in dynamic tile language tests to streamline dynamic alignment and TMA lower pass settings. * test fix * ci fix * Update flash-attention dependencies and clean up example code - Downgraded `flash-attn` dependency version in `requirements-test.txt` to `<=2.2.0`. - Removed unused imports and commented-out code in various example files to enhance readability and maintainability. - Updated the `flashattn` function signature to include default parameters for `block_M`, `block_N`, `num_stages`, and `threads`. - Cleaned up the `example_mha_fwd_varlen.py` and `example_mha_bwd_wgmma_pipelined.py` files by removing unnecessary comments and improving code clarity. - Deleted the `example_mha_inference.py` file as it is no longer needed. * Update CI workflow to remove `--user` flag from pip install commands - Removed the `--user` flag from the pip install commands in both the development and testing sections of the CI workflow to ensure proper installation of dependencies in the virtual environment. * Update CI workflow to include `--no-user` flag in pip install commands - Added the `--no-user` flag to the pip install commands in both the development and testing sections of the CI workflow to ensure dependencies are installed correctly within the virtual environment. * Update CI workflow to include `--no-user` flag in pip install command for wheel mode - Added the `--no-user` flag to the pip install command in the wheel mode section of the CI workflow to ensure dependencies are installed correctly within the virtual environment. * test fix * avoid conflict with system environments * test fix * add commnets --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 5bd3f94 commit e9a608e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+419
-1740
lines changed

.github/workflows/ci.yml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ jobs:
2323
- name: Activate virtual environment and install dependencies
2424
run: |
2525
source tilelang_ci/bin/activate
26-
python -m pip install --upgrade pip
27-
if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi
26+
python -m pip install --upgrade pip --no-user
27+
if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt --no-user; fi
2828
2929
- name: Update submodules recursively
3030
run: git submodule update --init --recursive
@@ -55,22 +55,24 @@ jobs:
5555
- name: Activate virtual environment and install dependencies
5656
run: |
5757
source tilelang_ci/bin/activate
58-
python -m pip install --upgrade pip
59-
if [ -f requirements-test.txt ]; then PIP_NO_BUILD_ISOLATION=1 python -m pip install -r requirements-test.txt; fi
58+
python -m pip install --upgrade pip --no-user
59+
if [ -f requirements-test.txt ]; then PIP_NO_BUILD_ISOLATION=1 python -m pip install -r requirements-test.txt --no-user; fi
6060
6161
- name: Install project in wheel mode
6262
run: |
6363
source tilelang_ci/bin/activate
64-
python -m pip install .
64+
python -m pip install . --no-user
6565
6666
- name: Run examples
6767
run: |
6868
source tilelang_ci/bin/activate
6969
cd examples
70+
unset PYTHONPATH
7071
python -m pytest **/test*.py
7172
7273
- name: Run tests
7374
run: |
7475
source tilelang_ci/bin/activate
7576
cd testing/python
77+
unset PYTHONPATH
7678
python -m pytest

3rdparty/tvm

Submodule tvm updated from db50d4e to 979c8e7

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.nn.functional as F
33
import tilelang
4-
from tilelang.autotuner import *
54
import tilelang.language as T
65
from einops import rearrange, einsum
76
import argparse
@@ -71,7 +70,7 @@ def flash_attn_split(
7170
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0))
7271
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
7372
has_valid_block = False
74-
# if (start < num_blocks):
73+
7574
for k in T.Pipelined(loop_range, num_stages=num_stages):
7675
i_s = block_indices[bid, cur_kv_head, start + k]
7776
if i_s >= 0:
@@ -238,23 +237,12 @@ def forward(self, query, key, value, block_indices, cache_seqlens):
238237
size_one_kv_head,
239238
is_causal_or_local=True,
240239
max_splits=128)
241-
# print("num_split: ", num_split)
242-
# Function to compile
243-
# def compute_actual_num_blocks(block_indices):
244-
# actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)
245-
# actual_num_blocks = actual_num_blocks[:, 0] # [batch]
246-
# return actual_num_blocks
247-
# compiled_fn = torch.compile(compute_actual_num_blocks)
248-
# actual_num_blocks = compiled_fn(block_indices)
240+
249241
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
250242
output_partial = torch.empty((batch, heads, num_split, dim_v),
251243
dtype=torch.float32,
252244
device='cuda')
253245

254-
# output = self.kernel(
255-
# query, key, value, block_indices, cache_seqlens,
256-
# actual_num_blocks, glse, output_partial
257-
# )
258246
output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial)
259247
return output
260248

@@ -377,8 +365,6 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
377365
all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol)
378366
print(name + " all_close={}".format(all_close))
379367
if not all_close:
380-
# print(expect[3, 28])
381-
# print(actual[3, 28])
382368
diff = (expect - actual).abs()
383369
print("all_close={}, max={}, min={}, mean={}".format(all_close,
384370
diff.max().item(),

examples/convolution/example_convolution.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,8 @@ def main(argv=None):
116116
block_k = 32
117117
num_stages = 3
118118
threads = 256
119-
120-
kernel = tilelang.compile(
121-
convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads), out_idx=[2])
119+
program = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads)
120+
kernel = tilelang.compile(program, out_idx=[2])
122121

123122
out_c = kernel(a, b)
124123
ref_c = ref_program(S, P, D)(a, b)

examples/convolution/test_example_convolution.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44
import example_convolution_autotune
55

66

7+
# TODO(@cy): TMA with convolution must be fixed in future.
8+
@tilelang.testing.requires_cuda
9+
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
710
def test_example_convolution():
811
example_convolution.main([])
912

1013

14+
@tilelang.testing.requires_cuda
15+
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
1116
def test_example_convolution_autotune():
1217
example_convolution_autotune.main()
1318

examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
tilelang.testing.set_random_seed(42)
1010

1111

12-
@tilelang.jit(out_idx=[2])
12+
@tilelang.jit
1313
def tl_gemm(
1414
M,
1515
N,

examples/flash_attention/example_mha_bwd_wgmma_pipelined.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def flash_fwd(
2323
):
2424
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
2525
Q_shared = T.alloc_shared([block_M, dim], dtype)
26-
# Q_local = T.alloc_fragment([block_M, dim], dtype)
2726
K_shared = T.alloc_shared([block_N, dim], dtype)
2827
V_shared = T.alloc_shared([block_N, dim], dtype)
2928
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
@@ -40,9 +39,7 @@ def flash_fwd(
4039
T.fill(acc_o, 0)
4140
T.fill(logsum, 0)
4241
T.fill(scores_max, -T.infinity(accum_dtype))
43-
# T.copy(Q_shared, Q_local)
44-
# for i, j in T.Parallel(block_M, dim):
45-
# Q_local[i, j] *= scale
42+
4643
loop_range = (
4744
T.ceildiv(
4845
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
@@ -264,8 +261,8 @@ def maybe_contiguous(x):
264261
return x
265262

266263
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
267-
block_M = 64
268-
block_N = 64 if D_HEAD <= 64 else 32
264+
block_M = 128
265+
block_N = 128 if D_HEAD <= 64 else 32
269266
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
270267
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
271268
delta = mod_prep(o, do)

0 commit comments

Comments
 (0)