-
-
Notifications
You must be signed in to change notification settings - Fork 15k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cudaPackages.nccl: support building with CUDA < 11.4 with cudatoolkit
(cherry picked from commit d2800c5)
- Loading branch information
1 parent
a12e4b0
commit 4efe723
Showing
1 changed file
with
99 additions
and
71 deletions.
There are no files selected for viewing
170 changes: 99 additions & 71 deletions
170
pkgs/development/libraries/science/math/nccl/default.nix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1,113 @@ | ||
{ lib | ||
, backendStdenv | ||
, fetchFromGitHub | ||
, python3 | ||
, which | ||
, autoAddOpenGLRunpathHook | ||
, cuda_cccl | ||
, cuda_cudart | ||
, cuda_nvcc | ||
, cudaFlags | ||
, cudaVersion | ||
# passthru.updateScript | ||
, gitUpdater | ||
# NOTE: Though NCCL is called within the cudaPackages package set, we avoid passing in | ||
# the names of dependencies from that package set directly to avoid evaluation errors | ||
# in the case redistributable packages are not available. | ||
{ | ||
lib, | ||
fetchFromGitHub, | ||
python3, | ||
which, | ||
cudaPackages, | ||
# passthru.updateScript | ||
gitUpdater, | ||
}: | ||
let | ||
# Output looks like "-gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86" | ||
gencode = lib.concatStringsSep " " cudaFlags.gencode; | ||
|
||
inherit (cudaPackages) | ||
autoAddOpenGLRunpathHook | ||
backendStdenv | ||
cuda_cccl | ||
cuda_cudart | ||
cuda_nvcc | ||
cudaFlags | ||
cudatoolkit | ||
cudaVersion | ||
; | ||
in | ||
backendStdenv.mkDerivation (finalAttrs: { | ||
pname = "nccl"; | ||
version = "2.19.3-1"; | ||
backendStdenv.mkDerivation ( | ||
finalAttrs: { | ||
pname = "nccl"; | ||
version = "2.19.3-1"; | ||
|
||
src = fetchFromGitHub { | ||
owner = "NVIDIA"; | ||
repo = finalAttrs.pname; | ||
rev = "v${finalAttrs.version}"; | ||
hash = "sha256-59FlOKM5EB5Vkm4dZBRCkn+IgIcdQehE+FyZAdTCT/A="; | ||
}; | ||
src = fetchFromGitHub { | ||
owner = "NVIDIA"; | ||
repo = finalAttrs.pname; | ||
rev = "v${finalAttrs.version}"; | ||
hash = "sha256-59FlOKM5EB5Vkm4dZBRCkn+IgIcdQehE+FyZAdTCT/A="; | ||
}; | ||
|
||
outputs = [ "out" "dev" ]; | ||
strictDeps = true; | ||
|
||
nativeBuildInputs = [ | ||
which | ||
autoAddOpenGLRunpathHook | ||
cuda_nvcc | ||
python3 | ||
]; | ||
outputs = [ | ||
"out" | ||
"dev" | ||
]; | ||
|
||
buildInputs = [ | ||
cuda_cudart | ||
] | ||
# NOTE: CUDA versions in Nixpkgs only use a major and minor version. When we do comparisons | ||
# against other version, like below, it's important that we use the same format. Otherwise, | ||
# we'll get incorrect results. | ||
# For example, lib.versionAtLeast "12.0" "12.0.0" == false. | ||
++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [ | ||
cuda_cccl | ||
]; | ||
nativeBuildInputs = | ||
[ | ||
which | ||
autoAddOpenGLRunpathHook | ||
python3 | ||
] | ||
++ lib.optionals (lib.versionOlder cudaVersion "11.4") [cudatoolkit] | ||
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [cuda_nvcc]; | ||
|
||
buildInputs = | ||
lib.optionals (lib.versionOlder cudaVersion "11.4") [cudatoolkit] | ||
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ | ||
cuda_nvcc.dev # crt/host_config.h | ||
cuda_cudart | ||
] | ||
# NOTE: CUDA versions in Nixpkgs only use a major and minor version. When we do comparisons | ||
# against other version, like below, it's important that we use the same format. Otherwise, | ||
# we'll get incorrect results. | ||
# For example, lib.versionAtLeast "12.0" "12.0.0" == false. | ||
++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [cuda_cccl]; | ||
|
||
preConfigure = '' | ||
patchShebangs ./src/device/generate.py | ||
makeFlagsArray+=( | ||
"NVCC_GENCODE=${gencode}" | ||
) | ||
''; | ||
env.NIX_CFLAGS_COMPILE = toString ["-Wno-unused-function"]; | ||
|
||
makeFlags = [ | ||
"CUDA_HOME=${cuda_nvcc}" | ||
"CUDA_LIB=${lib.getLib cuda_cudart}/lib" | ||
"CUDA_INC=${lib.getDev cuda_cudart}/include" | ||
"PREFIX=$(out)" | ||
]; | ||
preConfigure = '' | ||
patchShebangs ./src/device/generate.py | ||
makeFlagsArray+=( | ||
"NVCC_GENCODE=${lib.concatStringsSep " " cudaFlags.gencode}" | ||
) | ||
''; | ||
|
||
postFixup = '' | ||
moveToOutput lib/libnccl_static.a $dev | ||
''; | ||
makeFlags = | ||
["PREFIX=$(out)"] | ||
++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ | ||
"CUDA_HOME=${cudatoolkit}" | ||
"CUDA_LIB=${lib.getLib cudatoolkit}/lib" | ||
"CUDA_INC=${lib.getDev cudatoolkit}/include" | ||
] | ||
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ | ||
"CUDA_HOME=${cuda_nvcc}" | ||
"CUDA_LIB=${lib.getLib cuda_cudart}/lib" | ||
"CUDA_INC=${lib.getDev cuda_cudart}/include" | ||
]; | ||
|
||
env.NIX_CFLAGS_COMPILE = toString [ "-Wno-unused-function" ]; | ||
enableParallelBuilding = true; | ||
|
||
# Run the update script with: `nix-shell maintainers/scripts/update.nix --argstr package cudaPackages.nccl` | ||
passthru.updateScript = gitUpdater { | ||
inherit (finalAttrs) pname version; | ||
rev-prefix = "v"; | ||
}; | ||
postFixup = '' | ||
moveToOutput lib/libnccl_static.a $dev | ||
''; | ||
|
||
enableParallelBuilding = true; | ||
passthru.updateScript = gitUpdater { | ||
inherit (finalAttrs) pname version; | ||
rev-prefix = "v"; | ||
}; | ||
|
||
meta = with lib; { | ||
description = "Multi-GPU and multi-node collective communication primitives for NVIDIA GPUs"; | ||
homepage = "https://developer.nvidia.com/nccl"; | ||
license = licenses.bsd3; | ||
platforms = platforms.linux; | ||
maintainers = with maintainers; [ mdaiter orivej ] ++ teams.cuda.members; | ||
}; | ||
}) | ||
meta = with lib; { | ||
description = "Multi-GPU and multi-node collective communication primitives for NVIDIA GPUs"; | ||
homepage = "https://developer.nvidia.com/nccl"; | ||
license = licenses.bsd3; | ||
platforms = platforms.linux; | ||
maintainers = | ||
with maintainers; | ||
[ | ||
mdaiter | ||
orivej | ||
] | ||
++ teams.cuda.members; | ||
}; | ||
} | ||
) |