Skip to content

Commit 54101e9

Browse files
johnnynunezyzh119
andauthored
[NVIDIA] Thor & Spark Support (#2028)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Thor and Spark support when wheels are generating ## πŸ” Related Issues Output says that is not compatible. Only with JIT is working. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Broadened GPU architecture support to include additional newer architectures. * **Documentation** * Updated README and installation docs to show the revised CUDA architecture example list. * **Chores** * Adjusted release/nightly workflows and build scripts to select architectures using an expanded CUDA-version threshold and branching logic. * **Performance** * Extended architecture-specific build/runtime handling to cover an additional GPU architecture affecting memory-related behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <expye@outlook.com> Co-authored-by: yzh119 <zihaoy@nvidia.com>
1 parent b433fc7 commit 54101e9

File tree

7 files changed

+17
-7
lines changed

7 files changed

+17
-7
lines changed

β€Ž.github/workflows/nightly-release.ymlβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ jobs:
145145
- name: Build wheel in container
146146
env:
147147
DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }}
148-
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }}
148+
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0f 12.0f' }}
149149
FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }}
150150
run: |
151151
# Extract CUDA major and minor versions

β€Ž.github/workflows/release.ymlβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ jobs:
182182
- name: Build wheel in container
183183
env:
184184
DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }}
185-
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }}
185+
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0f 12.0f' }}
186186
run: |
187187
# Extract CUDA major and minor versions
188188
CUDA_MAJOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f1)

β€ŽREADME.mdβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ python -m pip install dist/*.whl
9090

9191
`flashinfer-jit-cache` (customize `FLASHINFER_CUDA_ARCH_LIST` for your target GPUs):
9292
```bash
93-
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a"
93+
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0f 12.0f"
9494
cd flashinfer-jit-cache
9595
python -m build --no-isolation --wheel
9696
python -m pip install dist/*.whl

β€Žcsrc/xqa/mha.cuβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ __constant__ constexpr uint32_t cacheVTileSeqLen = 32;
9393
constexpr uint32_t preferedKHeadPartBytes = 64;
9494
__constant__ constexpr uint32_t cacheVTileSeqLen = 32;
9595
#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 900 || \
96-
__CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030
96+
__CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030 || __CUDA_ARCH__ == 1100
9797
constexpr uint32_t preferedKHeadPartBytes = 128;
9898
__constant__ constexpr uint32_t cacheVTileSeqLen = 64;
9999
#else

β€Žcsrc/xqa/utils.cuhβ€Ž

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ __constant__ constexpr float kE4M3_MAX = 448.F;
4646
constexpr uint32_t kMAX_SMEM_SIZE = (99u << 10);
4747
#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870
4848
constexpr uint32_t kMAX_SMEM_SIZE = (163u << 10);
49-
#elif __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030
49+
#elif __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030 || \
50+
__CUDA_ARCH__ == 1100
5051
constexpr uint32_t kMAX_SMEM_SIZE = (227u << 10);
5152
#endif
5253
#endif

β€Ždocs/installation.rstβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ You can follow the steps below to install FlashInfer from source code:
9292

9393
.. code-block:: bash
9494
95-
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a"
95+
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0f 12.0f"
9696
cd flashinfer-jit-cache
9797
python -m build --no-isolation --wheel
9898
python -m pip install dist/*.whl

β€Žscripts/task_test_jit_cache_package_build_import.shβ€Ž

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,16 @@ arches = ["7.5", "8.0", "8.9", "9.0a"]
4343
if cuda_ver is not None:
4444
try:
4545
major, minor = map(int, cuda_ver.split(".")[:2])
46-
if (major, minor) >= (12, 8):
46+
if (major, minor) >= (13, 0):
47+
arches.append("10.0a")
48+
arches.append("10.3a")
49+
arches.append("11.0f")
50+
arches.append("12.0f")
51+
elif (major, minor) >= (12, 9):
52+
arches.append("10.0a")
53+
arches.append("10.3a")
54+
arches.append("12.0f")
55+
elif (major, minor) >= (12, 8):
4756
arches.append("10.0a")
4857
arches.append("12.0a")
4958
except Exception:

0 commit comments

Comments
Β (0)