Skip to content

Commit

Permalink
Merge pull request #266081 from ConnorBaker/fix/torch-jetson
Browse files Browse the repository at this point in the history
python3Packages.torch: patch cpp_extension.py for Jetson support
  • Loading branch information
Connor Baker authored Nov 9, 2023
2 parents 417c205 + 2a42503 commit 47f07ca
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion pkgs/development/python-modules/torch/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@

let
inherit (lib) attrsets lists strings trivial;
inherit (cudaPackages) cudaFlags cudnn nccl;
inherit (cudaPackages) cudaFlags cudnn;

# Some packages are not available on all platforms
nccl = cudaPackages.nccl or null;

setBool = v: if v then "1" else "0";

Expand Down Expand Up @@ -178,6 +181,13 @@ in buildPythonPackage rec {
'message(FATAL_ERROR "Found NCCL header version and library version' \
'message(WARNING "Found NCCL header version and library version'
''
# TODO(@connorbaker): Remove this patch after 2.1.0 lands.
+ lib.optionalString cudaSupport ''
substituteInPlace torch/utils/cpp_extension.py \
--replace \
"'8.6', '8.9'" \
"'8.6', '8.7', '8.9'"
''
# error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
# This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
+ lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.targetPlatform.darwinSdkVersion "11.0") ''
Expand Down Expand Up @@ -253,6 +263,7 @@ in buildPythonPackage rec {
PYTORCH_BUILD_VERSION = version;
PYTORCH_BUILD_NUMBER = 0;

USE_NCCL = setBool (nccl != null);
USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL
USE_STATIC_NCCL = setBool useSystemNccl;

Expand Down Expand Up @@ -316,6 +327,8 @@ in buildPythonPackage rec {
libcusolver.lib
libcusparse.dev
libcusparse.lib
] ++ lists.optionals (nccl != null) [
# Some platforms do not support NCCL (i.e., Jetson)
nccl.dev # Provides nccl.h AND a static copy of NCCL!
] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
cuda_nvprof.dev # <cuda_profiler_api.h>
Expand Down

0 comments on commit 47f07ca

Please sign in to comment.