diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index f4b533aa..87c790e6 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -71,4 +71,4 @@ jobs: - uses: actions-rs/cargo@v1 with: command: clippy - args: --examples --tests --all-features -- -D warnings + args: --examples --tests --all-features --all -- -D warnings diff --git a/CHANGELOG.md b/CHANGELOG.md index 55fe7061..dac5b03a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Changed +- Add a `pyo3-tch` crate for interacting with Python via PyO3 + [730](https://github.com/LaurentMazare/tch-rs/pull/730). - Expose the cuda fuser enabled flag, [728](https://github.com/LaurentMazare/tch-rs/pull/728). - Improved the safetensor error wrapping, diff --git a/Cargo.toml b/Cargo.toml index 6fa84927..9ef07b3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,11 @@ memmap2 = { version = "0.6.1", optional = true } anyhow = "1" [workspace] -members = ["torch-sys", "examples/python-extension"] +members = [ + "torch-sys", + "pyo3-tch", + "examples/python-extension", +] [features] download-libtorch = ["torch-sys/download-libtorch"] diff --git a/examples/python-extension/Cargo.toml b/examples/python-extension/Cargo.toml index d1842fe9..eaf62184 100644 --- a/examples/python-extension/Cargo.toml +++ b/examples/python-extension/Cargo.toml @@ -18,5 +18,6 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.18.3", features = ["extension-module"] } +pyo3-tch = { path = "../../pyo3-tch", version = "0.13.0" } tch = { path = "../..", features = ["python-extension"], version = "0.13.0" } torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.13.0" } diff --git a/examples/python-extension/src/lib.rs b/examples/python-extension/src/lib.rs index aad87bf1..fc6f1c6e 100644 --- a/examples/python-extension/src/lib.rs +++ b/examples/python-extension/src/lib.rs @@ -1,43 +1,9 @@ use pyo3::prelude::*; -use pyo3::{ - exceptions::{PyTypeError, PyValueError}, - AsPyPointer, -}; - -struct PyTensor(tch::Tensor); - -fn wrap_tch_err(err: tch::TchError) -> PyErr { - PyErr::new::(format!("{err:?}")) -} - -impl<'source> FromPyObject<'source> for PyTensor { - fn extract(ob: &'source PyAny) -> PyResult { - let ptr = ob.as_ptr() as *mut tch::python::CPyObject; - let tensor = unsafe { tch::Tensor::pyobject_unpack(ptr) }; - tensor - .map_err(wrap_tch_err)? - .ok_or_else(|| { - let type_ = ob.get_type(); - PyErr::new::(format!("expected a torch.Tensor, got {type_}")) - }) - .map(PyTensor) - } -} - -impl IntoPy for PyTensor { - fn into_py(self, py: Python<'_>) -> PyObject { - // There is no fallible alternative to ToPyObject/IntoPy at the moment so we return - // None on errors. https://github.com/PyO3/pyo3/issues/1813 - self.0.pyobject_wrap().map_or_else( - |_| py.None(), - |ptr| unsafe { PyObject::from_owned_ptr(py, ptr as *mut pyo3::ffi::PyObject) }, - ) - } -} +use pyo3_tch::{wrap_tch_err, PyTensor}; #[pyfunction] fn add_one(tensor: PyTensor) -> PyResult { - let tensor = tensor.0.f_add_scalar(1.0).map_err(wrap_tch_err)?; + let tensor = tensor.f_add_scalar(1.0).map_err(wrap_tch_err)?; Ok(PyTensor(tensor)) } diff --git a/pyo3-tch/Cargo.toml b/pyo3-tch/Cargo.toml new file mode 100644 index 00000000..336056c5 --- /dev/null +++ b/pyo3-tch/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "pyo3-tch" +version = "0.13.0" +authors = ["Laurent Mazare "] +edition = "2021" +build = "build.rs" + +description = "Manipulate PyTorch tensors from a Python extension via PyO3/tch." +repository = "https://github.com/LaurentMazare/tch-rs" +keywords = ["pytorch", "deep-learning", "machine-learning"] +categories = ["science"] +license = "MIT/Apache-2.0" + +[dependencies] +tch = { path = "..", features = ["python-extension"], version = "0.13.0" } +torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.13.0" } +pyo3 = { version = "0.18.3", features = ["extension-module"] } diff --git a/pyo3-tch/build.rs b/pyo3-tch/build.rs new file mode 100644 index 00000000..9d48dd37 --- /dev/null +++ b/pyo3-tch/build.rs @@ -0,0 +1,14 @@ +fn main() { + let os = std::env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); + match os.as_str() { + "linux" | "windows" => { + if let Some(lib_path) = std::env::var_os("DEP_TCH_LIBTORCH_LIB") { + println!("cargo:rustc-link-arg=-Wl,-rpath={}", lib_path.to_string_lossy()); + } + println!("cargo:rustc-link-arg=-Wl,--no-as-needed"); + println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries"); + println!("cargo:rustc-link-arg=-ltorch"); + } + _ => {} + } +} diff --git a/pyo3-tch/src/lib.rs b/pyo3-tch/src/lib.rs new file mode 100644 index 00000000..b95cdd2f --- /dev/null +++ b/pyo3-tch/src/lib.rs @@ -0,0 +1,46 @@ +use pyo3::prelude::*; +use pyo3::{ + exceptions::{PyTypeError, PyValueError}, + AsPyPointer, +}; +pub use tch; +pub use torch_sys; + +pub struct PyTensor(pub tch::Tensor); + +impl std::ops::Deref for PyTensor { + type Target = tch::Tensor; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub fn wrap_tch_err(err: tch::TchError) -> PyErr { + PyErr::new::(format!("{err:?}")) +} + +impl<'source> FromPyObject<'source> for PyTensor { + fn extract(ob: &'source PyAny) -> PyResult { + let ptr = ob.as_ptr() as *mut tch::python::CPyObject; + let tensor = unsafe { tch::Tensor::pyobject_unpack(ptr) }; + tensor + .map_err(wrap_tch_err)? + .ok_or_else(|| { + let type_ = ob.get_type(); + PyErr::new::(format!("expected a torch.Tensor, got {type_}")) + }) + .map(PyTensor) + } +} + +impl IntoPy for PyTensor { + fn into_py(self, py: Python<'_>) -> PyObject { + // There is no fallible alternative to ToPyObject/IntoPy at the moment so we return + // None on errors. https://github.com/PyO3/pyo3/issues/1813 + self.0.pyobject_wrap().map_or_else( + |_| py.None(), + |ptr| unsafe { PyObject::from_owned_ptr(py, ptr as *mut pyo3::ffi::PyObject) }, + ) + } +}