diff --git a/recipe/bld.bat b/recipe/bld.bat index 30cc5d4f..b785aa32 100644 --- a/recipe/bld.bat +++ b/recipe/bld.bat @@ -16,7 +16,8 @@ if "%build_with_cuda%" == "" goto cuda_flags_end set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%desired_cuda% set CUDA_BIN_PATH=%CUDA_PATH%\bin -set TORCH_CUDA_ARCH_LIST=5.0;6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX +REM Keep this list in sync with https://github.com/pytorch/pytorch/blob/07fa6e2c8b003319f85a469307f1b1dd73f6026c/.ci/manywheel/build_cuda.sh#L60 +set TORCH_CUDA_ARCH_LIST=5.0;6.0;7.0;7.5;8.0;8.6;9.0+PTX set TORCH_NVCC_FLAGS=-Xfatbin -compress-all :cuda_flags_end diff --git a/recipe/build.sh b/recipe/build.sh index 19da0ad4..851f5c58 100644 --- a/recipe/build.sh +++ b/recipe/build.sh @@ -168,7 +168,8 @@ elif [[ ${cuda_compiler_version} != "None" ]]; then esac case ${cuda_compiler_version} in 12.6) - export TORCH_CUDA_ARCH_LIST="5.0;6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX" + # Keep this list in sync with https://github.com/pytorch/pytorch/blob/07fa6e2c8b003319f85a469307f1b1dd73f6026c/.ci/manywheel/build_cuda.sh#L60 + export TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0+PTX" ;; *) echo "unsupported cuda version. edit build.sh"