Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible build script for Android support #594

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod error;
pub use error::TchError;

pub(crate) mod wrappers;
pub use wrappers::device::{Cuda, Device};
pub use wrappers::device::{Cuda, Device, Vulkan};
pub use wrappers::jit::{self, CModule, IValue, TrainableCModule};
pub use wrappers::kind::{self, Kind};
pub use wrappers::optimizer::COptimizer;
Expand Down
76 changes: 68 additions & 8 deletions src/wrappers/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,41 @@ pub enum Device {
Cuda(usize),
/// The main MPS device.
Mps,
/// The main Vulkan device.
Vulkan,
}

/// Cuda related helper functions.
pub enum Cuda {}
impl Cuda {
/// Returns the number of CUDA devices available.
pub fn device_count() -> i64 {
let res = unsafe_torch!(torch_sys::cuda::atc_cuda_device_count());
i64::from(res)
#[cfg(not(target_os = "android"))]
{
let res = unsafe_torch!(torch_sys::cuda::atc_cuda_device_count());
i64::from(res)
}

#[cfg(target_os = "android")]
0
}

/// Returns true if at least one CUDA device is available.
pub fn is_available() -> bool {
unsafe_torch!(torch_sys::cuda::atc_cuda_is_available()) != 0
#[cfg(not(target_os = "android"))]
return unsafe_torch!(torch_sys::cuda::atc_cuda_is_available()) != 0;

#[cfg(target_os = "android")]
return false;
}

/// Returns true if CUDA is available, and CuDNN is available.
pub fn cudnn_is_available() -> bool {
unsafe_torch!(torch_sys::cuda::atc_cudnn_is_available()) != 0
#[cfg(not(target_os = "android"))]
return unsafe_torch!(torch_sys::cuda::atc_cudnn_is_available()) != 0;

#[cfg(target_os = "android")]
return false;
}

/// Sets the seed for the current GPU.
Expand All @@ -36,6 +52,7 @@ impl Cuda {
///
/// * `seed` - An unsigned 64bit int to be used as seed.
pub fn manual_seed(seed: u64) {
#[cfg(not(target_os = "android"))]
unsafe_torch!(torch_sys::cuda::atc_manual_seed(seed));
}

Expand All @@ -45,6 +62,7 @@ impl Cuda {
///
/// * `seed` - An unsigned 64bit int to be used as seed.
pub fn manual_seed_all(seed: u64) {
#[cfg(not(target_os = "android"))]
unsafe_torch!(torch_sys::cuda::atc_manual_seed_all(seed));
}

Expand All @@ -54,18 +72,24 @@ impl Cuda {
///
/// * `device_index` - A signed 64bit int to indice which device to wait for.
pub fn synchronize(device_index: i64) {
#[cfg(not(target_os = "android"))]
unsafe_torch!(torch_sys::cuda::atc_synchronize(device_index));
}

/// Returns true if cudnn is enabled by the user.
///
/// This does not indicate whether cudnn is actually usable.
pub fn user_enabled_cudnn() -> bool {
unsafe_torch!(torch_sys::cuda::atc_user_enabled_cudnn()) != 0
#[cfg(not(target_os = "android"))]
return unsafe_torch!(torch_sys::cuda::atc_user_enabled_cudnn()) != 0;

#[cfg(target_os = "android")]
return false;
}

/// Enable or disable cudnn.
pub fn set_user_enabled_cudnn(b: bool) {
#[cfg(not(target_os = "android"))]
unsafe_torch!(torch_sys::cuda::atc_set_user_enabled_cudnn(i32::from(b)))
}

Expand All @@ -76,16 +100,27 @@ impl Cuda {
/// in the following runs. This can result in significant performance
/// improvements.
pub fn cudnn_set_benchmark(b: bool) {
#[cfg(not(target_os = "android"))]
unsafe_torch!(torch_sys::cuda::atc_set_benchmark_cudnn(i32::from(b)))
}
}

/// Vulkan related helper functions.
pub enum Vulkan {}
impl Vulkan {
/// Returns true if Vulkan is available.
pub fn is_available() -> bool {
unsafe_torch!(torch_sys::vulkan::atc_vulkan_is_available()) != 0
}
}

impl Device {
pub(super) fn c_int(self) -> libc::c_int {
match self {
Device::Cpu => -1,
Device::Cuda(device_index) => device_index as libc::c_int,
Device::Mps => -2,
Device::Vulkan => -3,
}
}

Expand All @@ -98,7 +133,7 @@ impl Device {
}
}

/// Returns a GPU device if available, else default to CPU.
/// Returns a CUDA device if available, else default to CPU.
pub fn cuda_if_available() -> Device {
if Cuda::is_available() {
Device::Cuda(0)
Expand All @@ -107,11 +142,36 @@ impl Device {
}
}

/// Returns a GPU device if available, else default to CPU.
pub fn vulkan_if_available() -> Device {
if Vulkan::is_available() {
Device::Vulkan
} else {
Device::Cpu
}
}

pub fn gpu_if_available() -> Device {
if Cuda::is_available() {
Device::Cuda(0)
} else if Vulkan::is_available() {
Device::Vulkan
} else {
Device::Cpu
}
}

pub fn is_cuda(self) -> bool {
match self {
Device::Cuda(_) => true,
Device::Cpu => false,
Device::Mps => false,
_ => false,
}
}

pub fn is_vulkan(self) -> bool {
match self {
Device::Vulkan => true,
_ => false,
}
}
}
135 changes: 87 additions & 48 deletions torch-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ fn extract<P: AsRef<Path>>(filename: P, outpath: P) -> anyhow::Result<()> {
Ok(())
}

fn env_var_target_specific(name: &str) -> Result<String, env::VarError> {
let target = env::var("TARGET").expect("Unable to get TARGET");

let name_with_target_hyphenated = name.to_owned() + "_" + &target;
let name_with_target_underscored = name.to_owned() + "_" + &target.replace("-", "_");

env_var_rerun(&name_with_target_hyphenated)
.or_else(|_| env_var_rerun(&name_with_target_underscored))
.or_else(|_| env_var_rerun(name))
}

fn env_var_rerun(name: &str) -> Result<String, env::VarError> {
println!("cargo:rerun-if-env-changed={}", name);
env::var(name)
Expand Down Expand Up @@ -101,7 +112,7 @@ fn prepare_libtorch_dir() -> PathBuf {
Err(_) => "cpu".to_owned(),
};

if let Ok(libtorch) = env_var_rerun("LIBTORCH") {
if let Ok(libtorch) = env_var_target_specific("LIBTORCH") {
PathBuf::from(libtorch)
} else if let Some(pathbuf) = check_system_location() {
pathbuf
Expand Down Expand Up @@ -147,14 +158,15 @@ fn prepare_libtorch_dir() -> PathBuf {
}
}

fn make<P: AsRef<Path>>(libtorch: P, use_cuda: bool, use_hip: bool) {
let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
let includes: PathBuf = env_var_rerun("LIBTORCH_INCLUDE")
.map(Into::into)
.unwrap_or_else(|_| libtorch.as_ref().to_owned());
let lib: PathBuf = env_var_rerun("LIBTORCH_LIB")
.map(Into::into)
.unwrap_or_else(|_| libtorch.as_ref().to_owned());
fn make(
includes: impl AsRef<Path>,
lib: impl AsRef<Path>,
use_cuda: bool,
use_hip: bool,
os: &str,
) {
let includes = includes.as_ref();
let lib = lib.as_ref();

let cuda_dependency = if use_cuda || use_hip {
"libtch/dummy_cuda_dependency.cpp"
Expand All @@ -168,17 +180,17 @@ fn make<P: AsRef<Path>>(libtorch: P, use_cuda: bool, use_hip: bool) {
println!("cargo:rerun-if-changed=libtch/stb_image_write.h");
println!("cargo:rerun-if-changed=libtch/stb_image_resize.h");
println!("cargo:rerun-if-changed=libtch/stb_image.h");
match os.as_str() {
"linux" | "macos" => {
match os {
"linux" | "macos" | "android" => {
let libtorch_cxx11_abi =
env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned());
cc::Build::new()
.cpp(true)
.pic(true)
.warnings(false)
.include(includes.join("include"))
.include(includes.join("include/torch/csrc/api/include"))
.flag(&format!("-Wl,-rpath={}", lib.join("lib").display()))
.include(includes)
.include(includes.join("torch/csrc/api/include"))
.flag(&format!("-Wl,-rpath={}", lib.display()))
.flag("-std=c++14")
.flag(&format!("-D_GLIBCXX_USE_CXX11_ABI={}", libtorch_cxx11_abi))
.file("libtch/torch_api.cpp")
Expand All @@ -193,19 +205,32 @@ fn make<P: AsRef<Path>>(libtorch: P, use_cuda: bool, use_hip: bool) {
.cpp(true)
.pic(true)
.warnings(false)
.include(includes.join("include"))
.include(includes.join("include/torch/csrc/api/include"))
.include(includes)
.include(includes.join("torch/csrc/api/include"))
.file("libtch/torch_api.cpp")
.file(cuda_dependency)
.compile("tch");
}
_ => panic!("Unsupported OS"),
os => panic!("Unsupported OS: {}", os),
};
}

fn main() {
if !cfg!(feature = "doc-only") {
let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");

let libtorch = prepare_libtorch_dir();

let libtorch_includes: PathBuf = env_var_target_specific("LIBTORCH_INCLUDE")
.map(Into::into)
.unwrap_or_else(|_| libtorch.join("include"));
let libtorch_lib: PathBuf = env_var_target_specific("LIBTORCH_LIB")
.map(Into::into)
.unwrap_or_else(|_| libtorch.join("lib"));
let libtorch_lite: bool = env_var_target_specific("LIBTORCH_LITE")
.map(|s| s.parse().unwrap_or(true))
.unwrap_or(true);

// use_cuda is a hacky way to detect whether cuda is available and
// if it's the case link to it by explicitly depending on a symbol
// from the torch_cuda library.
Expand All @@ -220,42 +245,56 @@ fn main() {
// This will be available starting from cargo 1.50 but will be a nightly
// only option to start with.
// https://github.com/rust-lang/cargo/blob/master/CHANGELOG.md
let use_cuda = libtorch.join("lib").join("libtorch_cuda.so").exists()
|| libtorch.join("lib").join("torch_cuda.dll").exists();
let use_cuda_cu = libtorch.join("lib").join("libtorch_cuda_cu.so").exists()
|| libtorch.join("lib").join("torch_cuda_cu.dll").exists();
let use_cuda_cpp = libtorch.join("lib").join("libtorch_cuda_cpp.so").exists()
|| libtorch.join("lib").join("torch_cuda_cpp.dll").exists();
let use_hip = libtorch.join("lib").join("libtorch_hip.so").exists()
|| libtorch.join("lib").join("torch_hip.dll").exists();
println!("cargo:rustc-link-search=native={}", libtorch.join("lib").display());
let use_cuda = libtorch_lib.join("libtorch_cuda.so").exists()
|| libtorch_lib.join("torch_cuda.dll").exists();
let use_cuda_cu = libtorch_lib.join("libtorch_cuda_cu.so").exists()
|| libtorch_lib.join("torch_cuda_cu.dll").exists();
let use_cuda_cpp = libtorch_lib.join("libtorch_cuda_cpp.so").exists()
|| libtorch_lib.join("torch_cuda_cpp.dll").exists();
let use_hip = libtorch_lib.join("libtorch_hip.so").exists()
|| libtorch_lib.join("torch_hip.dll").exists();

println!("cargo:rustc-link-search=native={}", libtorch_lib.display());

make(&libtorch, use_cuda, use_hip);
make(&libtorch_includes, &libtorch_lib, use_cuda, use_hip, &os);

println!("cargo:rustc-link-lib=static=tch");
if use_cuda {
println!("cargo:rustc-link-lib=torch_cuda");
}
if use_cuda_cu {
println!("cargo:rustc-link-lib=torch_cuda_cu");
}
if use_cuda_cpp {
println!("cargo:rustc-link-lib=torch_cuda_cpp");
}
if use_hip {
println!("cargo:rustc-link-lib=torch_hip");
}
println!("cargo:rustc-link-lib=torch_cpu");
println!("cargo:rustc-link-lib=torch");
println!("cargo:rustc-link-lib=c10");
if use_hip {
println!("cargo:rustc-link-lib=c10_hip");
}

let target = env::var("TARGET").unwrap();
match os.as_str() {
"windows" | "linux" | "macos" => {
if use_cuda {
println!("cargo:rustc-link-lib=torch_cuda");
}
if use_cuda_cu {
println!("cargo:rustc-link-lib=torch_cuda_cu");
}
if use_cuda_cpp {
println!("cargo:rustc-link-lib=torch_cuda_cpp");
}
if use_hip {
println!("cargo:rustc-link-lib=torch_hip");
}
println!("cargo:rustc-link-lib=torch_cpu");
println!("cargo:rustc-link-lib=torch");
println!("cargo:rustc-link-lib=c10");
if use_hip {
println!("cargo:rustc-link-lib=c10_hip");
}

let target = env::var("TARGET").unwrap();

if !target.contains("msvc") && !target.contains("apple") {
println!("cargo:rustc-link-lib=gomp");
if !target.contains("msvc") && !target.contains("apple") {
println!("cargo:rustc-link-lib=gomp");
}
}
"android" => {
if libtorch_lite {
println!("cargo:rustc-link-lib=pytorch_jni_lite");
} else {
println!("cargo:rustc-link-lib=pytorch_jni");
}
}
other => panic!("unsupported OS: {}", other),
}
}
}
6 changes: 6 additions & 0 deletions torch-sys/libtch/torch_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs,
}

at::Device device_of_int(int d) {
if (d == -3) return at::Device(at::kVulkan);
if (d == -2) return at::Device(at::kMPS);
if (d < 0) return at::Device(at::kCPU);
return at::Device(at::kCUDA, /*index=*/d);
Expand Down Expand Up @@ -926,6 +927,11 @@ int atc_cudnn_is_available() {
return -1;
}

int atc_vulkan_is_available() {
PROTECT(return torch::is_vulkan_available();)
return -1;
}

void atc_manual_seed(uint64_t seed) {
PROTECT(return torch::cuda::manual_seed(seed);)
}
Expand Down
Loading