From 5113b4ea6cbfcbe8af9fd68ddcf42454f8e09ac3 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Fri, 30 Aug 2024 11:45:14 -0700 Subject: [PATCH 1/8] Starting work to make GPU/CUDA optional. Work in progress. --- CMakeLists.txt | 91 ++++++++++--------- examples/cpp/cart_pole/CMakeLists.txt | 8 +- src/csrc/distributed.cpp | 20 ++++ src/csrc/include/internal/distributed.h | 6 +- src/csrc/include/internal/nvtx.h | 7 ++ src/csrc/include/internal/rl/distributions.h | 2 - src/csrc/include/internal/rl/noise_actor.h | 4 - src/csrc/include/internal/rl/off_policy.h | 10 ++ .../include/internal/rl/off_policy/ddpg.h | 4 - src/csrc/include/internal/rl/off_policy/sac.h | 4 - src/csrc/include/internal/rl/off_policy/td3.h | 4 - src/csrc/include/internal/rl/on_policy.h | 12 ++- src/csrc/include/internal/rl/on_policy/ppo.h | 4 - src/csrc/include/internal/rl/policy.h | 2 + src/csrc/include/internal/rl/replay_buffer.h | 4 - src/csrc/include/internal/rl/rollout_buffer.h | 4 - src/csrc/include/internal/rl/utils.h | 2 + src/csrc/include/internal/training.h | 6 ++ src/csrc/include/internal/utils.h | 2 + src/csrc/include/torchfort.h | 4 + src/csrc/include/torchfort_rl.h | 4 + src/csrc/model_wrapper.cpp | 2 + src/csrc/rl/off_policy/interface.cpp | 6 +- src/csrc/rl/on_policy/interface.cpp | 4 + src/csrc/torchfort.cpp | 2 + src/csrc/utils.cpp | 6 +- 26 files changed, 148 insertions(+), 76 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 963491f..28e4e10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,7 @@ set(TORCHFORT_YAML_CPP_ROOT CACHE STRING "Path to search for yaml-cpp installati option(TORCHFORT_BUILD_FORTRAN "Build Fortran bindings" ON) option(TORCHFORT_BUILD_EXAMPLES "Build examples" OFF) option(TORCHFORT_BUILD_TESTS "Build tests" OFF) +option(TORCHFORT_ENABLE_GPU "Enable GPU/CUDA support" ON) # For backward-compatibility with existing variable if (YAML_CPP_ROOT) @@ -54,51 +55,53 @@ endif() find_package(MPI REQUIRED) # CUDA -find_package(CUDAToolkit REQUIRED) - -# HPC SDK -# Locate and append NVHPC CMake configuration if available -find_program(NVHPC_CXX_BIN "nvc++") -if (NVHPC_CXX_BIN) - string(REPLACE "compilers/bin/nvc++" "cmake" NVHPC_CMAKE_DIR ${NVHPC_CXX_BIN}) - set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH};${NVHPC_CMAKE_DIR}") - find_package(NVHPC COMPONENTS "") -endif() - -# Get NCCL library (with optional override) -if (TORCHFORT_NCCL_ROOT) - find_path(NCCL_INCLUDE_DIR REQUIRED - NAMES nccl.h - HINTS ${TORCHFORT_NCCL_ROOT}/include - ) - - find_library(NCCL_LIBRARY REQUIRED - NAMES nccl - HINTS ${TORCHFORT_NCCL_ROOT}/lib - ) -else() - if (NVHPC_FOUND) - find_package(NVHPC REQUIRED COMPONENTS NCCL) - find_library(NCCL_LIBRARY +if (TORCHFORT_ENABLE_GPU) + find_package(CUDAToolkit REQUIRED) + + # HPC SDK + # Locate and append NVHPC CMake configuration if available + find_program(NVHPC_CXX_BIN "nvc++") + if (NVHPC_CXX_BIN) + string(REPLACE "compilers/bin/nvc++" "cmake" NVHPC_CMAKE_DIR ${NVHPC_CXX_BIN}) + set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH};${NVHPC_CMAKE_DIR}") + find_package(NVHPC COMPONENTS "") + endif() + + # Get NCCL library (with optional override) + if (TORCHFORT_NCCL_ROOT) + find_path(NCCL_INCLUDE_DIR REQUIRED + NAMES nccl.h + HINTS ${TORCHFORT_NCCL_ROOT}/include + ) + + find_library(NCCL_LIBRARY REQUIRED NAMES nccl - HINTS ${NVHPC_NCCL_LIBRARY_DIR} + HINTS ${TORCHFORT_NCCL_ROOT}/lib ) - string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR}) else() - message(FATAL_ERROR "Cannot find NCCL library. Please set TORCHFORT_NCCL_ROOT to NCCL installation directory.") + if (NVHPC_FOUND) + find_package(NVHPC REQUIRED COMPONENTS NCCL) + find_library(NCCL_LIBRARY + NAMES nccl + HINTS ${NVHPC_NCCL_LIBRARY_DIR} + ) + string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR}) + else() + message(FATAL_ERROR "Cannot find NCCL library. Please set TORCHFORT_NCCL_ROOT to NCCL installation directory.") + endif() endif() + + message(STATUS "Using NCCL library: ${NCCL_LIBRARY}") + + # PyTorch + # Set TORCH_CUDA_ARCH_LIST string to match TORCHFORT_CUDA_CC_LIST + foreach(CUDA_CC ${TORCHFORT_CUDA_CC_LIST}) + string(REGEX REPLACE "([0-9])$" ".\\1" CUDA_CC_W_DOT ${CUDA_CC}) + list(APPEND TORCH_CUDA_ARCH_LIST ${CUDA_CC_W_DOT}) + endforeach() + list(JOIN TORCH_CUDA_ARCH_LIST " " TORCH_CUDA_ARCH_LIST) endif() -message(STATUS "Using NCCL library: ${NCCL_LIBRARY}") - -# PyTorch -# Set TORCH_CUDA_ARCH_LIST string to match TORCHFORT_CUDA_CC_LIST -foreach(CUDA_CC ${TORCHFORT_CUDA_CC_LIST}) - string(REGEX REPLACE "([0-9])$" ".\\1" CUDA_CC_W_DOT ${CUDA_CC}) - list(APPEND TORCH_CUDA_ARCH_LIST ${CUDA_CC_W_DOT}) -endforeach() -list(JOIN TORCH_CUDA_ARCH_LIST " " TORCH_CUDA_ARCH_LIST) - find_package(Torch REQUIRED) # yaml-cpp @@ -160,16 +163,22 @@ target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(${PROJECT_NAME} PRIVATE ${NCCL_LIBRARY}) target_link_libraries(${PROJECT_NAME} PRIVATE MPI::MPI_CXX) target_link_libraries(${PROJECT_NAME} PRIVATE ${YAML_CPP_LIBRARY}) -target_link_libraries(${PROJECT_NAME} PRIVATE CUDA::cudart) target_include_directories(${PROJECT_NAME} PRIVATE ${YAML_CPP_INCLUDE_DIR} ${MPI_CXX_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} +) +if (TORCHFORT_ENABLE_GPU) + target_include_directories(${PROJECT_NAME} + PRIVATE ${CUDAToolkit_INCLUDE_DIRS} ${NCCL_INCLUDE_DIR} -) + ) + target_link_libraries(${PROJECT_NAME} PRIVATE CUDA::cudart) + target_compile_definitions(${PROJECT_NAME} PRIVATE ENABLE_GPU) +endif() target_compile_definitions(${PROJECT_NAME} PRIVATE YAML_CPP_STATIC_DEFINE) target_compile_options(${PROJECT_NAME} PRIVATE $<$:${TORCH_CXX_FLAGS}>) diff --git a/examples/cpp/cart_pole/CMakeLists.txt b/examples/cpp/cart_pole/CMakeLists.txt index a9b57d8..92fb5dc 100644 --- a/examples/cpp/cart_pole/CMakeLists.txt +++ b/examples/cpp/cart_pole/CMakeLists.txt @@ -35,9 +35,15 @@ foreach(tgt ${cart_pole_example_targets}) target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) target_link_libraries(${tgt} PRIVATE environments) - target_link_libraries(${tgt} PRIVATE CUDA::cudart) target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) + if (TORCHFORT_ENABLE_GPU) + target_include_directories(${tgt} + PRIVATE + ${CUDAToolkit_INCLUDE_DIRS} + ) + target_link_libraries(${tgt} PRIVATE CUDA::cudart) + endif() endforeach() # installation diff --git a/src/csrc/distributed.cpp b/src/csrc/distributed.cpp index 35cbef9..2800333 100644 --- a/src/csrc/distributed.cpp +++ b/src/csrc/distributed.cpp @@ -29,9 +29,11 @@ */ #include +#ifdef ENABLE_GPU #include #include +#endif #include #include "internal/defines.h" @@ -51,6 +53,7 @@ static MPI_Datatype get_mpi_dtype(torch::Tensor tensor) { } } +#ifdef ENABLE_GPU static ncclDataType_t get_nccl_dtype(torch::Tensor tensor) { auto dtype = tensor.dtype(); @@ -77,11 +80,13 @@ static ncclComm_t ncclCommFromMPIComm(MPI_Comm mpi_comm) { return nccl_comm; } +#endif void Comm::initialize(bool initialize_nccl) { CHECK_MPI(MPI_Comm_rank(mpi_comm, &rank)); CHECK_MPI(MPI_Comm_size(mpi_comm, &size)); +#ifdef ENABLE_GPU if (initialize_nccl) { nccl_comm = ncclCommFromMPIComm(mpi_comm); @@ -91,17 +96,21 @@ void Comm::initialize(bool initialize_nccl) { CHECK_CUDA(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); } +#endif initialized = true; } void Comm::finalize() { +#ifdef ENABLE_GPU if (nccl_comm) CHECK_NCCL(ncclCommDestroy(nccl_comm)); if (stream) CHECK_CUDA(cudaStreamDestroy(stream)); if (event) CHECK_CUDA(cudaEventDestroy(event)); +#endif } void Comm::allreduce(torch::Tensor& tensor, bool average) const { +#ifdef ENABLE_GPU if (tensor.device().type() == torch::kCUDA) { auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); CHECK_CUDA(cudaEventRecord(event, torch_stream)); @@ -121,6 +130,7 @@ void Comm::allreduce(torch::Tensor& tensor, bool average) const { CHECK_CUDA(cudaEventRecord(event, stream)); CHECK_CUDA(cudaStreamWaitEvent(torch_stream, event)); } else if (tensor.device().type() == torch::kCPU) { +#endif auto count = torch::numel(tensor); MPI_Datatype mpi_dtype; if (torch::is_complex(tensor)) { @@ -135,28 +145,34 @@ void Comm::allreduce(torch::Tensor& tensor, bool average) const { if (average) { tensor /= size; } +#ifdef ENABLE_GPU } +#endif } void Comm::allreduce(std::vector& tensors, bool average) const { +#ifdef ENABLE_GPU if (tensors[0].device().type() == torch::kCUDA) { auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); CHECK_CUDA(cudaEventRecord(event, torch_stream)); CHECK_CUDA(cudaStreamWaitEvent(stream, event)); CHECK_NCCL(ncclGroupStart()); } +#endif for (auto& t : tensors) { allreduce(t, average); } +#ifdef ENABLE_GPU if (tensors[0].device().type() == torch::kCUDA) { CHECK_NCCL(ncclGroupEnd()); auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); CHECK_CUDA(cudaEventRecord(event, stream)); CHECK_CUDA(cudaStreamWaitEvent(torch_stream, event)); } +#endif } void Comm::allreduce(double& val, bool average) const { CHECK_MPI(MPI_Allreduce(MPI_IN_PLACE, &val, 1, MPI_DOUBLE, MPI_SUM, mpi_comm)); @@ -174,6 +190,7 @@ void Comm::allreduce(float& val, bool average) const { void Comm::broadcast(torch::Tensor& tensor, int root) const { auto count = torch::numel(tensor); +#ifdef ENABLE_GPU if (tensor.device().type() == torch::kCUDA) { // Use NCCL for GPU tensors ncclDataType_t nccl_dtype; @@ -194,6 +211,7 @@ void Comm::broadcast(torch::Tensor& tensor, int root) const { CHECK_CUDA(cudaEventRecord(event, stream)); CHECK_CUDA(cudaStreamWaitEvent(torch_stream, event)); } else if (tensor.device().type() == torch::kCPU) { +#endif // Use MPI for CPU tensors MPI_Datatype mpi_dtype; if (torch::is_complex(tensor)) { @@ -205,7 +223,9 @@ void Comm::broadcast(torch::Tensor& tensor, int root) const { CHECK_MPI(MPI_Bcast(tensor.data_ptr(), count, mpi_dtype, root, mpi_comm)); +#ifdef ENABLE_GPU } +#endif } } // namespace torchfort diff --git a/src/csrc/include/internal/distributed.h b/src/csrc/include/internal/distributed.h index 3bca38f..68c26dc 100644 --- a/src/csrc/include/internal/distributed.h +++ b/src/csrc/include/internal/distributed.h @@ -30,9 +30,11 @@ #pragma once +#ifdef ENABLE_GPU #include -#include #include +#endif +#include #include @@ -50,9 +52,11 @@ struct Comm { int rank; int size; MPI_Comm mpi_comm; +#ifdef ENABLE_GPU ncclComm_t nccl_comm = nullptr; cudaStream_t stream = nullptr; cudaEvent_t event = nullptr; +#endif bool initialized = false; Comm(MPI_Comm mpi_comm) : mpi_comm(mpi_comm) {}; diff --git a/src/csrc/include/internal/nvtx.h b/src/csrc/include/internal/nvtx.h index e5ea9bd..accb704 100644 --- a/src/csrc/include/internal/nvtx.h +++ b/src/csrc/include/internal/nvtx.h @@ -32,13 +32,16 @@ #include +#ifdef ENABLE_GPU #include +#endif namespace torchfort { // Helper class for NVTX ranges class nvtx { public: +#ifdef ENABLE_GPU static void rangePush(const std::string& range_name) { static constexpr int ncolors_ = 8; static constexpr int colors_[ncolors_] = {0x3366CC, 0xDC3912, 0xFF9900, 0x109618, @@ -56,6 +59,10 @@ class nvtx { } static void rangePop() { nvtxRangePop(); } +#else + static void rangePush(const std::string& range_name) {} + static void rangePop() {} +#endif }; } // namespace torchfort diff --git a/src/csrc/include/internal/rl/distributions.h b/src/csrc/include/internal/rl/distributions.h index ea37c20..48a57ae 100644 --- a/src/csrc/include/internal/rl/distributions.h +++ b/src/csrc/include/internal/rl/distributions.h @@ -29,8 +29,6 @@ */ #pragma once -#include - #include #include diff --git a/src/csrc/include/internal/rl/noise_actor.h b/src/csrc/include/internal/rl/noise_actor.h index 194abf0..33f33e1 100644 --- a/src/csrc/include/internal/rl/noise_actor.h +++ b/src/csrc/include/internal/rl/noise_actor.h @@ -31,10 +31,6 @@ #pragma once #include -#include - -#include -#include #include #include "internal/model_pack.h" diff --git a/src/csrc/include/internal/rl/off_policy.h b/src/csrc/include/internal/rl/off_policy.h index fe3069f..51277f7 100644 --- a/src/csrc/include/internal/rl/off_policy.h +++ b/src/csrc/include/internal/rl/off_policy.h @@ -33,10 +33,12 @@ #include +#ifdef ENABLE_GPU #include #include #include +#endif #include #include "internal/defines.h" @@ -95,12 +97,14 @@ static void update_replay_buffer(const char* name, T* state_old, T* state_new, s // no grad torch::NoGradGuard no_grad; +#ifdef ENABLE_GPU c10::cuda::OptionalCUDAStreamGuard guard; auto rb_device = registry[name]->rbDevice(); if (rb_device.is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index()); guard.reset_stream(stream); } +#endif // get tensors and copy: auto state_old_tensor = get_tensor(state_old, state_dim, state_shape) @@ -119,6 +123,7 @@ template static void predict_explore(const char* name, T* state, size_t state_dim, int64_t* state_shape, T* action, size_t action_dim, int64_t* action_shape, cudaStream_t ext_stream) { +#ifdef ENABLE_GPU // device and stream handling c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); @@ -126,6 +131,7 @@ static void predict_explore(const char* name, T* state, size_t state_dim, int64_ auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); guard.reset_stream(stream); } +#endif // create tensors auto state_tensor = get_tensor(state, state_dim, state_shape) @@ -143,6 +149,7 @@ template static void predict(const char* name, T* state, size_t state_dim, int64_t* state_shape, T* action, size_t action_dim, int64_t* action_shape, cudaStream_t ext_stream) { +#ifdef ENABLE_GPU // device and stream handling c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); @@ -150,6 +157,7 @@ static void predict(const char* name, T* state, size_t state_dim, int64_t* state auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); guard.reset_stream(stream); } +#endif // create tensors auto state_tensor = get_tensor(state, state_dim, state_shape) @@ -168,6 +176,7 @@ static void policy_evaluate(const char* name, T* state, size_t state_dim, int64_ int64_t* action_shape, T* reward, size_t reward_dim, int64_t* reward_shape, cudaStream_t ext_stream) { +#ifdef ENABLE_GPU // device and stream handling c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); @@ -175,6 +184,7 @@ static void policy_evaluate(const char* name, T* state, size_t state_dim, int64_ auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); guard.reset_stream(stream); } +#endif // create tensors auto state_tensor = get_tensor(state, state_dim, state_shape) diff --git a/src/csrc/include/internal/rl/off_policy/ddpg.h b/src/csrc/include/internal/rl/off_policy/ddpg.h index 607bfd3..f1a4783 100644 --- a/src/csrc/include/internal/rl/off_policy/ddpg.h +++ b/src/csrc/include/internal/rl/off_policy/ddpg.h @@ -33,10 +33,6 @@ #include -#include - -#include -#include #include #include "internal/defines.h" diff --git a/src/csrc/include/internal/rl/off_policy/sac.h b/src/csrc/include/internal/rl/off_policy/sac.h index e51e519..c969e62 100644 --- a/src/csrc/include/internal/rl/off_policy/sac.h +++ b/src/csrc/include/internal/rl/off_policy/sac.h @@ -33,10 +33,6 @@ #include -#include - -#include -#include #include #include "internal/defines.h" diff --git a/src/csrc/include/internal/rl/off_policy/td3.h b/src/csrc/include/internal/rl/off_policy/td3.h index 2d86e94..054afe8 100644 --- a/src/csrc/include/internal/rl/off_policy/td3.h +++ b/src/csrc/include/internal/rl/off_policy/td3.h @@ -33,10 +33,6 @@ #include -#include - -#include -#include #include #include "internal/defines.h" diff --git a/src/csrc/include/internal/rl/on_policy.h b/src/csrc/include/internal/rl/on_policy.h index 30f5899..92217bd 100644 --- a/src/csrc/include/internal/rl/on_policy.h +++ b/src/csrc/include/internal/rl/on_policy.h @@ -33,10 +33,12 @@ #include +#ifdef ENABLE_GPU #include #include #include +#endif #include #include "internal/defines.h" @@ -98,9 +100,10 @@ static void update_rollout_buffer(const char* name, T* state, size_t state_dim, torch::NoGradGuard no_grad; // we need to sync carefully here - c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); auto rb_device = registry[name]->rbDevice(); +#ifdef ENABLE_GPU + c10::cuda::OptionalCUDAStreamGuard guard; if (model_device.is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); guard.reset_stream(stream); @@ -108,6 +111,7 @@ static void update_rollout_buffer(const char* name, T* state, size_t state_dim, auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index()); guard.reset_stream(stream); } +#endif // get tensors and copy: torch::Tensor state_tensor = get_tensor(state, state_dim, state_shape) @@ -125,6 +129,7 @@ template static void predict_explore(const char* name, T* state, size_t state_dim, int64_t* state_shape, T* action, size_t action_dim, int64_t* action_shape, cudaStream_t ext_stream) { +#ifdef ENABLE_GPU // device and stream handling c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); @@ -132,6 +137,7 @@ static void predict_explore(const char* name, T* state, size_t state_dim, int64_ auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); guard.reset_stream(stream); } +#endif // create tensors auto state_tensor = get_tensor(state, state_dim, state_shape) @@ -149,6 +155,7 @@ template static void predict(const char* name, T* state, size_t state_dim, int64_t* state_shape, T* action, size_t action_dim, int64_t* action_shape, cudaStream_t ext_stream) { +#ifdef ENABLE_GPU // device and stream handling c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); @@ -156,6 +163,7 @@ static void predict(const char* name, T* state, size_t state_dim, int64_t* state auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); guard.reset_stream(stream); } +#endif // create tensors auto state_tensor = get_tensor(state, state_dim, state_shape) @@ -174,6 +182,7 @@ static void policy_evaluate(const char* name, T* state, size_t state_dim, int64_ int64_t* action_shape, T* reward, size_t reward_dim, int64_t* reward_shape, cudaStream_t ext_stream) { +#ifdef ENABLE_GPU // device and stream handling c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); @@ -181,6 +190,7 @@ static void policy_evaluate(const char* name, T* state, size_t state_dim, int64_ auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); guard.reset_stream(stream); } +#endif // create tensors auto state_tensor = get_tensor(state, state_dim, state_shape) diff --git a/src/csrc/include/internal/rl/on_policy/ppo.h b/src/csrc/include/internal/rl/on_policy/ppo.h index 1846d3b..a3b4568 100644 --- a/src/csrc/include/internal/rl/on_policy/ppo.h +++ b/src/csrc/include/internal/rl/on_policy/ppo.h @@ -33,10 +33,6 @@ #include -#include - -#include -#include #include #include "internal/defines.h" diff --git a/src/csrc/include/internal/rl/policy.h b/src/csrc/include/internal/rl/policy.h index 272c936..33bb92e 100644 --- a/src/csrc/include/internal/rl/policy.h +++ b/src/csrc/include/internal/rl/policy.h @@ -33,10 +33,12 @@ #include +#ifdef ENABLE_GPU #include #include #include +#endif #include #include "internal/defines.h" diff --git a/src/csrc/include/internal/rl/replay_buffer.h b/src/csrc/include/internal/rl/replay_buffer.h index da7281a..992ce47 100644 --- a/src/csrc/include/internal/rl/replay_buffer.h +++ b/src/csrc/include/internal/rl/replay_buffer.h @@ -33,10 +33,6 @@ #include #include -#include - -#include -#include #include #include "internal/defines.h" diff --git a/src/csrc/include/internal/rl/rollout_buffer.h b/src/csrc/include/internal/rl/rollout_buffer.h index 49fc9a9..5f45af2 100644 --- a/src/csrc/include/internal/rl/rollout_buffer.h +++ b/src/csrc/include/internal/rl/rollout_buffer.h @@ -33,10 +33,6 @@ #include #include -#include - -#include -#include #include #include "internal/defines.h" diff --git a/src/csrc/include/internal/rl/utils.h b/src/csrc/include/internal/rl/utils.h index 6d286b2..2a7320c 100644 --- a/src/csrc/include/internal/rl/utils.h +++ b/src/csrc/include/internal/rl/utils.h @@ -31,10 +31,12 @@ #pragma once #include +#ifdef ENABLE_GPU #include #include #include +#endif #include #include "internal/defines.h" diff --git a/src/csrc/include/internal/training.h b/src/csrc/include/internal/training.h index dae11d3..7bbc5df 100644 --- a/src/csrc/include/internal/training.h +++ b/src/csrc/include/internal/training.h @@ -32,10 +32,12 @@ #include #include +#ifdef ENABLE_GPU #include #include #include +#endif #include #include @@ -60,11 +62,13 @@ void inference(const char* name, T* input, size_t input_dim, int64_t* input_shap auto model = models[name].model.get(); +#if ENABLE_GPU c10::cuda::OptionalCUDAStreamGuard guard; if (model->device().is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, model->device().index()); guard.reset_stream(stream); } +#endif auto input_tensor_in = get_tensor(input, input_dim, input_shape); auto output_tensor_in = get_tensor(output, output_dim, output_shape); @@ -93,11 +97,13 @@ void train(const char* name, T* input, size_t input_dim, int64_t* input_shape, T auto model = models[name].model.get(); +#ifdef ENABLE_GPU c10::cuda::OptionalCUDAStreamGuard guard; if (model->device().is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, model->device().index()); guard.reset_stream(stream); } +#endif auto input_tensor_in = get_tensor(input, input_dim, input_shape); auto label_tensor_in = get_tensor(label, label_dim, label_shape); diff --git a/src/csrc/include/internal/utils.h b/src/csrc/include/internal/utils.h index a4e55bc..d630052 100644 --- a/src/csrc/include/internal/utils.h +++ b/src/csrc/include/internal/utils.h @@ -37,7 +37,9 @@ #include #include +#ifdef ENABLE_GPU #include +#endif #include "internal/exceptions.h" #include "internal/nvtx.h" diff --git a/src/csrc/include/torchfort.h b/src/csrc/include/torchfort.h index 7711896..160f9cd 100644 --- a/src/csrc/include/torchfort.h +++ b/src/csrc/include/torchfort.h @@ -30,7 +30,11 @@ #pragma once #include +#ifdef ENABLE_GPU #include +#else +typedef void* cudaStream_t; +#endif #include #include "torchfort_enums.h" diff --git a/src/csrc/include/torchfort_rl.h b/src/csrc/include/torchfort_rl.h index 0057488..aa9a1c0 100644 --- a/src/csrc/include/torchfort_rl.h +++ b/src/csrc/include/torchfort_rl.h @@ -30,7 +30,11 @@ #pragma once #include "torchfort_enums.h" +#ifdef ENABLE_GPU #include +#else +typedef void* cudaStream_t; +#endif #define RL_OFF_POLICY_WANDB_LOG_FUNC(dtype) \ torchfort_result_t torchfort_rl_off_policy_wandb_log_##dtype(const char* name, const char* metric_name, \ diff --git a/src/csrc/model_wrapper.cpp b/src/csrc/model_wrapper.cpp index 4fb932b..a821bf9 100644 --- a/src/csrc/model_wrapper.cpp +++ b/src/csrc/model_wrapper.cpp @@ -33,7 +33,9 @@ #include #include +#ifdef ENABLE_GPU #include +#endif #include #include diff --git a/src/csrc/rl/off_policy/interface.cpp b/src/csrc/rl/off_policy/interface.cpp index 652005b..00be420 100644 --- a/src/csrc/rl/off_policy/interface.cpp +++ b/src/csrc/rl/off_policy/interface.cpp @@ -33,7 +33,9 @@ #include #include +#ifdef ENABLE_GPU #include +#endif #include #include #include @@ -164,9 +166,10 @@ torchfort_result_t torchfort_rl_off_policy_train_step(const char* name, float* p using namespace torchfort; // TODO: we need to figure out what to do if RB and Model streams are different - c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = rl::off_policy::registry[name]->modelDevice(); auto rb_device = rl::off_policy::registry[name]->rbDevice(); +#ifdef ENABLE_GPU + c10::cuda::OptionalCUDAStreamGuard guard; if (model_device.is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); guard.reset_stream(stream); @@ -174,6 +177,7 @@ torchfort_result_t torchfort_rl_off_policy_train_step(const char* name, float* p auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index()); guard.reset_stream(stream); } +#endif try { // perform a training step diff --git a/src/csrc/rl/on_policy/interface.cpp b/src/csrc/rl/on_policy/interface.cpp index 92cf8c1..59621b9 100644 --- a/src/csrc/rl/on_policy/interface.cpp +++ b/src/csrc/rl/on_policy/interface.cpp @@ -33,7 +33,9 @@ #include #include +#ifdef ENABLE_GPU #include +#endif #include #include #include @@ -157,6 +159,7 @@ torchfort_result_t torchfort_rl_on_policy_train_step(const char* name, float* p_ cudaStream_t ext_stream) { using namespace torchfort; +#ifdef ENABLE_GPU // TODO: we need to figure out what to do if RB and Model streams are different c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = rl::on_policy::registry[name]->modelDevice(); @@ -164,6 +167,7 @@ torchfort_result_t torchfort_rl_on_policy_train_step(const char* name, float* p_ auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); guard.reset_stream(stream); } +#endif try { // perform a training step diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index 399692c..6ce9bb5 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -36,7 +36,9 @@ #include #include +#ifdef ENABLE_GPU #include +#endif #include #include #include diff --git a/src/csrc/utils.cpp b/src/csrc/utils.cpp index 3a5ba10..81fdf0b 100644 --- a/src/csrc/utils.cpp +++ b/src/csrc/utils.cpp @@ -60,16 +60,19 @@ std::string filename_sanitize(std::string s) { torch::Device get_device(int device) { torch::Device device_torch(torch::kCPU); +#ifdef ENABLE_GPU if (device != TORCHFORT_DEVICE_CPU) { device_torch = torch::Device(torch::kCUDA, device); } +#endif return device_torch; } torch::Device get_device(const void* ptr) { + torch::Device device = torch::Device(torch::kCPU); +#ifdef ENABLE_GPU cudaPointerAttributes attr; CHECK_CUDA(cudaPointerGetAttributes(&attr, ptr)); - torch::Device device = torch::Device(torch::kCPU); switch (attr.type) { case cudaMemoryTypeHost: case cudaMemoryTypeUnregistered: @@ -78,6 +81,7 @@ torch::Device get_device(const void* ptr) { case cudaMemoryTypeDevice: device = torch::Device(torch::kCUDA); break; } +#endif return device; } From 276de319e4798ffc7efbc389aae8e8ae20930b61 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Fri, 30 Aug 2024 12:05:24 -0700 Subject: [PATCH 2/8] Add work in progress CPU only Dockerfile. --- docker/Dockerfile_cpu | 55 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 docker/Dockerfile_cpu diff --git a/docker/Dockerfile_cpu b/docker/Dockerfile_cpu new file mode 100644 index 0000000..f2121a9 --- /dev/null +++ b/docker/Dockerfile_cpu @@ -0,0 +1,55 @@ +FROM ubuntu:22.04 + +# Install System Dependencies +ENV DEBIAN_FRONTEND noninteractive +RUN apt update -y && \ + apt install -y build-essential && \ + apt install -y curl unzip wget cmake python3 python-is-python3 python3-pip python3-pybind11 git vim gfortran doxygen libibverbs-dev + +# Download HPCX and compile with Fortran support +RUN cd /opt && \ + wget https://download.open-mpi.org/release/open-mpi/v5.0/openmpi-5.0.5.tar.gz && \ + tar xzf openmpi-5.0.5.tar.gz && \ + cd openmpi-5.0.5 && \ + FC=gfortran CC=gcc CXX=g++ ./configure --prefix=/opt/openmpi \ + --with-libevent=internal \ + --enable-mpi1-compatibility \ + --without-xpmem \ + --with-slurm && \ + make -j$(nproc) install && \ + cd /opt && rm -rf openmpi-5.0.5 && rm openmpi-5.0.5.tar.gz + +ENV PATH /opt/openmpi/bin:$PATH +ENV LD_LIBRARY_PATH /opt/openmpi/lib:$LD_LIBRARY_PATH + +# Install PyTorch +RUN pip3 install torch==2.2.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + +# Install yaml-cpp +RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ + cd yaml-cpp && \ + mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \ + -DCMAKE_CXX_FLAGS:="-D_GLIBCXX_USE_CXX11_ABI=0" \ + -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \ + make -j$(nproc) && make install +ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH} + +# Install HDF5 +RUN wget https://github.com/HDFGroup/hdf5/archive/refs/tags/hdf5-1_14_3.tar.gz && \ + tar xzf hdf5-1_14_3.tar.gz && \ + cd hdf5-hdf5-1_14_3 && \ + CC=mpicc FC=mpifort \ + ./configure --enable-parallel \ + --enable-fortran \ + --prefix=/opt/hdf5 && \ + make -j$(nproc) install && \ + cd .. && \ + rm -rf hdf5-hdf5-1_14_3 hdf5-1_14_3.tar.gz +ENV LD_LIBRARY_PATH /opt/hdf5/lib:$LD_LIBRARY_PATH + +# Install additional Python dependencies +RUN pip3 install wandb ruamel-yaml h5py matplotlib pygame moviepy + +ENTRYPOINT bash From 952edb344fecceb39fc8a94245f189bd972430a7 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Fri, 30 Aug 2024 12:12:54 -0700 Subject: [PATCH 3/8] Temporarily disable cart_pole example for CPU only build. --- CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 28e4e10..7e7c8ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -243,7 +243,10 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/docs DESTINATION ${CMAKE_INSTALL_P # build examples if (TORCHFORT_BUILD_EXAMPLES) - add_subdirectory(examples/cpp/cart_pole) + if (TORCHFORT_ENABLE_GPU) + # TODO: Enable cart_pole example for CPU only builds + add_subdirectory(examples/cpp/cart_pole) + endif() if (TORCHFORT_BUILD_FORTRAN) add_subdirectory(examples/fortran/simulation) endif() From 8abe9dae8a5acf2ac56db96a3dbc0c9c0e7897d7 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 3 Sep 2024 03:06:01 -0700 Subject: [PATCH 4/8] conditional linking of tests against cudart --- tests/rl/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/rl/CMakeLists.txt b/tests/rl/CMakeLists.txt index 45086ab..e181361 100644 --- a/tests/rl/CMakeLists.txt +++ b/tests/rl/CMakeLists.txt @@ -56,10 +56,12 @@ foreach(tgt ${test_targets}) target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) - target_link_libraries(${tgt} PRIVATE CUDA::cudart) target_link_libraries(${tgt} PRIVATE GTest::gtest_main) target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) +if (TORCHFORT_ENABLE_GPU) + target_link_libraries(${tgt} PRIVATE CUDA::cudart) +endif() # discover tests: we have an issue with the work dir of gtest so disable that for now #gtest_discover_tests(${tgt}) From 5eedb0fdb1413cc0c2e5aa7cad17380acfc4def9 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 3 Sep 2024 07:16:49 -0700 Subject: [PATCH 5/8] fixed config file for cartpole example, made GPU training for that example optional. Fixed test compilation when cuda is disabled --- CMakeLists.txt | 5 +-- examples/cpp/cart_pole/config.yaml | 2 +- .../cpp/cart_pole/python/initialize_models.py | 11 ++++--- examples/cpp/cart_pole/python/visualize.py | 11 ++++--- examples/cpp/cart_pole/train.cpp | 33 ++++++++++++++++++- 5 files changed, 48 insertions(+), 14 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7e7c8ed..28e4e10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -243,10 +243,7 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/docs DESTINATION ${CMAKE_INSTALL_P # build examples if (TORCHFORT_BUILD_EXAMPLES) - if (TORCHFORT_ENABLE_GPU) - # TODO: Enable cart_pole example for CPU only builds - add_subdirectory(examples/cpp/cart_pole) - endif() + add_subdirectory(examples/cpp/cart_pole) if (TORCHFORT_BUILD_FORTRAN) add_subdirectory(examples/fortran/simulation) endif() diff --git a/examples/cpp/cart_pole/config.yaml b/examples/cpp/cart_pole/config.yaml index a52a6b8..82528d9 100644 --- a/examples/cpp/cart_pole/config.yaml +++ b/examples/cpp/cart_pole/config.yaml @@ -14,7 +14,7 @@ algorithm: gamma: 0.99 rho: 0.99 -action: +actor: type: space_noise parameters: a_low: -1.0 diff --git a/examples/cpp/cart_pole/python/initialize_models.py b/examples/cpp/cart_pole/python/initialize_models.py index 10e3e60..052dd42 100644 --- a/examples/cpp/cart_pole/python/initialize_models.py +++ b/examples/cpp/cart_pole/python/initialize_models.py @@ -40,10 +40,13 @@ def main(args): # set seed torch.manual_seed(666) - torch.cuda.manual_seed(666) - - # script model: - device = torch.device("cuda:0") + + # CUDA check + if torch.cuda.is_available(): + torch.cuda.manual_seed(666) + device = torch.device("cuda:0") + else: + device = torch.device("cpu") # parameters batch_size = 64 diff --git a/examples/cpp/cart_pole/python/visualize.py b/examples/cpp/cart_pole/python/visualize.py index cbdfbba..0d1b1c4 100644 --- a/examples/cpp/cart_pole/python/visualize.py +++ b/examples/cpp/cart_pole/python/visualize.py @@ -130,11 +130,14 @@ def main(args): # set seed torch.manual_seed(666) - torch.cuda.manual_seed(666) - - # script model: - device = torch.device("cuda:0") + # CUDA check + if torch.cuda.is_available(): + torch.cuda.manual_seed(666) + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + # parameters batch_size = 1 diff --git a/examples/cpp/cart_pole/train.cpp b/examples/cpp/cart_pole/train.cpp index 0813db4..8ea5865 100644 --- a/examples/cpp/cart_pole/train.cpp +++ b/examples/cpp/cart_pole/train.cpp @@ -28,6 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +#include #include #include #include @@ -49,6 +50,7 @@ } \ } while (false) + int main(int argc, char* argv[]) { // load config file @@ -91,8 +93,13 @@ int main(int argc, char* argv[]) { float theta_width = state_max[2] - state_min[2]; // instantiate torchfort +#if ENABLE_GPU CHECK_TORCHFORT(torchfort_rl_off_policy_create_system("td3_system", "config.yaml", 0, TORCHFORT_DEVICE_CPU)); +#else + CHECK_TORCHFORT(torchfort_rl_off_policy_create_system("td3_system", "config.yaml", + TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); +#endif // define variables StateVector state, state_new; @@ -103,11 +110,17 @@ int main(int argc, char* argv[]) { bool terminate; // allocate cuda arrays - cudaSetDevice(0); float *dstate, *dstate_new, *daction, *dreward; +#ifdef ENABLE_GPU + cudaSetDevice(0); cudaMalloc(&dstate, state.size() * sizeof(float)); cudaMalloc(&dstate_new, state.size() * sizeof(float)); cudaMalloc(&daction, action.size() * sizeof(float)); +#else + dstate = static_cast(std::malloc(state.size() * sizeof(float))); + dstate_new = static_cast(std::malloc(state.size() * sizeof(float))); + daction = static_cast(std::malloc(action.size() * sizeof(float))); +#endif int64_t step_total = 0; bool is_eval = false; @@ -136,7 +149,11 @@ int main(int argc, char* argv[]) { step_total++; // copy data to device +#ifdef ENABLE_GPU cudaMemcpy(dstate, state.data(), state.size() * sizeof(float), cudaMemcpyHostToDevice); +#else + std::memcpy(dstate, state.data(), state.size() * sizeof(float)); +#endif // state check std::cout << prefix + "state: " << state[0] << ", " << state[1] << ", " << state[2] << ", " << state[3] @@ -151,7 +168,11 @@ int main(int argc, char* argv[]) { 0)); // copy data to host +#ifdef ENABLE_GPU cudaMemcpy(action.data(), daction, action.size() * sizeof(float), cudaMemcpyDeviceToHost); +#else + std::memcpy(action.data(), daction, action.size() * sizeof(float)); +#endif // action check std::cout << prefix + "action: " << action[0] << std::endl; @@ -167,7 +188,11 @@ int main(int argc, char* argv[]) { } // copy data to device +#ifdef ENABLE_GPU cudaMemcpy(dstate_new, state_new.data(), state_new.size() * sizeof(float), cudaMemcpyHostToDevice); +#else + std::memcpy(dstate_new, state_new.data(), state_new.size() * sizeof(float)); +#endif // update replay buffer CHECK_TORCHFORT(torchfort_rl_off_policy_update_replay_buffer( @@ -211,9 +236,15 @@ int main(int argc, char* argv[]) { } // clean up +#ifdef ENABLE_GPU cudaFree(dstate); cudaFree(dstate_new); cudaFree(daction); +#else + std::free(dstate); + std::free(dstate_new); + std::free(daction); +#endif return 0; } From ab0abbfd21e1402d6d83596d887f66cf2ee180dd Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Tue, 3 Sep 2024 21:28:07 -0700 Subject: [PATCH 6/8] Throw exception if model/component is placed on GPU when GPU is not enabled. Update CPU only dockerfile. Fix generate_fcn_model.py with CPU only PyTorch. --- ...{Dockerfile_cpu => Dockerfile_gnu_cpuonly} | 21 +++++++++++++++++-- .../fortran/simulation/generate_fcn_model.py | 7 +++++-- src/csrc/utils.cpp | 6 ++++-- 3 files changed, 28 insertions(+), 6 deletions(-) rename docker/{Dockerfile_cpu => Dockerfile_gnu_cpuonly} (72%) diff --git a/docker/Dockerfile_cpu b/docker/Dockerfile_gnu_cpuonly similarity index 72% rename from docker/Dockerfile_cpu rename to docker/Dockerfile_gnu_cpuonly index f2121a9..b836636 100644 --- a/docker/Dockerfile_cpu +++ b/docker/Dockerfile_gnu_cpuonly @@ -6,7 +6,7 @@ RUN apt update -y && \ apt install -y build-essential && \ apt install -y curl unzip wget cmake python3 python-is-python3 python3-pip python3-pybind11 git vim gfortran doxygen libibverbs-dev -# Download HPCX and compile with Fortran support +# Download OpenMPI and compile with Fortran support RUN cd /opt && \ wget https://download.open-mpi.org/release/open-mpi/v5.0/openmpi-5.0.5.tar.gz && \ tar xzf openmpi-5.0.5.tar.gz && \ @@ -23,7 +23,7 @@ ENV PATH /opt/openmpi/bin:$PATH ENV LD_LIBRARY_PATH /opt/openmpi/lib:$LD_LIBRARY_PATH # Install PyTorch -RUN pip3 install torch==2.2.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +RUN pip3 install torch==2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu # Install yaml-cpp RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ @@ -52,4 +52,21 @@ ENV LD_LIBRARY_PATH /opt/hdf5/lib:$LD_LIBRARY_PATH # Install additional Python dependencies RUN pip3 install wandb ruamel-yaml h5py matplotlib pygame moviepy +# Install TorchFort without GPU support +ENV FC=gfortran +ENV HDF5_ROOT=/opt/hdf5 +COPY . /torchfort +RUN cd /torchfort && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \ + -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \ + -DTORCHFORT_ENABLE_GPU=0 \ + -DTORCHFORT_BUILD_EXAMPLES=1 \ + -DTORCHFORT_BUILD_TESTS=1 \ + -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ + .. && \ + make -j$(nproc) install && \ + cd / && rm -rf torchfort +ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH} +ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH} + ENTRYPOINT bash diff --git a/examples/fortran/simulation/generate_fcn_model.py b/examples/fortran/simulation/generate_fcn_model.py index fd520d2..45ac5ed 100644 --- a/examples/fortran/simulation/generate_fcn_model.py +++ b/examples/fortran/simulation/generate_fcn_model.py @@ -42,8 +42,11 @@ def main(): model = Net() print("FCN model:", model) - # Move model to GPU, JIT, and save - model.to("cuda") + try: + # Move model to GPU, JIT, and save + model.to("cuda") + except: + print("PyTorch does not have CUDA support. Saving model on CPU.") model_jit = torch.jit.script(model) model_jit.save("fcn_torchscript.pt") diff --git a/src/csrc/utils.cpp b/src/csrc/utils.cpp index 81fdf0b..6956859 100644 --- a/src/csrc/utils.cpp +++ b/src/csrc/utils.cpp @@ -60,11 +60,13 @@ std::string filename_sanitize(std::string s) { torch::Device get_device(int device) { torch::Device device_torch(torch::kCPU); -#ifdef ENABLE_GPU if (device != TORCHFORT_DEVICE_CPU) { +#ifdef ENABLE_GPU device_torch = torch::Device(torch::kCUDA, device); - } +#else + THROW_NOT_SUPPORTED("Attempted to place a model or other component on GPU but TorchFort was build without GPU support."); #endif + } return device_torch; } From 9a8d91672dca31a8b583181ae284946c28277431 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 4 Sep 2024 02:16:03 -0700 Subject: [PATCH 7/8] updating torch in Dockerfile --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 5d12b5e..3727082 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -32,7 +32,7 @@ RUN cd /opt && \ ENV LD_LIBRARY_PATH /opt/nccl/build/lib:$LD_LIBRARY_PATH # Install PyTorch -RUN pip3 install torch==2.2.1 torchvision torchaudio +RUN pip3 install torch==2.4.0 # Install yaml-cpp RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ From f6637a5f20bf7670a63ac36c122295b74b0ab357 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 4 Sep 2024 11:10:53 -0700 Subject: [PATCH 8/8] Make Dockerfiles more consistent. --- docker/Dockerfile | 4 ++-- docker/Dockerfile_gnu | 11 ++++++++--- docker/Dockerfile_gnu_cpuonly | 7 +++++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 3727082..5424b35 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -24,7 +24,7 @@ ENV CUDA_HOME /opt/nvidia/hpc_sdk/Linux_x86_64/24.1/cuda RUN echo "source /opt/nvidia/hpc_sdk/Linux_x86_64/24.1/comm_libs/12.3/hpcx/latest/hpcx-init.sh; hpcx_load" >> /root/.bashrc -# Install NCCL 2.20.3 for compatibility with PyTorch 2.2.1 +# Install newer NCCL for compatibility with PyTorch 2.2.1+ RUN cd /opt && \ git clone --branch v2.20.3-1 https://github.com/NVIDIA/nccl.git && \ cd nccl && \ @@ -74,7 +74,7 @@ RUN cd /torchfort && mkdir build && cd build && \ -DTORCHFORT_BUILD_TESTS=1 \ -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ .. && \ - make VERBOSE=1 -j$(nproc) install && \ + make -j$(nproc) install && \ cd / && rm -rf torchfort ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH} ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH} diff --git a/docker/Dockerfile_gnu b/docker/Dockerfile_gnu index 770033e..fb4f5d9 100644 --- a/docker/Dockerfile_gnu +++ b/docker/Dockerfile_gnu @@ -1,8 +1,12 @@ FROM nvcr.io/nvidia/cuda:12.3.1-devel-ubuntu22.04 # Install System Dependencies +ENV DEBIAN_FRONTEND noninteractive RUN apt update -y && \ - DEBIAN_FRONTEND=noninteractive apt install -y curl unzip wget cmake python3 python-is-python3 python3-pip python3-pybind11 git vim gfortran doxygen libibverbs-dev + apt install -y curl unzip wget cmake && \ + apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \ + apt install -y git vim gfortran doxygen && \ + apt install -y libibverbs-dev ibverbs-utils numactl # Download HPCX and compile with Fortran support RUN cd /opt && \ @@ -32,7 +36,7 @@ ENV LD_LIBRARY_PATH /opt/hpcx/ompi/lib:$LD_LIBRARY_PATH RUN echo "source /opt/hpcx/hpcx-init.sh; hpcx_load" >> /root/.bashrc -# Install NCCL 2.20.3 for compatibility with PyTorch 2.2.1 +# Install newer NCCL for compatibility with PyTorch 2.2.1+ RUN cd /opt && \ git clone --branch v2.20.3-1 https://github.com/NVIDIA/nccl.git && \ cd nccl && \ @@ -40,7 +44,7 @@ RUN cd /opt && \ ENV LD_LIBRARY_PATH /opt/nccl/build/lib:$LD_LIBRARY_PATH # Install PyTorch -RUN pip3 install torch==2.2.1 torchvision torchaudio +RUN pip3 install torch==2.4.0 # Install yaml-cpp RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ @@ -78,6 +82,7 @@ RUN cd /torchfort && mkdir build && cd build && \ -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \ -DTORCHFORT_NCCL_ROOT=/opt/nccl/build \ -DTORCHFORT_BUILD_EXAMPLES=1 \ + -DTORCHFORT_BUILD_TESTS=1 \ -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ .. && \ make -j$(nproc) install && \ diff --git a/docker/Dockerfile_gnu_cpuonly b/docker/Dockerfile_gnu_cpuonly index b836636..239d9c9 100644 --- a/docker/Dockerfile_gnu_cpuonly +++ b/docker/Dockerfile_gnu_cpuonly @@ -4,7 +4,10 @@ FROM ubuntu:22.04 ENV DEBIAN_FRONTEND noninteractive RUN apt update -y && \ apt install -y build-essential && \ - apt install -y curl unzip wget cmake python3 python-is-python3 python3-pip python3-pybind11 git vim gfortran doxygen libibverbs-dev + apt install -y curl unzip wget cmake && \ + apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \ + apt install -y git vim gfortran doxygen && \ + apt install -y libibverbs-dev ibverbs-utils numactl # Download OpenMPI and compile with Fortran support RUN cd /opt && \ @@ -23,7 +26,7 @@ ENV PATH /opt/openmpi/bin:$PATH ENV LD_LIBRARY_PATH /opt/openmpi/lib:$LD_LIBRARY_PATH # Install PyTorch -RUN pip3 install torch==2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +RUN pip3 install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu # Install yaml-cpp RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \