diff --git a/src/lib.rs b/src/lib.rs index 90c193da..c5cab519 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,7 @@ mod error; pub use error::TchError; pub(crate) mod wrappers; -pub use wrappers::device::{Cuda, Device}; +pub use wrappers::device::{Cuda, Device, Vulkan}; pub use wrappers::jit::{self, CModule, IValue, TrainableCModule}; pub use wrappers::kind::{self, Kind}; pub use wrappers::optimizer::COptimizer; diff --git a/src/wrappers/device.rs b/src/wrappers/device.rs index 7b118524..0fdb77ef 100644 --- a/src/wrappers/device.rs +++ b/src/wrappers/device.rs @@ -9,6 +9,8 @@ pub enum Device { Cuda(usize), /// The main MPS device. Mps, + /// The main Vulkan device. + Vulkan, } /// Cuda related helper functions. @@ -16,18 +18,32 @@ pub enum Cuda {} impl Cuda { /// Returns the number of CUDA devices available. pub fn device_count() -> i64 { - let res = unsafe_torch!(torch_sys::cuda::atc_cuda_device_count()); - i64::from(res) + #[cfg(not(target_os = "android"))] + { + let res = unsafe_torch!(torch_sys::cuda::atc_cuda_device_count()); + i64::from(res) + } + + #[cfg(target_os = "android")] + 0 } /// Returns true if at least one CUDA device is available. pub fn is_available() -> bool { - unsafe_torch!(torch_sys::cuda::atc_cuda_is_available()) != 0 + #[cfg(not(target_os = "android"))] + return unsafe_torch!(torch_sys::cuda::atc_cuda_is_available()) != 0; + + #[cfg(target_os = "android")] + return false; } /// Returns true if CUDA is available, and CuDNN is available. pub fn cudnn_is_available() -> bool { - unsafe_torch!(torch_sys::cuda::atc_cudnn_is_available()) != 0 + #[cfg(not(target_os = "android"))] + return unsafe_torch!(torch_sys::cuda::atc_cudnn_is_available()) != 0; + + #[cfg(target_os = "android")] + return false; } /// Sets the seed for the current GPU. @@ -36,6 +52,7 @@ impl Cuda { /// /// * `seed` - An unsigned 64bit int to be used as seed. pub fn manual_seed(seed: u64) { + #[cfg(not(target_os = "android"))] unsafe_torch!(torch_sys::cuda::atc_manual_seed(seed)); } @@ -45,6 +62,7 @@ impl Cuda { /// /// * `seed` - An unsigned 64bit int to be used as seed. pub fn manual_seed_all(seed: u64) { + #[cfg(not(target_os = "android"))] unsafe_torch!(torch_sys::cuda::atc_manual_seed_all(seed)); } @@ -54,6 +72,7 @@ impl Cuda { /// /// * `device_index` - A signed 64bit int to indice which device to wait for. pub fn synchronize(device_index: i64) { + #[cfg(not(target_os = "android"))] unsafe_torch!(torch_sys::cuda::atc_synchronize(device_index)); } @@ -61,11 +80,16 @@ impl Cuda { /// /// This does not indicate whether cudnn is actually usable. pub fn user_enabled_cudnn() -> bool { - unsafe_torch!(torch_sys::cuda::atc_user_enabled_cudnn()) != 0 + #[cfg(not(target_os = "android"))] + return unsafe_torch!(torch_sys::cuda::atc_user_enabled_cudnn()) != 0; + + #[cfg(target_os = "android")] + return false; } /// Enable or disable cudnn. pub fn set_user_enabled_cudnn(b: bool) { + #[cfg(not(target_os = "android"))] unsafe_torch!(torch_sys::cuda::atc_set_user_enabled_cudnn(i32::from(b))) } @@ -76,16 +100,27 @@ impl Cuda { /// in the following runs. This can result in significant performance /// improvements. pub fn cudnn_set_benchmark(b: bool) { + #[cfg(not(target_os = "android"))] unsafe_torch!(torch_sys::cuda::atc_set_benchmark_cudnn(i32::from(b))) } } +/// Vulkan related helper functions. +pub enum Vulkan {} +impl Vulkan { + /// Returns true if Vulkan is available. + pub fn is_available() -> bool { + unsafe_torch!(torch_sys::vulkan::atc_vulkan_is_available()) != 0 + } +} + impl Device { pub(super) fn c_int(self) -> libc::c_int { match self { Device::Cpu => -1, Device::Cuda(device_index) => device_index as libc::c_int, Device::Mps => -2, + Device::Vulkan => -3, } } @@ -98,7 +133,7 @@ impl Device { } } - /// Returns a GPU device if available, else default to CPU. + /// Returns a CUDA device if available, else default to CPU. pub fn cuda_if_available() -> Device { if Cuda::is_available() { Device::Cuda(0) @@ -107,11 +142,36 @@ impl Device { } } + /// Returns a GPU device if available, else default to CPU. + pub fn vulkan_if_available() -> Device { + if Vulkan::is_available() { + Device::Vulkan + } else { + Device::Cpu + } + } + + pub fn gpu_if_available() -> Device { + if Cuda::is_available() { + Device::Cuda(0) + } else if Vulkan::is_available() { + Device::Vulkan + } else { + Device::Cpu + } + } + pub fn is_cuda(self) -> bool { match self { Device::Cuda(_) => true, - Device::Cpu => false, - Device::Mps => false, + _ => false, + } + } + + pub fn is_vulkan(self) -> bool { + match self { + Device::Vulkan => true, + _ => false, } } } diff --git a/torch-sys/build.rs b/torch-sys/build.rs index fc182e4a..09c22aa7 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -64,6 +64,17 @@ fn extract>(filename: P, outpath: P) -> anyhow::Result<()> { Ok(()) } +fn env_var_target_specific(name: &str) -> Result { + let target = env::var("TARGET").expect("Unable to get TARGET"); + + let name_with_target_hyphenated = name.to_owned() + "_" + ⌖ + let name_with_target_underscored = name.to_owned() + "_" + &target.replace("-", "_"); + + env_var_rerun(&name_with_target_hyphenated) + .or_else(|_| env_var_rerun(&name_with_target_underscored)) + .or_else(|_| env_var_rerun(name)) +} + fn env_var_rerun(name: &str) -> Result { println!("cargo:rerun-if-env-changed={}", name); env::var(name) @@ -101,7 +112,7 @@ fn prepare_libtorch_dir() -> PathBuf { Err(_) => "cpu".to_owned(), }; - if let Ok(libtorch) = env_var_rerun("LIBTORCH") { + if let Ok(libtorch) = env_var_target_specific("LIBTORCH") { PathBuf::from(libtorch) } else if let Some(pathbuf) = check_system_location() { pathbuf @@ -147,14 +158,15 @@ fn prepare_libtorch_dir() -> PathBuf { } } -fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { - let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); - let includes: PathBuf = env_var_rerun("LIBTORCH_INCLUDE") - .map(Into::into) - .unwrap_or_else(|_| libtorch.as_ref().to_owned()); - let lib: PathBuf = env_var_rerun("LIBTORCH_LIB") - .map(Into::into) - .unwrap_or_else(|_| libtorch.as_ref().to_owned()); +fn make( + includes: impl AsRef, + lib: impl AsRef, + use_cuda: bool, + use_hip: bool, + os: &str, +) { + let includes = includes.as_ref(); + let lib = lib.as_ref(); let cuda_dependency = if use_cuda || use_hip { "libtch/dummy_cuda_dependency.cpp" @@ -168,17 +180,17 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { println!("cargo:rerun-if-changed=libtch/stb_image_write.h"); println!("cargo:rerun-if-changed=libtch/stb_image_resize.h"); println!("cargo:rerun-if-changed=libtch/stb_image.h"); - match os.as_str() { - "linux" | "macos" => { + match os { + "linux" | "macos" | "android" => { let libtorch_cxx11_abi = env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned()); cc::Build::new() .cpp(true) .pic(true) .warnings(false) - .include(includes.join("include")) - .include(includes.join("include/torch/csrc/api/include")) - .flag(&format!("-Wl,-rpath={}", lib.join("lib").display())) + .include(includes) + .include(includes.join("torch/csrc/api/include")) + .flag(&format!("-Wl,-rpath={}", lib.display())) .flag("-std=c++14") .flag(&format!("-D_GLIBCXX_USE_CXX11_ABI={}", libtorch_cxx11_abi)) .file("libtch/torch_api.cpp") @@ -193,19 +205,32 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { .cpp(true) .pic(true) .warnings(false) - .include(includes.join("include")) - .include(includes.join("include/torch/csrc/api/include")) + .include(includes) + .include(includes.join("torch/csrc/api/include")) .file("libtch/torch_api.cpp") .file(cuda_dependency) .compile("tch"); } - _ => panic!("Unsupported OS"), + os => panic!("Unsupported OS: {}", os), }; } fn main() { if !cfg!(feature = "doc-only") { + let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); + let libtorch = prepare_libtorch_dir(); + + let libtorch_includes: PathBuf = env_var_target_specific("LIBTORCH_INCLUDE") + .map(Into::into) + .unwrap_or_else(|_| libtorch.join("include")); + let libtorch_lib: PathBuf = env_var_target_specific("LIBTORCH_LIB") + .map(Into::into) + .unwrap_or_else(|_| libtorch.join("lib")); + let libtorch_lite: bool = env_var_target_specific("LIBTORCH_LITE") + .map(|s| s.parse().unwrap_or(true)) + .unwrap_or(true); + // use_cuda is a hacky way to detect whether cuda is available and // if it's the case link to it by explicitly depending on a symbol // from the torch_cuda library. @@ -220,42 +245,56 @@ fn main() { // This will be available starting from cargo 1.50 but will be a nightly // only option to start with. // https://github.com/rust-lang/cargo/blob/master/CHANGELOG.md - let use_cuda = libtorch.join("lib").join("libtorch_cuda.so").exists() - || libtorch.join("lib").join("torch_cuda.dll").exists(); - let use_cuda_cu = libtorch.join("lib").join("libtorch_cuda_cu.so").exists() - || libtorch.join("lib").join("torch_cuda_cu.dll").exists(); - let use_cuda_cpp = libtorch.join("lib").join("libtorch_cuda_cpp.so").exists() - || libtorch.join("lib").join("torch_cuda_cpp.dll").exists(); - let use_hip = libtorch.join("lib").join("libtorch_hip.so").exists() - || libtorch.join("lib").join("torch_hip.dll").exists(); - println!("cargo:rustc-link-search=native={}", libtorch.join("lib").display()); + let use_cuda = libtorch_lib.join("libtorch_cuda.so").exists() + || libtorch_lib.join("torch_cuda.dll").exists(); + let use_cuda_cu = libtorch_lib.join("libtorch_cuda_cu.so").exists() + || libtorch_lib.join("torch_cuda_cu.dll").exists(); + let use_cuda_cpp = libtorch_lib.join("libtorch_cuda_cpp.so").exists() + || libtorch_lib.join("torch_cuda_cpp.dll").exists(); + let use_hip = libtorch_lib.join("libtorch_hip.so").exists() + || libtorch_lib.join("torch_hip.dll").exists(); + + println!("cargo:rustc-link-search=native={}", libtorch_lib.display()); - make(&libtorch, use_cuda, use_hip); + make(&libtorch_includes, &libtorch_lib, use_cuda, use_hip, &os); println!("cargo:rustc-link-lib=static=tch"); - if use_cuda { - println!("cargo:rustc-link-lib=torch_cuda"); - } - if use_cuda_cu { - println!("cargo:rustc-link-lib=torch_cuda_cu"); - } - if use_cuda_cpp { - println!("cargo:rustc-link-lib=torch_cuda_cpp"); - } - if use_hip { - println!("cargo:rustc-link-lib=torch_hip"); - } - println!("cargo:rustc-link-lib=torch_cpu"); - println!("cargo:rustc-link-lib=torch"); - println!("cargo:rustc-link-lib=c10"); - if use_hip { - println!("cargo:rustc-link-lib=c10_hip"); - } - let target = env::var("TARGET").unwrap(); + match os.as_str() { + "windows" | "linux" | "macos" => { + if use_cuda { + println!("cargo:rustc-link-lib=torch_cuda"); + } + if use_cuda_cu { + println!("cargo:rustc-link-lib=torch_cuda_cu"); + } + if use_cuda_cpp { + println!("cargo:rustc-link-lib=torch_cuda_cpp"); + } + if use_hip { + println!("cargo:rustc-link-lib=torch_hip"); + } + println!("cargo:rustc-link-lib=torch_cpu"); + println!("cargo:rustc-link-lib=torch"); + println!("cargo:rustc-link-lib=c10"); + if use_hip { + println!("cargo:rustc-link-lib=c10_hip"); + } + + let target = env::var("TARGET").unwrap(); - if !target.contains("msvc") && !target.contains("apple") { - println!("cargo:rustc-link-lib=gomp"); + if !target.contains("msvc") && !target.contains("apple") { + println!("cargo:rustc-link-lib=gomp"); + } + } + "android" => { + if libtorch_lite { + println!("cargo:rustc-link-lib=pytorch_jni_lite"); + } else { + println!("cargo:rustc-link-lib=pytorch_jni"); + } + } + other => panic!("unsupported OS: {}", other), } } } diff --git a/torch-sys/libtch/torch_api.cpp b/torch-sys/libtch/torch_api.cpp index aa280201..5bc1d14e 100644 --- a/torch-sys/libtch/torch_api.cpp +++ b/torch-sys/libtch/torch_api.cpp @@ -47,6 +47,7 @@ c10::List> of_carray_tensor_opt(torch::Tensor **vs, } at::Device device_of_int(int d) { + if (d == -3) return at::Device(at::kVulkan); if (d == -2) return at::Device(at::kMPS); if (d < 0) return at::Device(at::kCPU); return at::Device(at::kCUDA, /*index=*/d); @@ -926,6 +927,11 @@ int atc_cudnn_is_available() { return -1; } +int atc_vulkan_is_available() { + PROTECT(return torch::is_vulkan_available();) + return -1; +} + void atc_manual_seed(uint64_t seed) { PROTECT(return torch::cuda::manual_seed(seed);) } diff --git a/torch-sys/libtch/torch_api.h b/torch-sys/libtch/torch_api.h index a05d4f2a..1ad7ad12 100644 --- a/torch-sys/libtch/torch_api.h +++ b/torch-sys/libtch/torch_api.h @@ -145,7 +145,6 @@ double ats_to_float(scalar); char *ats_to_string(scalar); void ats_free(scalar); - /// Returns the number of CUDA devices available. int atc_cuda_device_count(); @@ -155,6 +154,9 @@ int atc_cuda_is_available(); /// Returns true if CUDA is available, and CuDNN is available. int atc_cudnn_is_available(); +/// Returns true if at least one CUDA device is available. +int atc_vulkan_is_available(); + /// Sets the seed for the current GPU. void atc_manual_seed(uint64_t seed); diff --git a/torch-sys/src/lib.rs b/torch-sys/src/lib.rs index 17d587dd..253cbe30 100644 --- a/torch-sys/src/lib.rs +++ b/torch-sys/src/lib.rs @@ -1,5 +1,7 @@ +#[cfg(not(target_os = "android"))] pub mod cuda; pub mod io; +pub mod vulkan; use libc::{c_char, c_int, c_uchar, c_void, size_t}; diff --git a/torch-sys/src/vulkan.rs b/torch-sys/src/vulkan.rs new file mode 100644 index 00000000..291b6adf --- /dev/null +++ b/torch-sys/src/vulkan.rs @@ -0,0 +1,6 @@ +use libc::c_int; + +extern "C" { + /// Returns true if Vulkan is available. + pub fn atc_vulkan_is_available() -> c_int; +}