From 1c9891afb95120b86db3c42aed6aaf77838877e4 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 13 May 2023 09:18:17 +0100 Subject: [PATCH 1/3] Rework the build script. --- Cargo.toml | 5 +- build.rs | 3 + torch-sys/Cargo.toml | 1 + torch-sys/build.rs | 303 +++++++++++++++++++++++++++---------------- 4 files changed, 197 insertions(+), 115 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2b92fb04..ba64610e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,8 @@ members = ["torch-sys"] [features] default = ["torch-sys/download-libtorch"] -python = ["cpython"] +python-libtorch = ["torch-sys/python-libtorch"] +rl_python = ["cpython"] doc-only = ["torch-sys/doc-only"] cuda-tests = [] @@ -48,7 +49,7 @@ features = [ "doc-only" ] [[example]] name = "reinforcement-learning" -required-features = ["python"] +required-features = ["rl_python"] [[example]] name = "stable-diffusion" diff --git a/build.rs b/build.rs index e78a7a0c..9d48dd37 100644 --- a/build.rs +++ b/build.rs @@ -2,6 +2,9 @@ fn main() { let os = std::env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); match os.as_str() { "linux" | "windows" => { + if let Some(lib_path) = std::env::var_os("DEP_TCH_LIBTORCH_LIB") { + println!("cargo:rustc-link-arg=-Wl,-rpath={}", lib_path.to_string_lossy()); + } println!("cargo:rustc-link-arg=-Wl,--no-as-needed"); println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries"); println!("cargo:rustc-link-arg=-ltorch"); diff --git a/torch-sys/Cargo.toml b/torch-sys/Cargo.toml index 4694bd52..4002f532 100644 --- a/torch-sys/Cargo.toml +++ b/torch-sys/Cargo.toml @@ -26,6 +26,7 @@ zip = "0.6" [features] download-libtorch = ["ureq", "serde", "serde_json"] doc-only = [] +python-libtorch = [] [package.metadata.docs.rs] features = [ "doc-only" ] diff --git a/torch-sys/build.rs b/torch-sys/build.rs index aebc49af..8cea689b 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -11,6 +11,32 @@ use std::path::{Path, PathBuf}; use std::{env, fs, io}; const TORCH_VERSION: &str = "2.0.0"; +const PYTHON_PRINT_PYTORCH_DETAILS: &str = r" +import torch +from torch.utils import cpp_extension +print('LIBTORCH_CXX11:', torch._C._GLIBCXX_USE_CXX11_ABI) +for include_path in cpp_extension.include_paths(): + print('LIBTORCH_INCLUDE:', include_path) +for library_path in cpp_extension.library_paths(): + print('LIBTORCH_LIB:', library_path) +"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Os { + Linux, + Macos, + Windows, +} + +#[allow(dead_code)] +#[derive(Debug, Clone)] +struct SystemInfo { + os: Os, + python_interpreter: PathBuf, + cxx11_abi: String, + libtorch_include_dirs: Vec, + libtorch_lib_dir: PathBuf, +} #[cfg(feature = "ureq")] fn download>(source_url: &str, target_file: P) -> anyhow::Result<()> { @@ -104,48 +130,108 @@ fn env_var_rerun(name: &str) -> Result { env::var(name) } -fn check_system_location() -> Option { - let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); - - match os.as_str() { - "linux" => Path::new("/usr/lib/libtorch.so").exists().then(|| PathBuf::from("/usr")), - _ => None, +impl SystemInfo { + fn new() -> Self { + let os = match env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS").as_str() { + "linux" => Os::Linux, + "windows" => Os::Windows, + "macos" => Os::Macos, + os => panic!("unsupported TARGET_OS '{os}'"), + }; + // Locate the currently active Python binary, similar to: + // https://github.com/PyO3/maturin/blob/243b8ec91d07113f97a6fe74d9b2dcb88086e0eb/src/target.rs#L547 + let python_interpreter = match os { + Os::Windows => PathBuf::from("python.exe"), + Os::Linux | Os::Macos => { + if env::var_os("VIRTUAL_ENV").is_some() { + PathBuf::from("python") + } else { + PathBuf::from("python3") + } + } + }; + let mut libtorch_include_dirs = vec![]; + let mut libtorch_lib_dir = None; + let cxx11_abi = if env_var_rerun("LIBTORCH_USE_PYTORCH").is_ok() { + let output = std::process::Command::new(&python_interpreter) + .arg("-c") + .arg(PYTHON_PRINT_PYTORCH_DETAILS) + .output() + .map_err(|err| format!("error running {python_interpreter:?}: {err:?}")) + .unwrap(); + let mut cxx11_abi = None; + for line in String::from_utf8_lossy(&output.stdout).lines() { + match line.strip_prefix("LIBTORCH_CXX11: ") { + Some("True") => cxx11_abi = Some("1".to_owned()), + Some("False") => cxx11_abi = Some("0".to_owned()), + _ => {} + } + if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") { + libtorch_include_dirs.push(PathBuf::from(path)) + } + if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") { + libtorch_lib_dir = Some(PathBuf::from(path)) + } + } + match cxx11_abi { + Some(cxx11_abi) => cxx11_abi, + None => panic!("no cxx11 abi returned by python {output:?}"), + } + } else { + let libtorch = Self::prepare_libtorch_dir(os); + let includes = env_var_rerun("LIBTORCH_INCLUDE") + .map(PathBuf::from) + .unwrap_or_else(|_| libtorch.clone()); + let lib = env_var_rerun("LIBTORCH_LIB") + .map(PathBuf::from) + .unwrap_or_else(|_| libtorch.clone()); + libtorch_include_dirs.push(includes.join("include")); + libtorch_include_dirs.push(includes.join("include/torch/csrc/api/include")); + libtorch_lib_dir = Some(lib.join("lib")); + env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned()) + }; + let libtorch_lib_dir = libtorch_lib_dir.expect("no libtorch lib dir found"); + Self { os, python_interpreter, cxx11_abi, libtorch_include_dirs, libtorch_lib_dir } } -} -fn prepare_libtorch_dir() -> PathBuf { - let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); + fn check_system_location(os: Os) -> Option { + match os { + Os::Linux => Path::new("/usr/lib/libtorch.so").exists().then(|| PathBuf::from("/usr")), + _ => None, + } + } - let device = match env_var_rerun("TORCH_CUDA_VERSION") { - Ok(cuda_env) => match os.as_str() { - "linux" | "windows" => { - cuda_env.trim().to_lowercase().trim_start_matches("cu").split('.').take(2).fold( - "cu".to_owned(), - |mut acc, curr| { - acc += curr; - acc - }, - ) - } - os_str => panic!( - "CUDA was specified with `TORCH_CUDA_VERSION`, but pre-built \ - binaries with CUDA are only available for Linux and Windows, not: {}.", - os_str - ), - }, - Err(_) => "cpu".to_owned(), - }; + fn prepare_libtorch_dir(os: Os) -> PathBuf { + if let Ok(libtorch) = env_var_rerun("LIBTORCH") { + PathBuf::from(libtorch) + } else if let Some(pathbuf) = Self::check_system_location(os) { + pathbuf + } else { + let device = match env_var_rerun("TORCH_CUDA_VERSION") { + Ok(cuda_env) => match os { + Os::Linux | Os::Windows => cuda_env + .trim() + .to_lowercase() + .trim_start_matches("cu") + .split('.') + .take(2) + .fold("cu".to_owned(), |mut acc, curr| { + acc += curr; + acc + }), + os => panic!( + "CUDA was specified with `TORCH_CUDA_VERSION`, but pre-built \ + binaries with CUDA are only available for Linux and Windows, not: {os:?}.", + ), + }, + Err(_) => "cpu".to_owned(), + }; - if let Ok(libtorch) = env_var_rerun("LIBTORCH") { - PathBuf::from(libtorch) - } else if let Some(pathbuf) = check_system_location() { - pathbuf - } else { - let libtorch_dir = PathBuf::from(env::var("OUT_DIR").unwrap()).join("libtorch"); - if !libtorch_dir.exists() { - fs::create_dir(&libtorch_dir).unwrap_or_default(); - let libtorch_url = match os.as_str() { - "linux" => format!( + let libtorch_dir = PathBuf::from(env::var("OUT_DIR").unwrap()).join("libtorch"); + if !libtorch_dir.exists() { + fs::create_dir(&libtorch_dir).unwrap_or_default(); + let libtorch_url = match os { + Os::Linux => format!( "https://download.pytorch.org/libtorch/{}/libtorch-cxx11-abi-shared-with-deps-{}{}.zip", device, TORCH_VERSION, match device.as_ref() { "cpu" => "%2Bcpu", @@ -157,7 +243,7 @@ fn prepare_libtorch_dir() -> PathBuf { _ => panic!("unsupported device {}, TORCH_CUDA_VERSION may be set incorrectly?", device), } ), - "macos" => { + Os::Macos => { if env::var("CARGO_CFG_TARGET_ARCH") == Ok(String::from("aarch64")) { get_pypi_wheel_url_for_aarch64_macosx().expect( "Failed to retrieve torch from pypi. Pre-built version of libtorch for apple silicon are not available. @@ -171,7 +257,7 @@ fn prepare_libtorch_dir() -> PathBuf { format!("https://download.pytorch.org/libtorch/cpu/libtorch-macos-{TORCH_VERSION}.zip") } }, - "windows" => format!( + Os::Windows => format!( "https://download.pytorch.org/libtorch/{}/libtorch-win-shared-with-deps-{}{}.zip", device, TORCH_VERSION, match device.as_ref() { "cpu" => "%2Bcpu", @@ -182,78 +268,68 @@ fn prepare_libtorch_dir() -> PathBuf { "cu118" => "%2Bcu118", _ => "" }), - _ => panic!("Unsupported OS"), }; - let filename = libtorch_dir.join(format!("v{TORCH_VERSION}.zip")); - download(&libtorch_url, &filename).unwrap(); - extract(&filename, &libtorch_dir).unwrap(); + let filename = libtorch_dir.join(format!("v{TORCH_VERSION}.zip")); + download(&libtorch_url, &filename).unwrap(); + extract(&filename, &libtorch_dir).unwrap(); + } + libtorch_dir.join("libtorch") } - - libtorch_dir.join("libtorch") } -} -fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { - let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); - let includes: PathBuf = env_var_rerun("LIBTORCH_INCLUDE") - .map(Into::into) - .unwrap_or_else(|_| libtorch.as_ref().to_owned()); - let lib: PathBuf = env_var_rerun("LIBTORCH_LIB") - .map(Into::into) - .unwrap_or_else(|_| libtorch.as_ref().to_owned()); - - let cuda_dependency = if use_cuda || use_hip { - "libtch/dummy_cuda_dependency.cpp" - } else { - "libtch/fake_cuda_dependency.cpp" - }; - println!("cargo:rerun-if-changed={}", cuda_dependency); - println!("cargo:rerun-if-changed=libtch/torch_api.cpp"); - println!("cargo:rerun-if-changed=libtch/torch_api.h"); - println!("cargo:rerun-if-changed=libtch/torch_api_generated.cpp.h"); - println!("cargo:rerun-if-changed=libtch/torch_api_generated.h"); - println!("cargo:rerun-if-changed=libtch/stb_image_write.h"); - println!("cargo:rerun-if-changed=libtch/stb_image_resize.h"); - println!("cargo:rerun-if-changed=libtch/stb_image.h"); - match os.as_str() { - "linux" | "macos" => { - let libtorch_cxx11_abi = - env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned()); - cc::Build::new() - .cpp(true) - .pic(true) - .warnings(false) - .include(includes.join("include")) - .include(includes.join("include/torch/csrc/api/include")) - .flag(&format!("-Wl,-rpath={}", lib.join("lib").display())) - .flag("-std=c++14") - .flag(&format!("-D_GLIBCXX_USE_CXX11_ABI={libtorch_cxx11_abi}")) - .file("libtch/torch_api.cpp") - .file(cuda_dependency) - .compile("tch"); - } - "windows" => { - // TODO: Pass "/link" "LIBPATH:{}" to cl.exe in order to emulate rpath. - // Not yet supported by cc=rs. - // https://github.com/alexcrichton/cc-rs/issues/323 - cc::Build::new() - .cpp(true) - .pic(true) - .warnings(false) - .include(includes.join("include")) - .include(includes.join("include/torch/csrc/api/include")) - .file("libtch/torch_api.cpp") - .file(cuda_dependency) - .compile("tch"); - } - _ => panic!("Unsupported OS"), - }; + fn make(&self, use_cuda: bool, use_hip: bool) { + let cuda_dependency = if use_cuda || use_hip { + "libtch/dummy_cuda_dependency.cpp" + } else { + "libtch/fake_cuda_dependency.cpp" + }; + println!("cargo:rerun-if-changed={}", cuda_dependency); + println!("cargo:rerun-if-changed=libtch/torch_api.cpp"); + println!("cargo:rerun-if-changed=libtch/torch_api.h"); + println!("cargo:rerun-if-changed=libtch/torch_api_generated.cpp.h"); + println!("cargo:rerun-if-changed=libtch/torch_api_generated.h"); + println!("cargo:rerun-if-changed=libtch/stb_image_write.h"); + println!("cargo:rerun-if-changed=libtch/stb_image_resize.h"); + println!("cargo:rerun-if-changed=libtch/stb_image.h"); + match self.os { + Os::Linux | Os::Macos => { + // Pass the libtorch lib dir to crates that use torch-sys. This will be available + // as DEP_TORCH_SYS_LIBTORCH_LIB, see: + // https://doc.rust-lang.org/cargo/reference/build-scripts.html#the-links-manifest-key + println!("cargo:libtorch_lib={}", self.libtorch_lib_dir.display()); + cc::Build::new() + .cpp(true) + .pic(true) + .warnings(false) + .includes(&self.libtorch_include_dirs) + .flag(&format!("-Wl,-rpath={}", self.libtorch_lib_dir.display())) + .flag("-std=c++14") + .flag(&format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi)) + .file("libtch/torch_api.cpp") + .file(cuda_dependency) + .compile("tch"); + } + Os::Windows => { + // TODO: Pass "/link" "LIBPATH:{}" to cl.exe in order to emulate rpath. + // Not yet supported by cc=rs. + // https://github.com/alexcrichton/cc-rs/issues/323 + cc::Build::new() + .cpp(true) + .pic(true) + .warnings(false) + .includes(&self.libtorch_include_dirs) + .file("libtch/torch_api.cpp") + .file(cuda_dependency) + .compile("tch"); + } + }; + } } fn main() { if !cfg!(feature = "doc-only") { - let libtorch = prepare_libtorch_dir(); + let system_info = SystemInfo::new(); // use_cuda is a hacky way to detect whether cuda is available and // if it's the case link to it by explicitly depending on a symbol // from the torch_cuda library. @@ -268,17 +344,18 @@ fn main() { // This will be available starting from cargo 1.50 but will be a nightly // only option to start with. // https://github.com/rust-lang/cargo/blob/master/CHANGELOG.md - let use_cuda = libtorch.join("lib").join("libtorch_cuda.so").exists() - || libtorch.join("lib").join("torch_cuda.dll").exists(); - let use_cuda_cu = libtorch.join("lib").join("libtorch_cuda_cu.so").exists() - || libtorch.join("lib").join("torch_cuda_cu.dll").exists(); - let use_cuda_cpp = libtorch.join("lib").join("libtorch_cuda_cpp.so").exists() - || libtorch.join("lib").join("torch_cuda_cpp.dll").exists(); - let use_hip = libtorch.join("lib").join("libtorch_hip.so").exists() - || libtorch.join("lib").join("torch_hip.dll").exists(); - println!("cargo:rustc-link-search=native={}", libtorch.join("lib").display()); + let si_lib = &system_info.libtorch_lib_dir; + let use_cuda = + si_lib.join("libtorch_cuda.so").exists() || si_lib.join("torch_cuda.dll").exists(); + let use_cuda_cu = si_lib.join("libtorch_cuda_cu.so").exists() + || si_lib.join("torch_cuda_cu.dll").exists(); + let use_cuda_cpp = si_lib.join("libtorch_cuda_cpp.so").exists() + || si_lib.join("torch_cuda_cpp.dll").exists(); + let use_hip = + si_lib.join("libtorch_hip.so").exists() || si_lib.join("torch_hip.dll").exists(); + println!("cargo:rustc-link-search=native={}", si_lib.display()); - make(&libtorch, use_cuda, use_hip); + system_info.make(use_cuda, use_hip); println!("cargo:rustc-link-lib=static=tch"); if use_cuda { From 5959885081a666ccd336f47230e8fa1f058a1f2a Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 13 May 2023 09:25:02 +0100 Subject: [PATCH 2/3] Update the readme. --- README.md | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b914e7fb..a910f8bd 100644 --- a/README.md +++ b/README.md @@ -23,14 +23,22 @@ your system. You can either: - Use the system-wide libtorch installation (default). - Install libtorch manually and let the build script know about it via the `LIBTORCH` environment variable. -- When a system-wide libtorch can't be found and `LIBTORCH` is not set, the build script will download a pre-built binary version -of libtorch. By default a CPU version is used. The `TORCH_CUDA_VERSION` environment variable -can be set to `cu117` in order to get a pre-built binary using CUDA 11.7. +- Use a Python PyTorch install, to do this set `LIBTORCH_USE_PYTORCH=1`. +- When a system-wide libtorch can't be found and `LIBTORCH` is not set, the + build script will download a pre-built binary version of libtorch. By default + a CPU version is used. The `TORCH_CUDA_VERSION` environment variable can be + set to `cu117` in order to get a pre-built binary using CUDA 11.7. ### System-wide Libtorch -The build script will look for a system-wide libtorch library in the following locations: -- In Linux: `/usr/lib/libtorch.so` +On linux platforms, the build script will look for a system-wide libtorch +library in `/usr/lib/libtorch.so`. + +### Python PyTorch Install + +If the `LIBTORCH_USE_PYTORCH` environment variable is set, the active python +interpreter is called to retrieve information about the torch python package. +This version is then linked against. ### Libtorch Manual Install From abcc7b8beeeb03c65774f1997ad3aaa01d7bc261 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 13 May 2023 09:34:30 +0100 Subject: [PATCH 3/3] Add some version check. --- torch-sys/build.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch-sys/build.rs b/torch-sys/build.rs index 8cea689b..f2954560 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -14,6 +14,7 @@ const TORCH_VERSION: &str = "2.0.0"; const PYTHON_PRINT_PYTORCH_DETAILS: &str = r" import torch from torch.utils import cpp_extension +print('LIBTORCH_VERSION:', torch.__version__.split('+')[0]) print('LIBTORCH_CXX11:', torch._C._GLIBCXX_USE_CXX11_ABI) for include_path in cpp_extension.include_paths(): print('LIBTORCH_INCLUDE:', include_path) @@ -161,6 +162,13 @@ impl SystemInfo { .unwrap(); let mut cxx11_abi = None; for line in String::from_utf8_lossy(&output.stdout).lines() { + if let Some(version) = line.strip_prefix("LIBTORCH_VERSION: ") { + if env_var_rerun("LIBTORCH_BYPASS_VERSION_CHECK").is_err() + && version != TORCH_VERSION + { + panic!("this tch version expects PyTorch {TORCH_VERSION}, got {version}") + } + } match line.strip_prefix("LIBTORCH_CXX11: ") { Some("True") => cxx11_abi = Some("1".to_owned()), Some("False") => cxx11_abi = Some("0".to_owned()),