Skip to content

Commit

Permalink
Add an example of python extension (#704)
Browse files Browse the repository at this point in the history
* Expose the necessary functions to write a skeleton python extension.

* Add an example of python extension.

* Add a readme.

* Small cleanup.

* Changelog update.
  • Loading branch information
LaurentMazare authored May 13, 2023
1 parent 50803c7 commit e51a3c6
Show file tree
Hide file tree
Showing 19 changed files with 268 additions and 30 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ __pycache__

*.ot
*.safetensors
*.so
*.dylib
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ This documents the main changes to the `tch` crate.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v0.13.0 - unreleased yet
### Added
- Expose some functions so that Python extensions that operates on PyTorch
tensors can be written with `tch`,
[704](https://github.com/LaurentMazare/tch-rs/pull/704).
- Rework the torch-sys build script making it easier to leverage a Python
PyTorch install as a source for libtorch,
[703](https://github.com/LaurentMazare/tch-rs/pull/703).

## v0.12.0 - 2023-05-10
### Changed
- EfficientNet models have been reworked, pre-trained models used `safetensors`
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ image = { version = "0.24.5", optional = true }
anyhow = "1"

[workspace]
members = ["torch-sys"]
members = ["torch-sys", "examples/python-extension"]

[features]
default = ["torch-sys/download-libtorch"]
python-libtorch = ["torch-sys/python-libtorch"]
python-extension = ["torch-sys/python-extension"]
rl_python = ["cpython"]
doc-only = ["torch-sys/doc-only"]
cuda-tests = []
Expand Down
14 changes: 14 additions & 0 deletions examples/python-extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "tch_ext"
version = "0.1.0"
edition = "2021"
build = "build.rs"

[lib]
name = "tch_ext"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.18.3", features = ["extension-module"] }
tch = { path = "../..", features = ["python-extension"] }
torch-sys = { path = "../../torch-sys", features = ["python-extension"] }
21 changes: 21 additions & 0 deletions examples/python-extension/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## Python extensions using tch

This sample crate shows how to use `tch` to write a Python extension
that manipulates PyTorch tensors via [PyO3](https://github.com/PyO3/pyo3).

This is currently experimental hence requires some unsafe code until this has
been stabilized.

In order to build the extension and test the plugin, run the following in a
Python environment that has torch installed.

```bash
cd examples/python-extension
LIBTORCH_USE_PYTORCH=1 cargo build
python main.py
```

It is recommended to run the build with `LIBTORCH_USE_PYTORCH` set, this will
result in using the libtorch C++ library from the Python install in `tch` and
will ensure that this is at the proper version (having `tch` using a different
libtorch version from the one used by the Python runtime may result in segfaults).
14 changes: 14 additions & 0 deletions examples/python-extension/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
fn main() {
let os = std::env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
match os.as_str() {
"linux" | "windows" => {
if let Some(lib_path) = std::env::var_os("DEP_TCH_LIBTORCH_LIB") {
println!("cargo:rustc-link-arg=-Wl,-rpath={}", lib_path.to_string_lossy());
}
println!("cargo:rustc-link-arg=-Wl,--no-as-needed");
println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries");
println!("cargo:rustc-link-arg=-ltorch");
}
_ => {}
}
}
32 changes: 32 additions & 0 deletions examples/python-extension/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use pyo3::prelude::*;
use pyo3::{exceptions::PyValueError, AsPyPointer};

use tch;

fn wrap_tch_err(err: tch::TchError) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}

#[pyfunction]
fn add_one(t: PyObject) -> PyResult<PyObject> {
let tensor = unsafe { tch::Tensor::pyobject_unpack(t.as_ptr() as *mut tch::python::CPyObject) };
let tensor = tensor.map_err(wrap_tch_err)?;
let tensor = match tensor {
Some(tensor) => tensor,
None => Err(PyErr::new::<PyValueError, _>("t is not a PyTorch tensor object"))?,
};
let tensor = tensor + 1.0;
let tensor_ptr = tensor.pyobject_wrap().map_err(wrap_tch_err)?;
let pyobject = Python::with_gil(|py| unsafe {
PyObject::from_owned_ptr(py, tensor_ptr as *mut pyo3::ffi::PyObject)
});
Ok(pyobject)
}

/// A Python module implemented in Rust using tch to manipulate PyTorch
/// objects.
#[pymodule]
fn tch_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(add_one, m)?)?;
Ok(())
}
19 changes: 19 additions & 0 deletions examples/python-extension/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import shutil

# Copy the shared library to the current directory and rename it
# so that it's easy to import.
SHARED_LIB = "../../target/debug/libtch_ext.so"
TMP_LIB = "./tch_ext.so"

if os.path.exists(TMP_LIB):
os.remove(TMP_LIB)
shutil.copy(SHARED_LIB, TMP_LIB)

import torch
import tch_ext
print(tch_ext.__file__)

t = torch.tensor([[1., -1.], [1., -1.]])
print(t)
print(tch_ext.add_one(t))
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub use wrappers::jit::{self, CModule, IValue, TrainableCModule};
pub use wrappers::kind::{self, Kind};
pub use wrappers::layout::Layout;
pub use wrappers::optimizer::COptimizer;
#[cfg(feature = "python-extension")]
pub use wrappers::python;
pub use wrappers::scalar::Scalar;
pub use wrappers::utils;
pub use wrappers::{
Expand Down
2 changes: 2 additions & 0 deletions src/wrappers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub mod jit;
pub mod kind;
pub(crate) mod layout;
pub(crate) mod optimizer;
#[cfg(feature = "python-extension")]
pub mod python;
pub(crate) mod scalar;
pub(crate) mod stream;
pub(crate) mod tensor;
Expand Down
34 changes: 34 additions & 0 deletions src/wrappers/python.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use crate::{TchError, Tensor};
use torch_sys::python::{self, C_pyobject};

pub type CPyObject = C_pyobject;

/// Check whether an object is a wrapped tensor or not.
///
/// # Safety
/// Undefined behavior if the given pointer is not a valid PyObject.
pub unsafe fn pyobject_check(pyobject: *mut CPyObject) -> Result<bool, TchError> {
let v = unsafe_torch_err!(python::thp_variable_check(pyobject));
Ok(v)
}

impl Tensor {
/// Wrap a tensor in a Python object.
pub fn pyobject_wrap(&self) -> Result<*mut CPyObject, TchError> {
let v = unsafe_torch_err!(python::thp_variable_wrap(self.c_tensor));
Ok(v)
}

/// Unwrap a tensor stored in a Python object. This returns `Ok(None)` if
/// the object is not a wrapped tensor.
///
/// # Safety
/// Undefined behavior if the given pointer is not a valid PyObject.
pub unsafe fn pyobject_unpack(pyobject: *mut CPyObject) -> Result<Option<Self>, TchError> {
if !pyobject_check(pyobject)? {
return Ok(None);
}
let v = unsafe_torch_err!(python::thp_variable_unpack(pyobject));
Ok(Some(Tensor::from_ptr(v)))
}
}
4 changes: 2 additions & 2 deletions torch-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ libc = "0.2.0"
anyhow = "1.0"
cc = "1.0"
ureq = { version = "2.6", optional = true, features = ["json"] }
serde_json= { version = "1.0", optional = true }
serde_json = { version = "1.0", optional = true }
serde = { version = "1.0", optional = true, features = ["derive"] }
zip = "0.6"

[features]
download-libtorch = ["ureq", "serde", "serde_json"]
doc-only = []
python-libtorch = []
python-extension = []

[package.metadata.docs.rs]
features = [ "doc-only" ]
Loading

0 comments on commit e51a3c6

Please sign in to comment.