Skip to content

Commit 70deb79

Browse files
committed
Update imports in flash attention test file to use new backward and forward examples for better clarity and consistency.
1 parent ae10dd4 commit 70deb79

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

examples/flash_attention/example_mha_fwd_bhsd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ def main(
223223

224224
if __name__ == "__main__":
225225
parser = argparse.ArgumentParser()
226-
parser.add_argument('--batch', type=int, default=16, help='batch size')
227-
parser.add_argument('--heads', type=int, default=64, help='heads')
228-
parser.add_argument('--seq_q', type=int, default=298, help='query sequence length')
229-
parser.add_argument('--seq_kv', type=int, default=298, help='key/value sequence length')
226+
parser.add_argument('--batch', type=int, default=1, help='batch size')
227+
parser.add_argument('--heads', type=int, default=1, help='heads')
228+
parser.add_argument('--seq_q', type=int, default=256, help='query sequence length')
229+
parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length')
230230
parser.add_argument('--dim', type=int, default=64, help='dim')
231231
parser.add_argument('--is_causal', action='store_true', help='causal', default=False)
232232
parser.add_argument('--tune', action='store_true', help='tune configs')

examples/flash_attention/test_example_flash_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
import example_gqa_bwd
44
import example_gqa_bwd_wgmma_pipelined
5-
import example_mha_bwd
5+
import example_mha_bwd_bshd
66
import example_mha_bwd_bhsd
77
import example_mha_fwd_bhsd_wgmma_pipelined
88
import example_gqa_fwd_bshd
99
import example_mha_fwd_bshd
1010
import example_gqa_fwd_bshd_wgmma_pipelined
1111
import example_mha_fwd_bshd_wgmma_pipelined
1212
import example_mha_fwd_varlen
13-
import example_mha_bwd_wgmma_pipelined
13+
import example_mha_bwd_bshd_wgmma_pipelined
1414
import example_mha_fwd_bhsd
1515
import example_gqa_bwd_tma_reduce_varlen
1616

@@ -33,7 +33,7 @@ def test_example_gqa_bwd_wgmma_pipelined():
3333

3434
@tilelang.testing.requires_cuda
3535
def test_example_mha_bwd():
36-
example_mha_bwd.main(
36+
example_mha_bwd_bshd.main(
3737
BATCH=1,
3838
H=16,
3939
N_CTX=512,
@@ -56,7 +56,7 @@ def test_example_mha_bwd_bhsd():
5656
@tilelang.testing.requires_cuda
5757
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
5858
def test_example_mha_bwd_wgmma_pipelined():
59-
example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
59+
example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
6060

6161

6262
@tilelang.testing.requires_cuda

0 commit comments

Comments
 (0)