diff --git a/README.md b/README.md index 9aae7aca..bf13d7fb 100644 --- a/README.md +++ b/README.md @@ -254,6 +254,16 @@ See some details in [this thread](https://github.com/LaurentMazare/tch-rs/issues Check this [issue](https://github.com/LaurentMazare/tch-rs/issues/488). +### What if I get some errors not finding `cuda_runtime_api.h`? + +This may be caused by the cuda headers not being in your default include paths. +To get around this, you can try changing the `CPLUS_INCLUDE_PATH` environment +variable pointing it at the appropriate directory, e.g. + +```bash +CPLUS_INCLUDE_PATH=/usr/local/cuda/include:$CPLUS_INCLUDE_PATH +``` + ## License `tch-rs` is distributed under the terms of both the MIT license and the Apache license (version 2.0), at your option. diff --git a/examples/cuda_graph.rs b/examples/cuda_graph.rs new file mode 100644 index 00000000..be22b517 --- /dev/null +++ b/examples/cuda_graph.rs @@ -0,0 +1,25 @@ +use tch::Tensor; + +fn run() -> Result<(), tch::TchError> { + tch::maybe_init_cuda(); + let mut graph = tch::cuda_graph::CudaGraph::new()?; + let mut t = Tensor::of_slice(&[3.0]); + t.print(); + t += 0.1; + t.print(); + let stream = tch::cuda_stream::CudaStream::get_stream_from_pool(false, 0)?; + stream.set_current_stream()?; + graph.capture_begin()?; + t += 0.01; + graph.capture_end()?; + t.print(); + graph.replay()?; + graph.replay()?; + graph.replay()?; + t.print(); + Ok(()) +} + +fn main() { + run().unwrap() +} diff --git a/src/lib.rs b/src/lib.rs index fd3b724f..6d0517a1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,8 @@ mod error; pub use error::TchError; pub(crate) mod wrappers; +pub use wrappers::cuda_graph; +pub use wrappers::cuda_stream; pub use wrappers::device::{Cuda, Device}; pub use wrappers::jit::{self, CModule, IValue, TrainableCModule}; pub use wrappers::kind::{self, Kind}; diff --git a/src/wrappers/cuda_graph.rs b/src/wrappers/cuda_graph.rs new file mode 100644 index 00000000..eeb20924 --- /dev/null +++ b/src/wrappers/cuda_graph.rs @@ -0,0 +1,43 @@ +//! CUDA Graph API. + +use crate::TchError; + +pub struct CudaGraph { + c_ptr: *mut torch_sys::cuda::CCudaGraph, +} + +impl CudaGraph { + pub fn new() -> Result { + let c_ptr = unsafe_torch_err!(torch_sys::cuda::atcg_new()); + if c_ptr.is_null() { + return Err(TchError::Torch("CudaGraph::new() returned null".to_string())); + } + Ok(Self { c_ptr }) + } + + pub fn capture_begin(&mut self) -> Result<(), TchError> { + unsafe_torch_err!(torch_sys::cuda::atcg_capture_begin(self.c_ptr)); + Ok(()) + } + + pub fn capture_end(&mut self) -> Result<(), TchError> { + unsafe_torch_err!(torch_sys::cuda::atcg_capture_end(self.c_ptr)); + Ok(()) + } + + pub fn replay(&mut self) -> Result<(), TchError> { + unsafe_torch_err!(torch_sys::cuda::atcg_replay(self.c_ptr)); + Ok(()) + } + + pub fn reset(&mut self) -> Result<(), TchError> { + unsafe_torch_err!(torch_sys::cuda::atcg_reset(self.c_ptr)); + Ok(()) + } +} + +impl Drop for CudaGraph { + fn drop(&mut self) { + unsafe_torch!(torch_sys::cuda::atcg_free(self.c_ptr)) + } +} diff --git a/src/wrappers/cuda_stream.rs b/src/wrappers/cuda_stream.rs new file mode 100644 index 00000000..c4f147ea --- /dev/null +++ b/src/wrappers/cuda_stream.rs @@ -0,0 +1,54 @@ +//! CUDA Stream API. + +use crate::TchError; +use libc::c_int; + +pub struct CudaStream { + c_ptr: *mut torch_sys::cuda::CCudaStream, +} + +impl CudaStream { + pub fn get_stream_from_pool(high_priority: bool, device: usize) -> Result { + let c_ptr = unsafe_torch_err!(torch_sys::cuda::atcs_get_stream_from_pool( + high_priority as c_int, + device as c_int + )); + if c_ptr.is_null() { + return Err(TchError::Torch( + "CUDAStream::getStreamFromPool() returned null".to_string(), + )); + } + Ok(Self { c_ptr }) + } + + pub fn get_default_stream(device: usize) -> Result { + let c_ptr = unsafe_torch_err!(torch_sys::cuda::atcs_get_default_stream(device as c_int)); + if c_ptr.is_null() { + return Err(TchError::Torch( + "CUDAStream::getDefaultStream() returned null".to_string(), + )); + } + Ok(Self { c_ptr }) + } + + pub fn get_current_stream(device: usize) -> Result { + let c_ptr = unsafe_torch_err!(torch_sys::cuda::atcs_get_current_stream(device as c_int)); + if c_ptr.is_null() { + return Err(TchError::Torch( + "CUDAStream::getStreamFromPool() returned null".to_string(), + )); + } + Ok(Self { c_ptr }) + } + + pub fn set_current_stream(&self) -> Result<(), TchError> { + unsafe_torch_err!(torch_sys::cuda::atcs_set_current_stream(self.c_ptr)); + Ok(()) + } +} + +impl Drop for CudaStream { + fn drop(&mut self) { + unsafe_torch!(torch_sys::cuda::atcs_free(self.c_ptr)) + } +} diff --git a/src/wrappers/mod.rs b/src/wrappers/mod.rs index 17d55c5c..2d06b1da 100644 --- a/src/wrappers/mod.rs +++ b/src/wrappers/mod.rs @@ -6,6 +6,8 @@ pub use utils::{ set_num_threads, QEngine, }; +pub mod cuda_graph; +pub mod cuda_stream; pub(crate) mod device; pub(crate) mod image; pub mod jit; diff --git a/torch-sys/build.rs b/torch-sys/build.rs index 489f280f..051a9d37 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -151,7 +151,7 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { .unwrap_or_else(|_| libtorch.as_ref().to_owned()); let cuda_dependency = if use_cuda || use_hip { - "libtch/dummy_cuda_dependency.cpp" + "libtch/cuda_dependency.cpp" } else { "libtch/fake_cuda_dependency.cpp" }; @@ -159,6 +159,9 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { println!("cargo:rerun-if-changed=libtch/torch_api.h"); println!("cargo:rerun-if-changed=libtch/torch_api_generated.cpp.h"); println!("cargo:rerun-if-changed=libtch/torch_api_generated.h"); + println!("cargo:rerun-if-changed=libtch/cuda_dependency.cpp"); + println!("cargo:rerun-if-changed=libtch/cuda_dependency.h"); + println!("cargo:rerun-if-changed=libtch/fake_cuda_dependency.cpp"); println!("cargo:rerun-if-changed=libtch/stb_image_write.h"); println!("cargo:rerun-if-changed=libtch/stb_image_resize.h"); println!("cargo:rerun-if-changed=libtch/stb_image.h"); @@ -228,6 +231,7 @@ fn main() { println!("cargo:rustc-link-lib=static=tch"); if use_cuda { + println!("cargo:rustc-link-lib=c10_cuda"); println!("cargo:rustc-link-lib=torch_cuda"); } if use_cuda_cu { diff --git a/torch-sys/libtch/cuda_dependency.cpp b/torch-sys/libtch/cuda_dependency.cpp new file mode 100644 index 00000000..8e5fd54c --- /dev/null +++ b/torch-sys/libtch/cuda_dependency.cpp @@ -0,0 +1,96 @@ +#define __TCH_ACTUAL_CUDA_DEPENDENCY__ +#include "cuda_dependency.h" +#include "torch_api.h" + +#include +#include +#include +#include +using namespace std; +extern "C" { + void dummy_cuda_dependency(); +} + +struct cublasContext; + +namespace at { + namespace cuda { + cublasContext* getCurrentCUDABlasHandle(); + int warp_size(); + } +} +char * magma_strerror(int err); +void dummy_cuda_dependency() { + try { + at::cuda::getCurrentCUDABlasHandle(); + at::cuda::warp_size(); + } + catch (std::exception &e) { + std::cerr << "error initializing cuda: " << e.what() << std::endl; + } +} + +cuda_graph atcg_new() { + PROTECT( + return new at::cuda::CUDAGraph(); + ) + return nullptr; +} + +void atcg_free(cuda_graph c) { + delete c; +} +void atcg_capture_begin(cuda_graph c) { + PROTECT( + c->capture_begin(); + ) +} + +void atcg_capture_end(cuda_graph c) { + PROTECT( + c->capture_end(); + ) +} + +void atcg_replay(cuda_graph c) { + PROTECT( + c->replay(); + ) +} + +void atcg_reset(cuda_graph c) { + PROTECT( + c->reset(); + ) +} + +void atcs_free(cuda_stream s) { + delete s; +} + +cuda_stream atcs_get_stream_from_pool(int high_priority, int device) { + PROTECT ( + return new c10::cuda::CUDAStream(c10::cuda::getStreamFromPool(high_priority, device)); + ) + return nullptr; +} + +cuda_stream atcs_get_default_stream(int device) { + PROTECT ( + return new c10::cuda::CUDAStream(c10::cuda::getDefaultCUDAStream(device)); + ) + return nullptr; +} + +cuda_stream atcs_get_current_stream(int device) { + PROTECT ( + return new c10::cuda::CUDAStream(c10::cuda::getCurrentCUDAStream(device)); + ) + return nullptr; +} + +void atcs_set_current_stream(cuda_stream s) { + PROTECT ( + c10::cuda::setCurrentCUDAStream(*s); + ) +} diff --git a/torch-sys/libtch/cuda_dependency.h b/torch-sys/libtch/cuda_dependency.h new file mode 100644 index 00000000..9e58e774 --- /dev/null +++ b/torch-sys/libtch/cuda_dependency.h @@ -0,0 +1,44 @@ +#ifndef __TCH_CUDA_DEPENDENCY_H__ +#define __TCH_CUDA_DEPENDENCY_H__ + +#include +#include + +#ifdef __cplusplus + +#ifdef __TCH_ACTUAL_CUDA_DEPENDENCY__ +#include +#include +typedef at::cuda::CUDAGraph *cuda_graph; +typedef c10::cuda::CUDAStream *cuda_stream; +#else +typedef void *cuda_graph; +typedef void *cuda_stream; +#endif + +extern "C" { + +#else +typedef void *cuda_graph; +typedef void *cuda_stream; +#endif + +void dummy_cuda_dependency(); +cuda_graph atcg_new(); +void atcg_free(cuda_graph); +void atcg_capture_begin(cuda_graph); +void atcg_capture_end(cuda_graph); +void atcg_replay(cuda_graph); +void atcg_reset(cuda_graph); + +void atcs_free(cuda_stream); +cuda_stream atcs_get_stream_from_pool(int, int); +cuda_stream atcs_get_default_stream(int); +cuda_stream atcs_get_current_stream(int); +void atcs_set_current_stream(cuda_stream); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/torch-sys/libtch/dummy_cuda_dependency.cpp b/torch-sys/libtch/dummy_cuda_dependency.cpp deleted file mode 100644 index 70583473..00000000 --- a/torch-sys/libtch/dummy_cuda_dependency.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include -#include -#include -#include -using namespace std; -extern "C" { - void dummy_cuda_dependency(); -} - -struct cublasContext; - -namespace at { - namespace cuda { - cublasContext* getCurrentCUDABlasHandle(); - int warp_size(); - } -} -char * magma_strerror(int err); -void dummy_cuda_dependency() { - try { - at::cuda::getCurrentCUDABlasHandle(); - at::cuda::warp_size(); - } - catch (std::exception &e) { - std::cerr << "error initializing cuda: " << e.what() << std::endl; - } -} diff --git a/torch-sys/libtch/fake_cuda_dependency.cpp b/torch-sys/libtch/fake_cuda_dependency.cpp index 9db48039..32a6de8c 100644 --- a/torch-sys/libtch/fake_cuda_dependency.cpp +++ b/torch-sys/libtch/fake_cuda_dependency.cpp @@ -1,6 +1,26 @@ -extern "C" { - void dummy_cuda_dependency(); -} +#include "cuda_dependency.h" void dummy_cuda_dependency() { } + +cuda_graph atcg_new() { + return nullptr; +} + +void atcg_free(cuda_graph) {} +void atcg_capture_begin(cuda_graph) {} +void atcg_capture_end(cuda_graph) {} +void atcg_replay(cuda_graph) {} +void atcg_reset(cuda_graph) {} + +void atcs_free(cuda_stream) {} +cuda_stream atcs_get_stream_from_pool(int, int) { + return nullptr; +} +cuda_stream atcs_get_default_stream(int) { + return nullptr; +} +cuda_stream atcs_get_current_stream(int) { + return nullptr; +} +void atcs_set_current_stream(cuda_stream) {} diff --git a/torch-sys/libtch/torch_api.cpp b/torch-sys/libtch/torch_api.cpp index c38bf1ce..9f9bd246 100644 --- a/torch-sys/libtch/torch_api.cpp +++ b/torch-sys/libtch/torch_api.cpp @@ -1,16 +1,3 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include "torch_api.h" #define STB_IMAGE_IMPLEMENTATION @@ -23,6 +10,7 @@ #include "stb_image_resize.h" using namespace std; +thread_local char *torch_last_err = nullptr; char *get_and_reset_last_err() { char *tmp = torch_last_err; diff --git a/torch-sys/libtch/torch_api.h b/torch-sys/libtch/torch_api.h index ad6a9333..021ceca6 100644 --- a/torch-sys/libtch/torch_api.h +++ b/torch-sys/libtch/torch_api.h @@ -3,7 +3,21 @@ #include #ifdef __cplusplus -thread_local char *torch_last_err = nullptr; +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +extern thread_local char *torch_last_err; extern "C" { typedef torch::Tensor *tensor; diff --git a/torch-sys/src/cuda.rs b/torch-sys/src/cuda.rs index 7ea4bbef..4bd7b7b4 100644 --- a/torch-sys/src/cuda.rs +++ b/torch-sys/src/cuda.rs @@ -1,5 +1,15 @@ use libc::c_int; +#[repr(C)] +pub struct CCudaGraph { + _private: [u8; 0], +} + +#[repr(C)] +pub struct CCudaStream { + _private: [u8; 0], +} + extern "C" { /// Returns the number of CUDA devices available. pub fn atc_cuda_device_count() -> c_int; @@ -27,4 +37,17 @@ extern "C" { /// Sets CUDNN benchmark mode. pub fn atc_set_benchmark_cudnn(b: c_int); + + pub fn atcg_new() -> *mut CCudaGraph; + pub fn atcg_free(arg: *mut CCudaGraph); + pub fn atcg_replay(arg: *mut CCudaGraph); + pub fn atcg_reset(arg: *mut CCudaGraph); + pub fn atcg_capture_begin(arg: *mut CCudaGraph); + pub fn atcg_capture_end(arg: *mut CCudaGraph); + + pub fn atcs_free(arg: *mut CCudaStream); + pub fn atcs_get_stream_from_pool(high_priority: c_int, device: c_int) -> *mut CCudaStream; + pub fn atcs_get_default_stream(device: c_int) -> *mut CCudaStream; + pub fn atcs_get_current_stream(device: c_int) -> *mut CCudaStream; + pub fn atcs_set_current_stream(arg: *mut CCudaStream); }