From faaf22ce880fbc76a7842ad5cf08de0c0d33e010 Mon Sep 17 00:00:00 2001 From: Connor Baker Date: Tue, 22 Aug 2023 14:58:14 +0000 Subject: [PATCH] libtorch: work on some cuda refactoring --- .../science/math/libtorch/default.nix | 73 ++++++++++++------- 1 file changed, 48 insertions(+), 25 deletions(-) diff --git a/pkgs/development/libraries/science/math/libtorch/default.nix b/pkgs/development/libraries/science/math/libtorch/default.nix index fe143ef73494b..826e0638b3cca 100644 --- a/pkgs/development/libraries/science/math/libtorch/default.nix +++ b/pkgs/development/libraries/science/math/libtorch/default.nix @@ -2,7 +2,6 @@ fetchFromGitHub, fetchpatch, pkgs, - symlinkJoin, # nativeBuildInputs asmjit, blas, @@ -23,6 +22,7 @@ mpi, ninja, numactl, + onnx, protobuf, psimd, pthreadpool, @@ -43,7 +43,7 @@ useXnnpack ? true, useZstd ? true, }: let - inherit (lib) lists; + inherit (lib) lists strings; setBool = bool: if bool then "ON" @@ -82,28 +82,6 @@ }; }); - cuda-redist = symlinkJoin { - name = "cuda-redist"; - paths = with cudaPackages; - [ - autoAddOpenGLRunpathHook - cuda_cccl # and CUB - cuda_cudart - cuda_cupti # Needed by Kineto for GPU profiling - cuda_nvcc - cuda_nvml_dev - cuda_nvrtc - cuda_nvtx - libcublas - libcufft - libcurand - libcusolver - libcusparse - nccl.dev - ] - ++ lists.optionals useCudnn [cudnn]; - }; - mkDerivation = if useCuda then cudaPackages.backendStdenv.mkDerivation @@ -171,6 +149,7 @@ in rm -rf FXdiv* rm -rf gloo* rm -rf ideep/mkl-dnn* + rm -rf onnx* rm -rf protobuf* rm -rf psimd* rm -rf pthreadpool* @@ -235,6 +214,7 @@ in fxdiv gflags glog + onnx protobuf psimd pthreadpool @@ -248,7 +228,14 @@ in zlib ] # Optional dependencies - ++ lists.optionals useCuda [cuda-redist] + ++ lists.optionals useCuda ( + # TODO(@connorbaker): Is this correct that we need both cudart and nvcc as native dependencies? + with cudaPackages; [ + autoAddOpenGLRunpathHook + cuda_cudart # cuda_runtime.h + cuda_nvcc # crt/host_config.h + ] + ) ++ lists.optionals useGloo [gloo] ++ lists.optionals useMagma [magma] ++ lists.optionals useMkldnn [oneDNN.dev] # oneDNN is the new name for MKL-DNN @@ -257,6 +244,39 @@ in ++ lists.optionals useXnnpack [xnnpack] ++ lists.optionals useZstd [zstd.dev]; + # TODO(@connorbaker): Currently CUDA build fails with: + # CMake Error at cmake/public/cuda.cmake:65 (message): + # Found two conflicting CUDA installs: + # + # V11.8.89 in + # '/nix/store/rsjxr5b5zifa0wbpziwqfzg7lncfz0f0-cuda_cudart-11.8.89/include' + # and + # + # V11.8.89 in + # '/nix/store/rsjxr5b5zifa0wbpziwqfzg7lncfz0f0-cuda_cudart-11.8.89/include;/nix/store/nljxvgbp6fy0q7cbrp5l5igv57p5fa3v-cuda_nvcc-11.8.89/include;/nix/store/mfk63jcw2r77asgai82rzbzbph10dhh8-cuda_cccl-11.8.89/include;/nix/store/0xhbghrnf7x289m78c8ha2dm6n83wfbg-cuda_cupti-11.8.87/include;/nix/store/4x7gb192a6pskj2skwn9s3m0vnn73bff-cuda_nvml_dev-11.8.86/include;/nix/store/00p0i6kqw6qjbrc4fddqfnv07zcg7gi1-cuda_nvrtc-11.8.89/include;/nix/store/953p97p0inb7wdj50qcz47dy3lh58vhq-cuda_nvtx-11.8.86/include;/nix/store/qsm8bjydfnapr77wzlyzyzcsnkc0yrh2-libcublas-11.11.3.6/include;/nix/store/fszipvg6jw9dsj2lz1izwy7363mwh4fj-libcufft-10.9.0.58/include;/nix/store/8r9kj0rh0kk9iqi32kkm1bdxqb8jipbr-libcurand-10.3.0.86/include;/nix/store/f0d08h7g4apgngbyrgqvpjxmlp3azf0m-libcusolver-11.4.1.48/include;/nix/store/141gw8r2ypg27186mzg81rhndl402l80-libcusparse-11.7.5.86/include;/nix/store/z5ppzlnw5wzy5bbvhm76kfmjmirpkqhb-cuda_profiler_api-11.8.86/include' + buildInputs = lists.optionals useCuda (with cudaPackages; + [ + (lib.getDev nccl) + cuda_cccl # + cuda_cupti + cuda_nvml_dev # + cuda_nvrtc + cuda_nvtx # -llibNVToolsExt + libcublas + libcufft + libcurand + libcusolver + libcusparse + nccl + ] + ++ lists.optionals useCudnn [cudnn] + ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [ + cuda_nvprof # + ] + ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [ + cuda_profiler_api # + ]); + cmakeFlags = # Core configuration options [ @@ -264,6 +284,8 @@ in "-DBUILD_PYTHON:BOOL=OFF" "-DBUILD_SHARED_LIBS:BOOL=ON" "-DCMAKE_BUILD_TYPE:STRING=Release" + "-DCMAKE_C_STANDARD:STRING=17" + "-DCMAKE_CXX_STANDARD:STRING=17" "-DUSE_PRECOMPILED_HEADERS:BOOL=ON" ] # Core dependencies @@ -279,6 +301,7 @@ in "-DUSE_SYSTEM_FMT:BOOL=ON" "-DUSE_SYSTEM_FP16:BOOL=ON" "-DUSE_SYSTEM_FXDIV:BOOL=ON" + "-DUSE_SYSTEM_ONNX:BOOL=ON" "-DUSE_SYSTEM_PSIMD:BOOL=ON" "-DUSE_SYSTEM_PTHREADPOOL:BOOL=ON" "-DUSE_SYSTEM_PYBIND11:BOOL=ON"