Skip to content

Commit dfb6649

Browse files
authored
[BUILD] SBSA wheels + CUDA 13 Support (Dao-AILab#1865)
* [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * drop 12.4 * drop 12.4 * fix correct name * fix correct name * fix correct name * fix correct name * cibuildwheel.yml
1 parent afc97c6 commit dfb6649

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

.github/workflows/_build.yml

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ jobs:
7777

7878
- name: Install CUDA ${{ inputs.cuda-version }}
7979
if: ${{ inputs.cuda-version != 'cpu' }}
80-
uses: Jimver/cuda-toolkit@v0.2.26
80+
uses: Jimver/cuda-toolkit@v0.2.27
8181
id: cuda-toolkit
8282
with:
8383
cuda: ${{ inputs.cuda-version }}
@@ -98,17 +98,26 @@ jobs:
9898
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
9999
# This code is ugly, maybe there's a better way to do this.
100100
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
101-
minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126}[env['MATRIX_TORCH_VERSION']]; \
102-
maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129}[env['MATRIX_TORCH_VERSION']]; \
101+
minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \
102+
maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \
103103
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
104104
)
105+
# detect if we're on ARM
106+
if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
107+
PLAT=linux_aarch64
108+
else
109+
PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64
110+
fi
111+
echo "PLAT=$PLAT" >> $GITHUB_ENV
105112
if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then
106113
# pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
107114
# Can't use --no-deps because we need cudnn etc.
108-
# Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
115+
# Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904
109116
pip install jinja2
110-
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
111-
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
117+
TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl
118+
TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl
119+
pip install --no-cache-dir --pre "${TRITON_URL}"
120+
pip install --no-cache-dir --pre "${TORCH_URL}"
112121
else
113122
pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
114123
fi

.github/workflows/publish.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
matrix:
4141
# Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the
4242
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
43-
os: [ubuntu-22.04]
43+
os: [ubuntu-22.04, ubuntu-22.04-arm]
4444
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
4545
torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"]
4646
cuda-version: ["12.9.1"]
@@ -49,6 +49,9 @@ jobs:
4949
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
5050
# when building without C++11 ABI and using it on nvcr images.
5151
cxx11_abi: ["FALSE", "TRUE"]
52+
include:
53+
- torch-version: "2.9.0.dev20250904"
54+
cuda-version: "13.0"
5255
exclude:
5356
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
5457
# Pytorch < 2.5 does not support Python 3.13

0 commit comments

Comments
 (0)