-
Notifications
You must be signed in to change notification settings - Fork 114
/
lib.rs
156 lines (139 loc) · 4.65 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
use std::ops::Add;
use numpy::ndarray::{Array1, ArrayD, ArrayView1, ArrayViewD, ArrayViewMutD, Zip};
use numpy::{
datetime::{units, Timedelta},
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyArrayMethods, PyReadonlyArray1,
PyReadonlyArrayDyn, PyReadwriteArray1, PyReadwriteArrayDyn,
};
use pyo3::{
exceptions::PyIndexError,
pymodule,
types::{PyAnyMethods, PyDict, PyDictMethods, PyModule},
Bound, FromPyObject, PyAny, PyObject, PyResult, Python,
};
#[pymodule]
fn rust_ext<'py>(m: &Bound<'py, PyModule>) -> PyResult<()> {
// example using generic PyObject
fn head(py: Python<'_>, x: ArrayViewD<'_, PyObject>) -> PyResult<PyObject> {
x.get(0)
.map(|obj| obj.clone_ref(py))
.ok_or_else(|| PyIndexError::new_err("array index out of range"))
}
// example using immutable borrows producing a new array
fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
a * &x + &y
}
// example using a mutable borrow to modify an array in-place
fn mult(a: f64, mut x: ArrayViewMutD<'_, f64>) {
x *= a;
}
// example using complex numbers
fn conj(x: ArrayViewD<'_, Complex64>) -> ArrayD<Complex64> {
x.map(|c| c.conj())
}
// example using generics
fn generic_add<T: Copy + Add<Output = T>>(
x: ArrayView1<'_, T>,
y: ArrayView1<'_, T>,
) -> Array1<T> {
&x + &y
}
// wrapper of `head`
#[pyfn(m)]
#[pyo3(name = "head")]
fn head_py<'py>(x: PyReadonlyArrayDyn<'py, PyObject>) -> PyResult<PyObject> {
head(x.py(), x.as_array())
}
// wrapper of `axpy`
#[pyfn(m)]
#[pyo3(name = "axpy")]
fn axpy_py<'py>(
py: Python<'py>,
a: f64,
x: PyReadonlyArrayDyn<'py, f64>,
y: PyReadonlyArrayDyn<'py, f64>,
) -> Bound<'py, PyArrayDyn<f64>> {
let x = x.as_array();
let y = y.as_array();
let z = axpy(a, x, y);
z.into_pyarray(py)
}
// wrapper of `mult`
#[pyfn(m)]
#[pyo3(name = "mult")]
fn mult_py<'py>(a: f64, mut x: PyReadwriteArrayDyn<'py, f64>) {
let x = x.as_array_mut();
mult(a, x);
}
// wrapper of `conj`
#[pyfn(m)]
#[pyo3(name = "conj")]
fn conj_py<'py>(
py: Python<'py>,
x: PyReadonlyArrayDyn<'py, Complex64>,
) -> Bound<'py, PyArrayDyn<Complex64>> {
conj(x.as_array()).into_pyarray(py)
}
// example of how to extract an array from a dictionary
#[pyfn(m)]
fn extract(d: &Bound<'_, PyDict>) -> f64 {
let x = d
.get_item("x")
.unwrap()
.unwrap()
.downcast_into::<PyArray1<f64>>()
.unwrap();
x.readonly().as_array().sum()
}
// example using timedelta64 array
#[pyfn(m)]
fn add_minutes_to_seconds<'py>(
mut x: PyReadwriteArray1<'py, Timedelta<units::Seconds>>,
y: PyReadonlyArray1<'py, Timedelta<units::Minutes>>,
) {
#[allow(deprecated)]
Zip::from(x.as_array_mut())
.and(y.as_array())
.for_each(|x, y| *x = (i64::from(*x) + 60 * i64::from(*y)).into());
}
// This crate follows a strongly-typed approach to wrapping NumPy arrays
// while Python API are often expected to work with multiple element types.
//
// That kind of limited polymorphis can be recovered by accepting an enumerated type
// covering the supported element types and dispatching into a generic implementation.
#[derive(FromPyObject)]
enum SupportedArray<'py> {
F64(Bound<'py, PyArray1<f64>>),
I64(Bound<'py, PyArray1<i64>>),
}
#[pyfn(m)]
fn polymorphic_add<'py>(
x: SupportedArray<'py>,
y: SupportedArray<'py>,
) -> PyResult<Bound<'py, PyAny>> {
match (x, y) {
(SupportedArray::F64(x), SupportedArray::F64(y)) => Ok(generic_add(
x.readonly().as_array(),
y.readonly().as_array(),
)
.into_pyarray(x.py())
.into_any()),
(SupportedArray::I64(x), SupportedArray::I64(y)) => Ok(generic_add(
x.readonly().as_array(),
y.readonly().as_array(),
)
.into_pyarray(x.py())
.into_any()),
(SupportedArray::F64(x), SupportedArray::I64(y))
| (SupportedArray::I64(y), SupportedArray::F64(x)) => {
let y = y.cast::<f64>(false)?;
Ok(
generic_add(x.readonly().as_array(), y.readonly().as_array())
.into_pyarray(x.py())
.into_any(),
)
}
}
}
Ok(())
}