Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
251 commits
Select commit Hold shift + click to select a range
cd393e0
[Build] Update version of setuptools used to generate core package (#…
tmm1 Jan 29, 2025
bb135af
Don't compile for CUDA 11, compile for official pytorch 2.6.0
tridao Jan 29, 2025
979702c
Bump to v2.7.4
tridao Jan 29, 2025
5231d95
Drop Pytorch 2.1
tridao Jan 29, 2025
454ce31
[FA3] Compile with nvcc 12.8 instead of 12.3
tridao Jan 29, 2025
803f609
Fix comment in assert
tridao Jan 30, 2025
02541ac
[CE] Assert logit_scale > 0
tridao Jan 30, 2025
2a20412
Implement HeadDim_V != HeadDim_QK, support hdimQK=192, hdimV=128
tridao Feb 3, 2025
6d199aa
Fix shape_O in epilogue params when kHeadDimV != kHeadDim
tridao Feb 4, 2025
86bcd05
Remove old combine.h
tridao Feb 4, 2025
e3b2400
Fix loading paged V when kHeadDimV != kHeadDim
tridao Feb 4, 2025
9e07d6d
Fix shape_V for storing new KV when kHeadDimV != kHeadDim
tridao Feb 4, 2025
f0f2523
Implement the case of LargeHeadDimV
tridao Feb 4, 2025
4c8819d
Rename Mma0->MmaQK, Mma1->MmaPV, use Cluster only if hdimV >= 192
tridao Feb 7, 2025
dd87691
Pass _1 or _0 to cute::aligned_struct
tridao Feb 8, 2025
ed53b5f
Fix compilation for FP8 when kHeadDimV != kHeadDim
tridao Feb 8, 2025
4e8496a
Support Qv
tridao Feb 8, 2025
893a22a
Test varlen_q=True by default for kvcache
tridao Feb 8, 2025
5fab938
Fix num_splits heuristic being called before get_pack_gqa
tridao Feb 8, 2025
5fc5ebf
Fix num_splits heuristic again when PackGQA
tridao Feb 8, 2025
5378bc3
Tile fwd_combine kernel along headdim, don't need kBlockM > 128
tridao Feb 8, 2025
db8ca79
Use bf16 instead of fp16 in benchmark_gemm.py
tridao Feb 9, 2025
982c480
Update Cutlass to 3.7
tridao Feb 9, 2025
58ebfa5
Use nvcc 12.6 but ptxas 12.8
tridao Feb 9, 2025
ed435c6
cicc uses the same version as ptxas
tridao Feb 9, 2025
8668823
Split hdimdiff into a separate translation unit
tridao Feb 9, 2025
b2fc79d
Update benchmark script
tridao Feb 9, 2025
c091545
Update Cutlass to 3.8
tridao Feb 9, 2025
5e39b10
Adjust tile size for hdim 64
tridao Feb 9, 2025
1a7f4df
Adjust ninja build file
tridao Feb 10, 2025
15cf7ee
Rename collective_mainloop -> mainloop, move tile_scheduler variable
tridao Feb 11, 2025
9f313c7
Move functions getting number of m/n blocks to a separate file
tridao Feb 12, 2025
eafd53c
Update cutlass 3.8 to fix error w cudaGetDriverEntryPointByVersion
tridao Feb 12, 2025
fa445ff
Fix FP8 test
tridao Feb 12, 2025
a09abcd
make seqused optional on top level interface (#1497)
vasqu Feb 16, 2025
40cbd52
Temporarily change package name of FA3 to allow FA2 & FA3 install
tridao Feb 18, 2025
91917b4
Update benchmark_split_kv.py to work w new API
tridao Feb 18, 2025
ea3ecea
Add tp_degree to benchmark_split_kv
tridao Feb 18, 2025
74dfa43
Fix divide by 0 in causal tile_scheduler for large seqlen
tridao Feb 19, 2025
b36ad4e
Use split for super long sequences that don't fit into L2
tridao Feb 19, 2025
ecdb528
Make rotary test optional in FA3
tridao Feb 22, 2025
06e34f6
Enable MLA flag in FA3 (rope=64, latent=512) (#1504)
tzadouri Feb 23, 2025
6aed835
Add simple script to benchmark MLA decode
tridao Feb 24, 2025
6752d62
Add dynamic splits
tridao Feb 24, 2025
cdda5bf
Update to Cutlass 3.8.0 tag
tridao Feb 24, 2025
9505c74
Adjust seqlen_q in MLA decode benchmark script
tridao Feb 24, 2025
3b5047d
Fix loop in prepare_scheduler.cu (h/t Jay Shah)
tridao Feb 25, 2025
dec83a1
fix: add "typename" prior to dependent type name (#1517)
zhiweij1 Feb 28, 2025
08f4c80
Add FLOPS to MLA decode benchmark
tridao Feb 28, 2025
085ce58
Change margin in prepare_scheduler.cu from 20% to 10%
tridao Feb 28, 2025
39e7197
Fix cuda 12.1 build (#1511)
LucasWilkinson Mar 1, 2025
20b84d6
Don't use IntraWGOverlap for hdim 64,512
tridao Mar 2, 2025
5458c78
Remove sink token
tridao Mar 2, 2025
6865e60
fix: prompt index to type longlong to avoid numerical overflow (#1500)
xin-w8023 Mar 2, 2025
45c48af
Add option for WG1 to use RS MMA but WG2 using SS MMA
tridao Mar 4, 2025
3edf7e0
Add kwargs to _write_ninja_file for compatibility with new torch
tridao Mar 4, 2025
4f0640d
Move writing P to smem as separate function
tridao Mar 5, 2025
d82bbf2
Fix causal scheduler not considering hdim_v != hdim
tridao Mar 5, 2025
9c036e4
Always split fwd_combine_kernel on batch
tridao Mar 7, 2025
81643fa
For each batch, if num_splits=1, write to O instead of O_partial
tridao Mar 8, 2025
1d30bb4
Enable TMA when page size is a multiple of kBlockN
tridao Mar 9, 2025
a3a9cc5
Update ptxas to 12.8.93 (i.e. 12.8.1)
tridao Mar 9, 2025
322bec9
Use tile size 192 x 128 for hdim 64 causal
tridao Mar 9, 2025
5639b9d
Update benchmark_mla_decode.py
tridao Mar 9, 2025
48b3acb
Benchmark MHA, GQA, MQA, MLA in the same script
tridao Mar 11, 2025
d904855
Benchmark FlashMLA if it's available
tridao Mar 11, 2025
cdaf2de
Run all 4 attn variants in benchmark
tridao Mar 12, 2025
cf1b809
Move scheduler.get_next_work to before the epilogue
tridao Mar 12, 2025
3cf8998
Enable Cluster for hdim128 back
tridao Mar 12, 2025
6063dc5
Move tOrO init in mainloop
tridao Mar 12, 2025
430954a
Adjust heuristic for get_pagedkv_tma
tridao Mar 12, 2025
000090d
Enable PDL
tridao Mar 13, 2025
46e1d4a
Simplify prepare_varlen_num_blocks_kernel, restrict to batch <= 992
tridao Mar 13, 2025
897c845
Fix: num_splits_dynamic_ptr needs to be set before get_num_splits
tridao Mar 14, 2025
90f27a2
Loop on num_splits instead of parameterizing it in kvcache test
tridao Mar 15, 2025
fa60e7c
Add option to precompute scheduler metadata
tridao Mar 15, 2025
6c87fac
Update MLA decode benchmark to use get_scheduler_metadata
tridao Mar 15, 2025
4b5eeab
Fix FP8 test to quantize KV cache for reference impl as well
tridao Mar 15, 2025
27f501d
Dynamic autotune configs for devices with warp size != 32 (#1534)
schung-amd Mar 15, 2025
7ae5f8c
Add option for rotary_seqlens
tridao Mar 21, 2025
fef4fcf
Use StreamkBarrier0/1 barriers instead of TileCountSmemEmpty/Full
tridao Mar 22, 2025
b1951a4
Update Cutlass to 3.9
tridao Mar 22, 2025
df11fca
Support hdim 64,256
tridao Mar 22, 2025
f6a294a
Update benchmark with GLA
tridao Mar 22, 2025
29ef580
Adjust warp scheduler sync for HasQv case
tridao Mar 22, 2025
2f9ef08
num_head -> args.num_head (#1552)
yeqcharlotte Mar 25, 2025
1a58058
Fix zeroing out the scheduler semaphore when reusing metadata
tridao Mar 29, 2025
2dd8078
fix deprecation warning for newer torch versions (#1565)
vasqu Apr 1, 2025
7ff1b62
Don't use FusedDense anymore to simplify code
tridao Apr 7, 2025
aa04de6
Fix FA3 qkvpacked interface
tridao Apr 7, 2025
2afa43c
Launch more thread blocks in layer_norm_bwd
tridao Apr 8, 2025
9f2d2ae
check valid tile before storing num_splits in split_idx (#1578)
jayhshah Apr 9, 2025
d836a6b
Tune rotary kernel to use 2 warps if rotary_dim <= 64
tridao Apr 9, 2025
909eb7c
Implement attention_chunk
tridao Apr 10, 2025
7ff73af
Fix missed attention chunk size param for block specifics in `mma_pv`…
wanderingai Apr 10, 2025
c1352b6
[AMD ROCm] Support MI350 (#1586)
rocking5566 Apr 11, 2025
7bb8e82
Make attention_chunk work for non-causal cases
tridao Apr 12, 2025
fb4c510
Use tile size 128 x 96 for hdim 64,256
tridao Apr 12, 2025
757c5ad
Fix kvcache tests for attention_chunk when precomputing metadata
tridao Apr 12, 2025
fc5a6fa
Fix kvcache test with precomputed metadata: pass in max_seqlen_q
tridao Apr 12, 2025
4d9ba4f
Pass 0 as attention_chunk in the bwd for now
tridao Apr 12, 2025
4d3d2ff
[LayerNorm] Implement option for zero-centered weight
tridao Apr 13, 2025
934f6ad
Make hopper build more robust (#1598)
classner Apr 17, 2025
5e0c258
Fix L2 swizzle in causal tile scheduler
tridao Apr 21, 2025
1522dc7
Use LPT scheduler for causal backward pass
tridao Apr 21, 2025
75f90d6
add sm_margin for hopper flash_attn_qkvpacked_func (#1603)
TopIdiot Apr 22, 2025
37c816a
Support hdimQK != hdimV backward (#1604)
shcho1118 Apr 24, 2025
35e5f00
[AMD] Triton Backend for ROCm #2 (#1610)
micmelesse Apr 24, 2025
f7ba107
Fix (#1602)
co63oc Apr 24, 2025
9b5ae42
feat: add support for torch2.7 (#1574)
NanoCode012 Apr 24, 2025
dc8fd70
[Rotary] Block over seqlen and nheads dimension, use Triton 3.x
tridao Apr 24, 2025
a1be1cc
[CI] Drop support for pytorch 2.2 and 2.3
tridao Apr 24, 2025
1870a0d
[Rotary] Clean up, remove option pos_idx_in_fp32=False
tridao Apr 24, 2025
ef0bbd9
[Rotary] Refactor, test with torch.compile
tridao Apr 25, 2025
93690e2
[Rotary] Wrap apply_rotary_emb_qkv_inplace as a custom op
tridao Apr 25, 2025
41a21d6
Fix import error
tridao Apr 25, 2025
a9a3170
[Rotary] Don't need to wrap in custom_op, just need wrap_triton
tridao Apr 25, 2025
de94700
[LayerNorm] Make compatible with torch.compile
tridao Apr 29, 2025
515e263
[LayerNorm] Add triton_op util function
tridao Apr 29, 2025
ce21272
Don't specialize for hdim 160 anymore
tridao Apr 29, 2025
d462023
[CI] Compile with nvcc 12.8.1
tridao Apr 29, 2025
6ba57ef
Reduce specialization for Alibi to reduce compilation time
tridao Apr 29, 2025
fd2fc9d
[LayerNorm] Don't let torch.compile trace inside _layer_norm_bwd
tridao Apr 30, 2025
98edb0d
[AMD ROCm] Update backend to improve performance (#1654)
rocking5566 May 8, 2025
e9e96d3
Sync the compile flag with CK (#1670)
rocking5566 May 19, 2025
db4baba
[fa3] Use Python stable ABI (#1662)
danthe3rd May 22, 2025
0e79d71
[BE] use more minimal torch headers for hopper/flash_api.cpp (#1674)
janeyx99 May 22, 2025
8e595e5
Indent bwd_sm80.hpp
tridao Jun 1, 2025
931fb8c
[Cute] Implement fwd and bwd for Sm80 in Cute-DSL
tridao Jun 1, 2025
6bec3fb
[Cute] Support GQA
tridao Jun 2, 2025
ea8fe36
[Cute] Implement GQA bwd epilogue
tridao Jun 2, 2025
fad8398
[Cute] Move sm80 util functions to a separate file
tridao Jun 2, 2025
df1847a
[Cute] Move check_type, get_tiled_mma, get_shared_storage to methods
tridao Jun 3, 2025
dcaa072
[Cute] Use WGMMA for attn fwd on Sm90
tridao Jun 4, 2025
47078db
[Cute] Use TMA and warp specialization for attn fwd on Sm90
tridao Jun 7, 2025
d3d95dc
[Cute] Implement inter-warpgroup overlap
tridao Jun 7, 2025
2d8635c
[Cute] Implement PipelineTmaAsyncNoCluster
tridao Jun 8, 2025
3637516
[Cute] Use consumer_try_wait before consumer_wait
tridao Jun 8, 2025
fc27c4f
[fa3] Some fixes for windows build (#1698)
danthe3rd Jun 8, 2025
847025a
[fa3] API default values + backward compatibility (#1700)
danthe3rd Jun 8, 2025
14b0fec
[Cute] Implement intra-warpgroup overlap for attn fwd on Sm90
tridao Jun 8, 2025
c856912
[Cute] Refactor a bit
tridao Jun 8, 2025
69133f8
[Cute] Use TMA to store O in attn fwd epilogue
tridao Jun 9, 2025
8ede036
[Cute] Refactor Softmax and BlockInfo objects
tridao Jun 9, 2025
9a79170
[Cute] Implement varlen_q and varlen_q for attn fwd Sm90
tridao Jun 10, 2025
a737ade
[Cute] Use TMA for O when not varlen
tridao Jun 10, 2025
d31da73
[Cute] Implement PackGQA for attn fwd Sm90
tridao Jun 14, 2025
d417a5b
[CI] Compile with nvcc 12.9.0
tridao Jun 14, 2025
d738303
Update Cutlass to 4.0
tridao Jun 14, 2025
6f8f040
Bump to v2.8.0
tridao Jun 14, 2025
71f7ac2
[CI] Compile with ubuntu-22.04 instead of ubuntu-20.04
tridao Jun 14, 2025
de79b13
[CI] Build with NVCC_THREADS=2 to avoid OOM
tridao Jun 14, 2025
14bfeb3
[Cute] Use NameBarrier, replace cute.elem_less
tridao Jun 15, 2025
32c491f
fix: add tile shape to copy op template args (#1719)
rafacelente Jun 16, 2025
3ba6f82
Fix(hopper): Correct C++ syntax in epilogue_fwd.hpp (#1723)
QuentinFitteRey Jun 21, 2025
b3ae496
[AMD ROCm] Fix intrinsic for ROCm7 (#1729)
rocking5566 Jun 25, 2025
ddfcbed
[Cute] Set check_inf=True always, return smem_pipe_read
tridao Jun 15, 2025
3733dbb
Set line-length for ruff
tridao Jun 15, 2025
ecccf02
[Cute] Refactor Softmax, add fmax_reduce and fadd_reduce
tridao Jun 15, 2025
6c5f5ba
[Cute] Move load and mma to separate functions
tridao Jun 25, 2025
a5e1a3c
[Cute] Add first version of flash_fwd_sm100
tridao Jun 29, 2025
cc25213
[Cute] Don't need neg_inf_if_ge ptx any more
tridao Jun 29, 2025
96acd0f
[Cute] Test flash_fwd_sm100.py with hdim 64
tridao Jun 29, 2025
4834bb5
[Cute] Test flash_fwd_sm100.py with hdim 96
tridao Jun 30, 2025
b517a59
[Cute] Write out LSE for flash_fwd_sm100
tridao Jun 30, 2025
7661781
[Cute] Fix fwd_sm90 epilogue when varlen
tridao Jun 30, 2025
10a8916
[Cute] Implement sliding window for forward pass
tridao Jul 2, 2025
de2ce8f
[Cute] Add ruff options
tridao Jul 2, 2025
217c9d3
[Cute] Run ruff on utility files
tridao Jul 2, 2025
3222ea3
[Cute] Run ruff on bwd_pre/postprocess.py
tridao Jul 2, 2025
62349eb
[Cute] Move tile scheduler to a separate file
tridao Jul 3, 2025
8d454a3
[Cute] Add FastDivmod
tridao Jul 3, 2025
e94e0c2
[Cute] Refactor TileScheduler classes
tridao Jul 3, 2025
525fb43
[Cute] Port SingleTileLPTScheduler from C++ to Python
tridao Jul 3, 2025
60e1e89
[Cute] Update comment about cute version
tridao Jul 4, 2025
6a44198
[Cute] Update to cute-dsl 4.1.0.dev0
tridao Jul 4, 2025
25bd20c
[Cute] Use RS WGMMA for fwd_sm90
tridao Jul 5, 2025
0d0ab1b
[Cute] Use tile_scheduler in fwd_sm90
tridao Jul 5, 2025
312bb9b
[Cute] Add SingleTileVarlenScheduler to fwd_sm90
tridao Jul 5, 2025
10e8c39
[Cute] Do manual f32->f16x2 conversion for fwd_sm90
tridao Jul 6, 2025
3fc8c3c
[Cute] Split tP arrival for fwd_sm100
tridao Jul 6, 2025
723c36b
[Cute] Set tP arrival split to be 3/4
tridao Jul 6, 2025
e540fc1
[Cute] Fix missing tmem_store fence
tridao Jul 6, 2025
aace11d
[Cute] Tune num registers for fwd_sm100
tridao Jul 6, 2025
f14dcb1
[Cute] Check that compute_capability is 9.x or 10.x
tridao Jul 9, 2025
8ba246f
[BE] Better compress flash attention binaries (#1744)
Skylion007 Jul 9, 2025
944811e
adding changes for Windows compile fix for MSVC. (#1716)
loscrossos Jul 9, 2025
1e55644
[CI] Compile with nvcc 12.9.1
tridao Jul 9, 2025
7b0bfcc
Bump to v2.8.1
tridao Jul 9, 2025
adf27d1
[WIP] Add benchmarking script
tridao Jul 9, 2025
ed20940
[FA3] Don't return lse
tridao Jul 11, 2025
87855ac
[Cute] Clean up flash_fwd_sm90 and flash_fwd_sm100 a bit
tridao Jul 13, 2025
3d0e14a
[Cute] Support varlen in flash_fwd_sm100
tridao Jul 13, 2025
730e230
[Cute] Don't need max_seqlen_q for varlen fwd anymore
tridao Jul 13, 2025
10ee063
[Cute] Fix varlen scheduler when SeqUsedQ is not passed in
tridao Jul 13, 2025
c5b0c63
[Cute] Use LPT for SingleTileVarlenScheduler
tridao Jul 14, 2025
bac1001
[Cute] Use bit manipulation for masking in sm100
tridao Jul 15, 2025
b959a98
[Cute] Don't need a separate masking iter if causal for fwd_sm100
tridao Jul 15, 2025
ed6964c
[Cute] Back to having a separate iteration with masking
tridao Jul 15, 2025
c909b67
[Cute] Try e2e
tridao Jul 15, 2025
75c7d99
[Cute] Bench hdim 64
tridao Jul 15, 2025
5639535
[Cute] Bench both hdim 64 and 128
tridao Jul 15, 2025
5d98558
[Cute] Tune num regs
tridao Jul 15, 2025
50e0736
[Cute] Tune regs a bit
tridao Jul 15, 2025
34a3656
[Cute] Bench multiple seqlens
tridao Jul 15, 2025
24f0957
Revert "[BE] Better compress flash attention binaries (#1744)" (#1751)
imoneoi Jul 22, 2025
7321879
Bump to v2.8.2
tridao Jul 24, 2025
413d07e
[AMD ROCm] Fix compilation issue in gfx942 (#1787)
rocking5566 Jul 30, 2025
1a15733
[Cute] Support hdim_v != hdim_qk
tridao Aug 1, 2025
1b36ab1
[Cute] Support hdim (192,128)
tridao Aug 1, 2025
7337307
[Cute] Use kv_stage=3 for hdim (192,128)
tridao Aug 2, 2025
d6dbdaf
[Cute] Simplify some variables, be more careful about self.q_stage
tridao Aug 2, 2025
b8eb683
[Cute] Update to nvidia-cutlass-dsl==4.1.0
tridao Aug 3, 2025
cc5c574
[Cute] Implement additive sink for fwd_sm100
tridao Jul 29, 2025
5bdd30e
[Cute] Sink values in bf16
tridao Jul 29, 2025
e81c237
[Cute] Fix sink impl
tridao Aug 6, 2025
2f78d48
[Cute] Fix row_max not being written to smem when there's sink
tridao Aug 6, 2025
dc742f2
[Cute] Make flash_attn.cute installable as a standalone package
tridao Aug 6, 2025
66ee1b5
[Cute] No longer assume Q, K, V are compact
tridao Aug 9, 2025
5844fa6
[Cute] Fix not allocating enough smem for sScale when there's sink
tridao Aug 9, 2025
8c348fd
[FA3] Fix doc: page block size can be arbitrary
tridao Aug 9, 2025
81cdf4c
[Cute] Don't need i64_to_f32x2 anymore
tridao Aug 12, 2025
c4be578
Remove old xentropy kernel
tridao Aug 12, 2025
3edef7c
Remove old fused softmax kernel from apex/Megatron
tridao Aug 12, 2025
2715c53
Remove old attn decode kernel from FasterTransformer
tridao Aug 12, 2025
f28841d
Remove old rotary kernel
tridao Aug 12, 2025
a1c2e22
[Cute] Implement page table with TMA for fwd_sm100
tridao Aug 12, 2025
581b68d
[Cute] Remove trailing bracket (#1809)
jduprat Aug 13, 2025
3c51f15
[Cute] Make sure R2P happen
tridao Aug 13, 2025
d2e3fc3
feat: add support for pytorch2.8 (#1801)
NanoCode012 Aug 13, 2025
69b33b5
[Cute] Implement PackGQA with TMA for fwd_sm100
tridao Aug 14, 2025
060c918
Bump to v2.8.3
tridao Aug 14, 2025
cd9383f
[BugFix] Fix flash_attn_with_kvcache with scalar cache_seqlen (#1795)
stepinto Aug 15, 2025
b31ae1e
[Cute] Port fwd_combine kernel from C++ to cute-dsl
tridao Aug 17, 2025
591dc7e
[Cute] Simplify tile scheduler storing params
tridao Aug 17, 2025
f8b4f15
[Cute] Implement sink for fwd_sm90
tridao Aug 17, 2025
e1407db
[Cute] Implement PackGQA with TMA for fwd_sm90
tridao Aug 17, 2025
0e60e39
[Cute] Use R2P for masking in fwd_sm90
tridao Aug 17, 2025
199401d
Add sorting and head swizzle to varlen scheduler (#1823)
jayhshah Aug 22, 2025
632fe2a
Fixes incorrect variable reference in comment (#1775)
LoserCheems Aug 24, 2025
832d544
Update the initialization of dk/dv_semaphore (#1839)
y-sq Aug 25, 2025
478841a
Update tile_scheduler.hpp (#1841)
ghadiaravi13 Aug 26, 2025
6f2b052
ci: Move build job to workflow template (#1835)
ko3n1g Aug 27, 2025
b247655
ci: Build via workflow template (#1844)
ko3n1g Aug 27, 2025
d0ed097
ci: Switch to workflow_dispatch (#1847)
ko3n1g Aug 29, 2025
203b9b3
[`FA3`] Allow returning LSE via kwarg (#1851)
vasqu Aug 29, 2025
27b64c7
[BugFix] fix flash_fwd.FlashAttentionForwardSm80 bugs (#1856)
mingyangHao Sep 2, 2025
6387433
[FIX] Allow m_block_size == 192 and mma_pv_is_rs == False in Sm90 CuT…
reubenconducts Sep 2, 2025
afc97c6
make FA3 compatible with CUDA 13 Builds (#1860)
johnnynunez Sep 4, 2025
dfb6649
[BUILD] SBSA wheels + CUDA 13 Support (#1865)
johnnynunez Sep 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
225 changes: 225 additions & 0 deletions .github/workflows/_build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
name: ~Build wheel template

on:
workflow_call:
inputs:
runs-on:
description: "The runner to use for the build"
required: true
type: string
python-version:
description: "The Python version to use for the build"
required: true
type: string
cuda-version:
description: "The CUDA version to use for the build"
required: true
type: string
torch-version:
description: "The PyTorch version to use for the build"
required: true
type: string
cxx11_abi:
description: "The C++11 ABI to use for the build"
required: true
type: string
upload-to-release:
description: "Upload wheel to this release"
required: false
type: boolean
default: false
release-version:
description: "Upload wheel to this release"
required: false
type: string

defaults:
run:
shell: bash -x -e -u -o pipefail {0}

jobs:
build-wheel:
runs-on: ${{ inputs.runs-on }}
name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }})
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ inputs.release-version }}
submodules: recursive

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}

- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV

- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL

- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10

- name: Install CUDA ${{ inputs.cuda-version }}
if: ${{ inputs.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.27
id: cuda-toolkit
with:
cuda: ${{ inputs.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }}
method: "network"
sub-packages: '["nvcc"]'

- name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }}
run: |
pip install --upgrade pip
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
pip install typing-extensions==4.12.2
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
# detect if we're on ARM
if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
PLAT=linux_aarch64
else
PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64
fi
echo "PLAT=$PLAT" >> $GITHUB_ENV
if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then
# pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# Can't use --no-deps because we need cudnn etc.
# Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904
pip install jinja2
TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl
TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl
pip install --no-cache-dir --pre "${TRITON_URL}"
pip install --no-cache-dir --pre "${TORCH_URL}"
else
pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"

- name: Restore build cache
uses: actions/cache/restore@v4
with:
path: build.tar
key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}
restore-keys: |
build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-

- name: Unpack build cache
run: |
echo ::group::Adjust timestamps
sudo find / -exec touch -t 197001010000 {} + || true
echo ::endgroup::

if [ -f build.tar ]; then
find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} +
tar -xpvf build.tar -C .
else
echo "No build.tar found, skipping"
fi

ls -al ./
ls -al build/ || true
ls -al csrc/ || true

- name: Build wheel
id: build_wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==75.8.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
# nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM

export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2)
export NVCC_THREADS=2
export FLASH_ATTENTION_FORCE_BUILD="TRUE"
export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }}

# 5h timeout since GH allows max 6h and we want some buffer
EXIT_CODE=0
timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?

if [ $EXIT_CODE -eq 0 ]; then
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
fi

# Store exit code in GitHub env for later steps
echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT"

# Do not fail the job if timeout killed the build
exit $EXIT_CODE

- name: Log build logs after timeout
if: always() && steps.build_wheel.outputs.build_exit_code == 124
run: |
ls -al ./
tar -cvf build.tar . --atime-preserve=replace

- name: Save build cache timeout
if: always() && steps.build_wheel.outputs.build_exit_code == 124
uses: actions/cache/save@v4
with:
key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}
path: build.tar

- name: Log Built Wheels
run: |
ls dist

- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ inputs.release-version }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

- name: Upload Release Asset
id: upload_release_asset
if: inputs.upload-to-release
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
47 changes: 47 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: Build wheels

on:
workflow_dispatch:
inputs:
runs-on:
description: "The runner to use for the build"
required: true
type: string
default: ubuntu-22.04
python-version:
description: "The Python version to use for the build"
required: true
type: string
cuda-version:
description: "The CUDA version to use for the build"
required: true
type: string
torch-version:
description: "The PyTorch version to use for the build"
required: true
type: string
cxx11_abi:
description: "Enable torch flag C++11 ABI (TRUE/FALSE)"
required: true
type: string
upload-to-release:
description: "Upload wheel to this release"
required: false
type: boolean
default: false
release-version:
description: "Upload wheel to this release"
required: false
type: string

jobs:
build-wheels:
uses: ./.github/workflows/_build.yml
with:
runs-on: ${{ inputs.runs-on }}
python-version: ${{ inputs.python-version }}
cuda-version: ${{ inputs.cuda-version }}
torch-version: ${{ inputs.torch-version }}
cxx11_abi: ${{ inputs.cxx11_abi }}
upload-to-release: ${{ inputs.upload-to-release }}
release-version: ${{ inputs.release-version }}
Loading