diff --git a/Cargo.toml b/Cargo.toml index 2b92fb04..3716864a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ safetensors = "0.3.0" cpython = { version = "0.7.1", optional = true } regex = { version = "1.6.0", optional = true } image = { version = "0.24.5", optional = true } +pyo3 = { version = "0.18.3", optional = true } [dev-dependencies] anyhow = "1" @@ -40,11 +41,12 @@ members = ["torch-sys"] [features] default = ["torch-sys/download-libtorch"] python = ["cpython"] +torch_python = ["torch-sys/python", "pyo3"] doc-only = ["torch-sys/doc-only"] cuda-tests = [] [package.metadata.docs.rs] -features = [ "doc-only" ] +features = ["doc-only"] [[example]] name = "reinforcement-learning" diff --git a/examples/python-entropy/.github/workflows/CI.yml b/examples/python-entropy/.github/workflows/CI.yml new file mode 100644 index 00000000..e35652a8 --- /dev/null +++ b/examples/python-entropy/.github/workflows/CI.yml @@ -0,0 +1,66 @@ +name: CI + +on: + push: + pull_request: + +jobs: + linux: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: messense/maturin-action@v1 + with: + manylinux: auto + command: build + args: --release -o dist + - name: Upload wheels + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist + + windows: + runs-on: windows-latest + steps: + - uses: actions/checkout@v2 + - uses: messense/maturin-action@v1 + with: + command: build + args: --release --no-sdist -o dist + - name: Upload wheels + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist + + macos: + runs-on: macos-latest + steps: + - uses: actions/checkout@v2 + - uses: messense/maturin-action@v1 + with: + command: build + args: --release --no-sdist -o dist --universal2 + - name: Upload wheels + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist + + release: + name: Release + runs-on: ubuntu-latest + if: "startsWith(github.ref, 'refs/tags/')" + needs: [ macos, windows, linux ] + steps: + - uses: actions/download-artifact@v2 + with: + name: wheels + - name: Publish to PyPI + uses: messense/maturin-action@v1 + env: + MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + with: + command: upload + args: --skip-existing * \ No newline at end of file diff --git a/examples/python-entropy/.gitignore b/examples/python-entropy/.gitignore new file mode 100644 index 00000000..af3ca5ef --- /dev/null +++ b/examples/python-entropy/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version \ No newline at end of file diff --git a/examples/python-entropy/Cargo.toml b/examples/python-entropy/Cargo.toml new file mode 100644 index 00000000..fc07371e --- /dev/null +++ b/examples/python-entropy/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "python-entropy" +version = "0.1.0" +authors = ["Egor Dmitriev "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "python_entropy" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.18.3", features = ["extension-module"] } +tch = { path = "../../", features = ["torch_python"], default-features = false } +torch-sys = { path = "../../torch-sys", features = ["python"], default-features = false } \ No newline at end of file diff --git a/examples/python-entropy/README.md b/examples/python-entropy/README.md new file mode 100644 index 00000000..d1417e88 --- /dev/null +++ b/examples/python-entropy/README.md @@ -0,0 +1,20 @@ +# Python Extension Example + +## Instructions +Run the following commands to download the latest tch-rs version and build `python-entropy` extension: + +```bash +git clone https://github.com/LaurentMazare/tch-rs.git +cd tch-rs +``` + +```bash +cd examples/python-entropy +pip install maturin +maturin develop +``` + +Run `main.py` to test the extension: +```bash +python main.py +``` \ No newline at end of file diff --git a/examples/python-entropy/main.py b/examples/python-entropy/main.py new file mode 100644 index 00000000..3e8b3d23 --- /dev/null +++ b/examples/python-entropy/main.py @@ -0,0 +1,9 @@ +import python_entropy as pe +import torch + +metric = pe.EntropyMetric(2) +print(f'Initial counter: {metric.get_counter()}') +metric.update(torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])) +print(f'Updated counter: {metric.get_counter()}') +result = metric.compute() +print(f'Entropy: {result}') diff --git a/examples/python-entropy/pyproject.toml b/examples/python-entropy/pyproject.toml new file mode 100644 index 00000000..b88a6c00 --- /dev/null +++ b/examples/python-entropy/pyproject.toml @@ -0,0 +1,13 @@ +[build-system] +requires = ["maturin>=0.12,<0.13"] +build-backend = "maturin" + +[project] +name = "python-entropy" +requires-python = ">=3.6" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] + diff --git a/examples/python-entropy/src/lib.rs b/examples/python-entropy/src/lib.rs new file mode 100644 index 00000000..532c33b7 --- /dev/null +++ b/examples/python-entropy/src/lib.rs @@ -0,0 +1,40 @@ +use pyo3::prelude::*; +use std::ops::{Div, Mul}; +use tch::{Device, Kind, Scalar, Tensor}; + +#[pyclass(name = "EntropyMetric")] +pub struct EntropyMetric { + pub counter: Tensor, +} + +#[pymethods] +impl EntropyMetric { + #[new] + fn __new__(n_classes: i64) -> Self { + EntropyMetric { counter: Tensor::zeros(&[n_classes], (Kind::Float, Device::Cpu)) } + } + + fn update(&mut self, x: Tensor) -> PyResult<()> { + let ones = Tensor::ones(&[x.size()[0]], (Kind::Float, Device::Cpu)); + self.counter = self.counter.scatter_add(0, &x, &ones); + + Ok(()) + } + + fn get_counter(&self) -> PyResult<&Tensor> { + Ok(&self.counter) + } + + fn compute(&self) -> PyResult { + let counts = self.counter.masked_select(&self.counter.gt(Scalar::from(0.0))); + let probs = counts.div(&self.counter.sum(Kind::Float)); + let entropy = (probs.neg().mul(probs.log())).sum(Kind::Float); + Ok(entropy) + } +} + +#[pymodule] +fn python_entropy(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} diff --git a/src/wrappers/tensor.rs b/src/wrappers/tensor.rs index 454479d8..621fac44 100644 --- a/src/wrappers/tensor.rs +++ b/src/wrappers/tensor.rs @@ -21,6 +21,8 @@ pub struct Tensor { unsafe impl Send for Tensor {} +unsafe impl Sync for Tensor {} + pub extern "C" fn add_callback(data: *mut c_void, name: *const c_char, c_tensor: *mut C_tensor) { let name = unsafe { std::ffi::CStr::from_ptr(name).to_str().unwrap() }; let name = name.replace('|', "."); @@ -879,3 +881,41 @@ impl Reduction { } } } + +#[cfg(feature = "torch_python")] +use pyo3::{AsPyPointer, PyObject, Python, ToPyObject}; + +#[cfg(feature = "torch_python")] +impl<'a> pyo3::FromPyObject<'a> for Tensor { + fn extract(ob: &'a pyo3::PyAny) -> pyo3::PyResult { + if unsafe_torch!(thp_variable_check(ob.as_ptr() as *mut c_void)) { + Ok(unsafe_torch!(Tensor::from_ptr(thp_variable_unpack(ob.as_ptr() as *mut c_void)))) + } else { + Err(pyo3::exceptions::PyTypeError::new_err("Expected a torch.Tensor")) + } + } +} + +#[cfg(feature = "torch_python")] +impl pyo3::ToPyObject for Tensor { + fn to_object(&self, py: Python) -> PyObject { + let pyobj = unsafe_torch!(thp_variable_wrap(self.as_ptr() as *mut torch_sys::C_tensor) + as *mut pyo3::ffi::PyObject); + + unsafe { pyo3::PyObject::from_owned_ptr(py, pyobj) } + } +} + +#[cfg(feature = "torch_python")] +impl<'a> pyo3::IntoPy for &'a Tensor { + fn into_py(self, py: Python) -> PyObject { + self.to_object(py) + } +} + +#[cfg(feature = "torch_python")] +impl pyo3::IntoPy for Tensor { + fn into_py(self, py: Python) -> PyObject { + self.to_object(py) + } +} diff --git a/torch-sys/Cargo.toml b/torch-sys/Cargo.toml index 4694bd52..5f024e6f 100644 --- a/torch-sys/Cargo.toml +++ b/torch-sys/Cargo.toml @@ -25,6 +25,7 @@ zip = "0.6" [features] download-libtorch = ["ureq", "serde", "serde_json"] +python = [] doc-only = [] [package.metadata.docs.rs] diff --git a/torch-sys/build.rs b/torch-sys/build.rs index aebc49af..cb937c7d 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -8,6 +8,7 @@ use anyhow::Context; use std::path::{Path, PathBuf}; +use std::process::Command; use std::{env, fs, io}; const TORCH_VERSION: &str = "2.0.0"; @@ -42,11 +43,13 @@ struct PyPiPackageUrl { url: String, filename: String, } + #[cfg(feature = "download-libtorch")] #[derive(serde::Deserialize, Debug)] struct PyPiPackage { urls: Vec, } + #[cfg(feature = "download-libtorch")] fn get_pypi_wheel_url_for_aarch64_macosx() -> anyhow::Result { let pypi_url = format!("https://pypi.org/pypi/torch/{TORCH_VERSION}/json"); @@ -113,6 +116,47 @@ fn check_system_location() -> Option { } } +fn find_python() -> String { + env::var("PYTHON3").ok().unwrap_or_else(|| { + let candidates = if cfg!(windows) { + ["python3.exe", "python.exe"] + } else { + ["python3", "python"] + }; + for &name in &candidates { + if Command::new(name) + .arg("--version") + .output() + .ok() + .map_or(false, |out| out.status.success()) + { + return name.to_owned(); + } + } + panic!( + "Can't find python (tried {})! Try fixing PATH or setting the PYTHON_INCLUDE_DIRS env var explicitly", + candidates.join(", ") + ) + }) +} + +fn find_python_include_dir() -> PathBuf { + if let Ok(python_dir) = env_var_rerun("PYTHON_INCLUDE_DIRS") { + PathBuf::from(python_dir) + } else { + let python = find_python(); + let output = Command::new(python) + .arg("-c") + .arg("from sysconfig import get_paths as gp; print(gp()['include'])") + .output() + .expect("Failed to run python") + .stdout; + let python_dir = + String::from_utf8(output).expect("Python output not utf8").trim().to_owned(); + PathBuf::from(python_dir) + } +} + fn prepare_libtorch_dir() -> PathBuf { let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); @@ -170,7 +214,7 @@ fn prepare_libtorch_dir() -> PathBuf { } else { format!("https://download.pytorch.org/libtorch/cpu/libtorch-macos-{TORCH_VERSION}.zip") } - }, + } "windows" => format!( "https://download.pytorch.org/libtorch/{}/libtorch-win-shared-with-deps-{}{}.zip", device, TORCH_VERSION, match device.as_ref() { @@ -194,7 +238,7 @@ fn prepare_libtorch_dir() -> PathBuf { } } -fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { +fn make>(libtorch: P, use_cuda: bool, use_hip: bool, use_python: bool) { let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); let includes: PathBuf = env_var_rerun("LIBTORCH_INCLUDE") .map(Into::into) @@ -203,6 +247,14 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { .map(Into::into) .unwrap_or_else(|_| libtorch.as_ref().to_owned()); + let python_includes = if use_python { + let python_include_dir = find_python_include_dir(); + vec![python_include_dir] + } else { + vec![] + }; + let use_python_flag = if use_python { "1".to_owned() } else { "0".to_owned() }; + let cuda_dependency = if use_cuda || use_hip { "libtch/dummy_cuda_dependency.cpp" } else { @@ -226,9 +278,11 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { .warnings(false) .include(includes.join("include")) .include(includes.join("include/torch/csrc/api/include")) + .includes(python_includes) .flag(&format!("-Wl,-rpath={}", lib.join("lib").display())) .flag("-std=c++14") .flag(&format!("-D_GLIBCXX_USE_CXX11_ABI={libtorch_cxx11_abi}")) + .flag(&format!("-DWITH_PYTHON={}", use_python_flag)) .file("libtch/torch_api.cpp") .file(cuda_dependency) .compile("tch"); @@ -243,6 +297,8 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { .warnings(false) .include(includes.join("include")) .include(includes.join("include/torch/csrc/api/include")) + .includes(python_includes) + .flag(&format!("-DWITH_PYTHON={}", use_python_flag)) .file("libtch/torch_api.cpp") .file(cuda_dependency) .compile("tch"); @@ -269,35 +325,54 @@ fn main() { // only option to start with. // https://github.com/rust-lang/cargo/blob/master/CHANGELOG.md let use_cuda = libtorch.join("lib").join("libtorch_cuda.so").exists() - || libtorch.join("lib").join("torch_cuda.dll").exists(); + || libtorch.join("lib").join("torch_cuda.dll").exists() + || libtorch.join("lib").join("libtorch_cuda.dylib").exists(); let use_cuda_cu = libtorch.join("lib").join("libtorch_cuda_cu.so").exists() - || libtorch.join("lib").join("torch_cuda_cu.dll").exists(); + || libtorch.join("lib").join("torch_cuda_cu.dll").exists() + || libtorch.join("lib").join("libtorch_cuda_cu.dylib").exists(); let use_cuda_cpp = libtorch.join("lib").join("libtorch_cuda_cpp.so").exists() - || libtorch.join("lib").join("torch_cuda_cpp.dll").exists(); + || libtorch.join("lib").join("torch_cuda_cpp.dll").exists() + || libtorch.join("lib").join("libtorch_cuda_cpp.dylib").exists(); let use_hip = libtorch.join("lib").join("libtorch_hip.so").exists() - || libtorch.join("lib").join("torch_hip.dll").exists(); + || libtorch.join("lib").join("torch_hip.dll").exists() + || libtorch.join("lib").join("libtorch_hip.dylib").exists(); + + let use_python = cfg!(feature = "python"); + if use_python + && !libtorch.join("lib").join("libtorch_python.so").exists() + && !libtorch.join("lib").join("torch_python.dll").exists() + && !libtorch.join("lib").join("libtorch_python.dylib").exists() + { + panic!("libtorch_python.so or torch_python.dll or libtorch_python.dylib not found in {}", libtorch.join("lib").display()); + } + println!("cargo:rustc-link-search=native={}", libtorch.join("lib").display()); - make(&libtorch, use_cuda, use_hip); + make(&libtorch, use_cuda, use_hip, use_python); + + let link_type = if use_python { "dylib" } else { "static" }; println!("cargo:rustc-link-lib=static=tch"); if use_cuda { - println!("cargo:rustc-link-lib=torch_cuda"); + println!("cargo:rustc-link-lib={}=torch_cuda", link_type); } if use_cuda_cu { - println!("cargo:rustc-link-lib=torch_cuda_cu"); + println!("cargo:rustc-link-lib={}=torch_cuda_cu", link_type); } if use_cuda_cpp { - println!("cargo:rustc-link-lib=torch_cuda_cpp"); + println!("cargo:rustc-link-lib={}=torch_cuda_cpp", link_type); } if use_hip { - println!("cargo:rustc-link-lib=torch_hip"); + println!("cargo:rustc-link-lib={}=torch_hip", link_type); } - println!("cargo:rustc-link-lib=torch_cpu"); - println!("cargo:rustc-link-lib=torch"); - println!("cargo:rustc-link-lib=c10"); + println!("cargo:rustc-link-lib={}=torch_cpu", link_type); + println!("cargo:rustc-link-lib={}=torch", link_type); + println!("cargo:rustc-link-lib={}=c10", link_type); if use_hip { - println!("cargo:rustc-link-lib=c10_hip"); + println!("cargo:rustc-link-lib={}=c10_hip", link_type); + } + if use_python { + println!("cargo:rustc-link-lib={}=torch_python", link_type); } let target = env::var("TARGET").unwrap(); diff --git a/torch-sys/libtch/torch_api.cpp b/torch-sys/libtch/torch_api.cpp index ad37a20e..9b953b0a 100644 --- a/torch-sys/libtch/torch_api.cpp +++ b/torch-sys/libtch/torch_api.cpp @@ -1,4 +1,7 @@ #include +#if WITH_PYTHON +#include +#endif #include #include #include @@ -1639,6 +1642,26 @@ void ati_free(ivalue i) { delete(i); } +#if WITH_PYTHON +const at::Tensor* thp_variable_unpack(PyObject *obj) { + PROTECT( + return new torch::Tensor(THPVariable_Unpack(obj)); + ) + return nullptr; +} + +bool thp_variable_check(PyObject* obj) { + return THPVariable_Check(obj); +} + +PyObject* thp_variable_wrap(tensor var){ + PROTECT( + return THPVariable_Wrap(*var); + ) + return nullptr; +} +#endif + void at_set_graph_executor_optimize(bool o) { torch::jit::setGraphExecutorOptimize(o); } diff --git a/torch-sys/libtch/torch_api.h b/torch-sys/libtch/torch_api.h index 970a6702..db9abb60 100644 --- a/torch-sys/libtch/torch_api.h +++ b/torch-sys/libtch/torch_api.h @@ -1,6 +1,9 @@ #ifndef __TORCH_API_H__ #define __TORCH_API_H__ #include +#if WITH_PYTHON +#include +#endif #ifdef __cplusplus thread_local char *torch_last_err = nullptr; @@ -280,6 +283,13 @@ bool tch_read_stream_seek_start(void *stream_ptr, uint64_t pos, uint64_t *new_po bool tch_read_stream_seek_end(void *stream_ptr, int64_t pos, uint64_t *new_pos); bool tch_read_stream_read(void *stream_ptr, uint8_t *buf, size_t size, size_t *new_pos); +// torch python +#if WITH_PYTHON +const at::Tensor* thp_variable_unpack(PyObject* obj); +bool thp_variable_check(PyObject* obj); +PyObject* thp_variable_wrap(tensor var); +#endif + #include "torch_api_generated.h" #ifdef __cplusplus diff --git a/torch-sys/src/lib.rs b/torch-sys/src/lib.rs index c0771b33..90adfe76 100644 --- a/torch-sys/src/lib.rs +++ b/torch-sys/src/lib.rs @@ -331,6 +331,15 @@ extern "C" { pub fn atm_get_tensor_expr_fuser_enabled() -> bool; } +#[cfg(feature = "python")] +extern "C" { + pub fn thp_variable_unpack(obj: *mut c_void) -> *mut C_tensor; + + pub fn thp_variable_check(obj: *mut c_void) -> bool; + + pub fn thp_variable_wrap(var: *mut C_tensor) -> *mut c_void; +} + extern "C" { pub fn dummy_cuda_dependency(); }