Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some cuda graph api. #632

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ See some details in [this thread](https://github.com/LaurentMazare/tch-rs/issues

Check this [issue](https://github.com/LaurentMazare/tch-rs/issues/488).

### What if I get some errors not finding `cuda_runtime_api.h`?

This may be caused by the cuda headers not being in your default include paths.
To get around this, you can try changing the `CPLUS_INCLUDE_PATH` environment
variable pointing it at the appropriate directory, e.g.

```bash
CPLUS_INCLUDE_PATH=/usr/local/cuda/include:$CPLUS_INCLUDE_PATH
```

## License
`tch-rs` is distributed under the terms of both the MIT license
and the Apache license (version 2.0), at your option.
Expand Down
25 changes: 25 additions & 0 deletions examples/cuda_graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use tch::Tensor;

fn run() -> Result<(), tch::TchError> {
tch::maybe_init_cuda();
let mut graph = tch::cuda_graph::CudaGraph::new()?;
let mut t = Tensor::of_slice(&[3.0]);
t.print();
t += 0.1;
t.print();
let stream = tch::cuda_stream::CudaStream::get_stream_from_pool(false, 0)?;
stream.set_current_stream()?;
graph.capture_begin()?;
t += 0.01;
graph.capture_end()?;
t.print();
graph.replay()?;
graph.replay()?;
graph.replay()?;
t.print();
Ok(())
}

fn main() {
run().unwrap()
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ mod error;
pub use error::TchError;

pub(crate) mod wrappers;
pub use wrappers::cuda_graph;
pub use wrappers::cuda_stream;
pub use wrappers::device::{Cuda, Device};
pub use wrappers::jit::{self, CModule, IValue, TrainableCModule};
pub use wrappers::kind::{self, Kind};
Expand Down
43 changes: 43 additions & 0 deletions src/wrappers/cuda_graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//! CUDA Graph API.

use crate::TchError;

pub struct CudaGraph {
c_ptr: *mut torch_sys::cuda::CCudaGraph,
}

impl CudaGraph {
pub fn new() -> Result<Self, TchError> {
let c_ptr = unsafe_torch_err!(torch_sys::cuda::atcg_new());
if c_ptr.is_null() {
return Err(TchError::Torch("CudaGraph::new() returned null".to_string()));
}
Ok(Self { c_ptr })
}

pub fn capture_begin(&mut self) -> Result<(), TchError> {
unsafe_torch_err!(torch_sys::cuda::atcg_capture_begin(self.c_ptr));
Ok(())
}

pub fn capture_end(&mut self) -> Result<(), TchError> {
unsafe_torch_err!(torch_sys::cuda::atcg_capture_end(self.c_ptr));
Ok(())
}

pub fn replay(&mut self) -> Result<(), TchError> {
unsafe_torch_err!(torch_sys::cuda::atcg_replay(self.c_ptr));
Ok(())
}

pub fn reset(&mut self) -> Result<(), TchError> {
unsafe_torch_err!(torch_sys::cuda::atcg_reset(self.c_ptr));
Ok(())
}
}

impl Drop for CudaGraph {
fn drop(&mut self) {
unsafe_torch!(torch_sys::cuda::atcg_free(self.c_ptr))
}
}
54 changes: 54 additions & 0 deletions src/wrappers/cuda_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//! CUDA Stream API.

use crate::TchError;
use libc::c_int;

pub struct CudaStream {
c_ptr: *mut torch_sys::cuda::CCudaStream,
}

impl CudaStream {
pub fn get_stream_from_pool(high_priority: bool, device: usize) -> Result<Self, TchError> {
let c_ptr = unsafe_torch_err!(torch_sys::cuda::atcs_get_stream_from_pool(
high_priority as c_int,
device as c_int
));
if c_ptr.is_null() {
return Err(TchError::Torch(
"CUDAStream::getStreamFromPool() returned null".to_string(),
));
}
Ok(Self { c_ptr })
}

pub fn get_default_stream(device: usize) -> Result<Self, TchError> {
let c_ptr = unsafe_torch_err!(torch_sys::cuda::atcs_get_default_stream(device as c_int));
if c_ptr.is_null() {
return Err(TchError::Torch(
"CUDAStream::getDefaultStream() returned null".to_string(),
));
}
Ok(Self { c_ptr })
}

pub fn get_current_stream(device: usize) -> Result<Self, TchError> {
let c_ptr = unsafe_torch_err!(torch_sys::cuda::atcs_get_current_stream(device as c_int));
if c_ptr.is_null() {
return Err(TchError::Torch(
"CUDAStream::getStreamFromPool() returned null".to_string(),
));
}
Ok(Self { c_ptr })
}

pub fn set_current_stream(&self) -> Result<(), TchError> {
unsafe_torch_err!(torch_sys::cuda::atcs_set_current_stream(self.c_ptr));
Ok(())
}
}

impl Drop for CudaStream {
fn drop(&mut self) {
unsafe_torch!(torch_sys::cuda::atcs_free(self.c_ptr))
}
}
2 changes: 2 additions & 0 deletions src/wrappers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub use utils::{
set_num_threads, QEngine,
};

pub mod cuda_graph;
pub mod cuda_stream;
pub(crate) mod device;
pub(crate) mod image;
pub mod jit;
Expand Down
6 changes: 5 additions & 1 deletion torch-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,17 @@ fn make<P: AsRef<Path>>(libtorch: P, use_cuda: bool, use_hip: bool) {
.unwrap_or_else(|_| libtorch.as_ref().to_owned());

let cuda_dependency = if use_cuda || use_hip {
"libtch/dummy_cuda_dependency.cpp"
"libtch/cuda_dependency.cpp"
} else {
"libtch/fake_cuda_dependency.cpp"
};
println!("cargo:rerun-if-changed=libtch/torch_api.cpp");
println!("cargo:rerun-if-changed=libtch/torch_api.h");
println!("cargo:rerun-if-changed=libtch/torch_api_generated.cpp.h");
println!("cargo:rerun-if-changed=libtch/torch_api_generated.h");
println!("cargo:rerun-if-changed=libtch/cuda_dependency.cpp");
println!("cargo:rerun-if-changed=libtch/cuda_dependency.h");
println!("cargo:rerun-if-changed=libtch/fake_cuda_dependency.cpp");
println!("cargo:rerun-if-changed=libtch/stb_image_write.h");
println!("cargo:rerun-if-changed=libtch/stb_image_resize.h");
println!("cargo:rerun-if-changed=libtch/stb_image.h");
Expand Down Expand Up @@ -228,6 +231,7 @@ fn main() {

println!("cargo:rustc-link-lib=static=tch");
if use_cuda {
println!("cargo:rustc-link-lib=c10_cuda");
println!("cargo:rustc-link-lib=torch_cuda");
}
if use_cuda_cu {
Expand Down
96 changes: 96 additions & 0 deletions torch-sys/libtch/cuda_dependency.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#define __TCH_ACTUAL_CUDA_DEPENDENCY__
#include "cuda_dependency.h"
#include "torch_api.h"

#include<stdio.h>
#include<stdint.h>
#include<stdexcept>
#include<iostream>
using namespace std;
extern "C" {
void dummy_cuda_dependency();
}

struct cublasContext;

namespace at {
namespace cuda {
cublasContext* getCurrentCUDABlasHandle();
int warp_size();
}
}
char * magma_strerror(int err);
void dummy_cuda_dependency() {
try {
at::cuda::getCurrentCUDABlasHandle();
at::cuda::warp_size();
}
catch (std::exception &e) {
std::cerr << "error initializing cuda: " << e.what() << std::endl;
}
}

cuda_graph atcg_new() {
PROTECT(
return new at::cuda::CUDAGraph();
)
return nullptr;
}

void atcg_free(cuda_graph c) {
delete c;
}
void atcg_capture_begin(cuda_graph c) {
PROTECT(
c->capture_begin();
)
}

void atcg_capture_end(cuda_graph c) {
PROTECT(
c->capture_end();
)
}

void atcg_replay(cuda_graph c) {
PROTECT(
c->replay();
)
}

void atcg_reset(cuda_graph c) {
PROTECT(
c->reset();
)
}

void atcs_free(cuda_stream s) {
delete s;
}

cuda_stream atcs_get_stream_from_pool(int high_priority, int device) {
PROTECT (
return new c10::cuda::CUDAStream(c10::cuda::getStreamFromPool(high_priority, device));
)
return nullptr;
}

cuda_stream atcs_get_default_stream(int device) {
PROTECT (
return new c10::cuda::CUDAStream(c10::cuda::getDefaultCUDAStream(device));
)
return nullptr;
}

cuda_stream atcs_get_current_stream(int device) {
PROTECT (
return new c10::cuda::CUDAStream(c10::cuda::getCurrentCUDAStream(device));
)
return nullptr;
}

void atcs_set_current_stream(cuda_stream s) {
PROTECT (
c10::cuda::setCurrentCUDAStream(*s);
)
}
44 changes: 44 additions & 0 deletions torch-sys/libtch/cuda_dependency.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef __TCH_CUDA_DEPENDENCY_H__
#define __TCH_CUDA_DEPENDENCY_H__

#include<stdio.h>
#include<stdint.h>

#ifdef __cplusplus

#ifdef __TCH_ACTUAL_CUDA_DEPENDENCY__
#include<ATen/cuda/CUDAGraph.h>
#include<c10/cuda/CUDAStream.h>
typedef at::cuda::CUDAGraph *cuda_graph;
typedef c10::cuda::CUDAStream *cuda_stream;
#else
typedef void *cuda_graph;
typedef void *cuda_stream;
#endif

extern "C" {

#else
typedef void *cuda_graph;
typedef void *cuda_stream;
#endif

void dummy_cuda_dependency();
cuda_graph atcg_new();
void atcg_free(cuda_graph);
void atcg_capture_begin(cuda_graph);
void atcg_capture_end(cuda_graph);
void atcg_replay(cuda_graph);
void atcg_reset(cuda_graph);

void atcs_free(cuda_stream);
cuda_stream atcs_get_stream_from_pool(int, int);
cuda_stream atcs_get_default_stream(int);
cuda_stream atcs_get_current_stream(int);
void atcs_set_current_stream(cuda_stream);

#ifdef __cplusplus
}
#endif

#endif
27 changes: 0 additions & 27 deletions torch-sys/libtch/dummy_cuda_dependency.cpp

This file was deleted.

26 changes: 23 additions & 3 deletions torch-sys/libtch/fake_cuda_dependency.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
extern "C" {
void dummy_cuda_dependency();
}
#include "cuda_dependency.h"

void dummy_cuda_dependency() {
}

cuda_graph atcg_new() {
return nullptr;
}

void atcg_free(cuda_graph) {}
void atcg_capture_begin(cuda_graph) {}
void atcg_capture_end(cuda_graph) {}
void atcg_replay(cuda_graph) {}
void atcg_reset(cuda_graph) {}

void atcs_free(cuda_stream) {}
cuda_stream atcs_get_stream_from_pool(int, int) {
return nullptr;
}
cuda_stream atcs_get_default_stream(int) {
return nullptr;
}
cuda_stream atcs_get_current_stream(int) {
return nullptr;
}
void atcs_set_current_stream(cuda_stream) {}
Loading