Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for moving PyTorch Tensors from python to rust via PyO3 #698

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ safetensors = "0.3.0"
cpython = { version = "0.7.1", optional = true }
regex = { version = "1.6.0", optional = true }
image = { version = "0.24.5", optional = true }
pyo3 = { version = "0.18.3", optional = true }

[dev-dependencies]
anyhow = "1"
Expand All @@ -40,11 +41,12 @@ members = ["torch-sys"]
[features]
default = ["torch-sys/download-libtorch"]
python = ["cpython"]
torch_python = ["torch-sys/python", "pyo3"]
doc-only = ["torch-sys/doc-only"]
cuda-tests = []

[package.metadata.docs.rs]
features = [ "doc-only" ]
features = ["doc-only"]

[[example]]
name = "reinforcement-learning"
Expand Down
66 changes: 66 additions & 0 deletions examples/python-entropy/.github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
name: CI

on:
push:
pull_request:

jobs:
linux:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: messense/maturin-action@v1
with:
manylinux: auto
command: build
args: --release -o dist
- name: Upload wheels
uses: actions/upload-artifact@v2
with:
name: wheels
path: dist

windows:
runs-on: windows-latest
steps:
- uses: actions/checkout@v2
- uses: messense/maturin-action@v1
with:
command: build
args: --release --no-sdist -o dist
- name: Upload wheels
uses: actions/upload-artifact@v2
with:
name: wheels
path: dist

macos:
runs-on: macos-latest
steps:
- uses: actions/checkout@v2
- uses: messense/maturin-action@v1
with:
command: build
args: --release --no-sdist -o dist --universal2
- name: Upload wheels
uses: actions/upload-artifact@v2
with:
name: wheels
path: dist

release:
name: Release
runs-on: ubuntu-latest
if: "startsWith(github.ref, 'refs/tags/')"
needs: [ macos, windows, linux ]
steps:
- uses: actions/download-artifact@v2
with:
name: wheels
- name: Publish to PyPI
uses: messense/maturin-action@v1
env:
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
with:
command: upload
args: --skip-existing *
72 changes: 72 additions & 0 deletions examples/python-entropy/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/target

# Byte-compiled / optimized / DLL files
__pycache__/
.pytest_cache/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
.Python
.venv/
env/
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
include/
man/
venv/
*.egg-info/
.installed.cfg
*.egg

# Installer logs
pip-log.txt
pip-delete-this-directory.txt
pip-selfcheck.json

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml

# Translations
*.mo

# Mr Developer
.mr.developer.cfg
.project
.pydevproject

# Rope
.ropeproject

# Django stuff:
*.log
*.pot

.DS_Store

# Sphinx documentation
docs/_build/

# PyCharm
.idea/

# VSCode
.vscode/

# Pyenv
.python-version
15 changes: 15 additions & 0 deletions examples/python-entropy/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "python-entropy"
version = "0.1.0"
authors = ["Egor Dmitriev <egordmitriev2@gmail.com>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "python_entropy"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.18.3", features = ["extension-module"] }
tch = { path = "../../", features = ["torch_python"], default-features = false }
torch-sys = { path = "../../torch-sys", features = ["python"], default-features = false }
20 changes: 20 additions & 0 deletions examples/python-entropy/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Python Extension Example

## Instructions
Run the following commands to download the latest tch-rs version and build `python-entropy` extension:

```bash
git clone https://github.com/LaurentMazare/tch-rs.git
cd tch-rs
```

```bash
cd examples/python-entropy
pip install maturin
maturin develop
```

Run `main.py` to test the extension:
```bash
python main.py
```
9 changes: 9 additions & 0 deletions examples/python-entropy/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import python_entropy as pe
import torch

metric = pe.EntropyMetric(2)
print(f'Initial counter: {metric.get_counter()}')
metric.update(torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]))
print(f'Updated counter: {metric.get_counter()}')
result = metric.compute()
print(f'Entropy: {result}')
13 changes: 13 additions & 0 deletions examples/python-entropy/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[build-system]
requires = ["maturin>=0.12,<0.13"]
build-backend = "maturin"

[project]
name = "python-entropy"
requires-python = ">=3.6"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]

40 changes: 40 additions & 0 deletions examples/python-entropy/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use pyo3::prelude::*;
use std::ops::{Div, Mul};
use tch::{Device, Kind, Scalar, Tensor};

#[pyclass(name = "EntropyMetric")]
pub struct EntropyMetric {
pub counter: Tensor,
}

#[pymethods]
impl EntropyMetric {
#[new]
fn __new__(n_classes: i64) -> Self {
EntropyMetric { counter: Tensor::zeros(&[n_classes], (Kind::Float, Device::Cpu)) }
}

fn update(&mut self, x: Tensor) -> PyResult<()> {
let ones = Tensor::ones(&[x.size()[0]], (Kind::Float, Device::Cpu));
self.counter = self.counter.scatter_add(0, &x, &ones);

Ok(())
}

fn get_counter(&self) -> PyResult<&Tensor> {
Ok(&self.counter)
}

fn compute(&self) -> PyResult<Tensor> {
let counts = self.counter.masked_select(&self.counter.gt(Scalar::from(0.0)));
let probs = counts.div(&self.counter.sum(Kind::Float));
let entropy = (probs.neg().mul(probs.log())).sum(Kind::Float);
Ok(entropy)
}
}

#[pymodule]
fn python_entropy(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<EntropyMetric>()?;
Ok(())
}
40 changes: 40 additions & 0 deletions src/wrappers/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ pub struct Tensor {

unsafe impl Send for Tensor {}

unsafe impl Sync for Tensor {}

pub extern "C" fn add_callback(data: *mut c_void, name: *const c_char, c_tensor: *mut C_tensor) {
let name = unsafe { std::ffi::CStr::from_ptr(name).to_str().unwrap() };
let name = name.replace('|', ".");
Expand Down Expand Up @@ -879,3 +881,41 @@ impl Reduction {
}
}
}

#[cfg(feature = "torch_python")]
use pyo3::{AsPyPointer, PyObject, Python, ToPyObject};

#[cfg(feature = "torch_python")]
impl<'a> pyo3::FromPyObject<'a> for Tensor {
fn extract(ob: &'a pyo3::PyAny) -> pyo3::PyResult<Self> {
if unsafe_torch!(thp_variable_check(ob.as_ptr() as *mut c_void)) {
Ok(unsafe_torch!(Tensor::from_ptr(thp_variable_unpack(ob.as_ptr() as *mut c_void))))
} else {
Err(pyo3::exceptions::PyTypeError::new_err("Expected a torch.Tensor"))
}
}
}

#[cfg(feature = "torch_python")]
impl pyo3::ToPyObject for Tensor {
fn to_object(&self, py: Python) -> PyObject {
let pyobj = unsafe_torch!(thp_variable_wrap(self.as_ptr() as *mut torch_sys::C_tensor)
as *mut pyo3::ffi::PyObject);

unsafe { pyo3::PyObject::from_owned_ptr(py, pyobj) }
}
}

#[cfg(feature = "torch_python")]
impl<'a> pyo3::IntoPy<PyObject> for &'a Tensor {
fn into_py(self, py: Python) -> PyObject {
self.to_object(py)
}
}

#[cfg(feature = "torch_python")]
impl pyo3::IntoPy<PyObject> for Tensor {
fn into_py(self, py: Python) -> PyObject {
self.to_object(py)
}
}
1 change: 1 addition & 0 deletions torch-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ zip = "0.6"

[features]
download-libtorch = ["ureq", "serde", "serde_json"]
python = []
doc-only = []

[package.metadata.docs.rs]
Expand Down
Loading