Skip to content

Commit

Permalink
implement Decimal to rust_decimal conversions
Browse files Browse the repository at this point in the history
Implement conversion between rust_decimal::Decimal and decimal.Decimal
from Python's stdlib. The C API does not appear to be exposed on the
Python side so we need to call into it via Python.
  • Loading branch information
cardoe committed Apr 21, 2023
1 parent 9ca94a1 commit 32540a8
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 1 deletion.
10 changes: 9 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ hashbrown = { version = ">= 0.9, < 0.14", optional = true }
indexmap = { version = "1.6", optional = true }
num-bigint = { version = "0.4", optional = true }
num-complex = { version = ">= 0.2, < 0.5", optional = true }
rust_decimal = { version = "1.18.0", default-features = false, optional = true }
serde = { version = "1.0", optional = true }

[dev-dependencies]
Expand All @@ -53,6 +54,7 @@ send_wrapper = "0.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.61"
rayon = "1.0.2"
rust_decimal = { version = "1.18.0", features = ["std"] }
widestring = "0.5.1"

[build-dependencies]
Expand Down Expand Up @@ -110,6 +112,7 @@ full = [
"eyre",
"anyhow",
"experimental-inspect",
"rust_decimal",
]

[[bench]]
Expand All @@ -120,6 +123,11 @@ harness = false
name = "bench_err"
harness = false

[[bench]]
name = "bench_decimal"
harness = false
required-features = ["rust_decimal"]

[[bench]]
name = "bench_dict"
harness = false
Expand Down Expand Up @@ -173,5 +181,5 @@ members = [

[package.metadata.docs.rs]
no-default-features = true
features = ["macros", "num-bigint", "num-complex", "hashbrown", "serde", "multiple-pymethods", "indexmap", "eyre", "chrono"]
features = ["macros", "num-bigint", "num-complex", "hashbrown", "serde", "multiple-pymethods", "indexmap", "eyre", "chrono", "rust_decimal"]
rustdoc-args = ["--cfg", "docsrs"]
35 changes: 35 additions & 0 deletions benches/bench_decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use criterion::{black_box, criterion_group, criterion_main, Bencher, Criterion};

use pyo3::prelude::*;
use pyo3::types::PyDict;
#[cfg(feature = "rust_decimal")]
use rust_decimal::Decimal;

#[cfg(feature = "rust_decimal")]
fn decimal_via_extract(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
let locals = PyDict::new(py);
py.run(
r#"
import decimal
py_dec = decimal.Decimal("0.0")
"#,
None,
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();

b.iter(|| {
let _: Decimal = black_box(py_dec).extract().unwrap();
});
})
}

fn criterion_benchmark(c: &mut Criterion) {
#[cfg(feature = "rust_decimal")]
c.bench_function("decimal_via_extract", decimal_via_extract);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
4 changes: 4 additions & 0 deletions guide/src/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ Adds a dependency on [num-bigint](https://docs.rs/num-bigint) and enables conver

Adds a dependency on [num-complex](https://docs.rs/num-complex) and enables conversions into its [`Complex`](https://docs.rs/num-complex/latest/num_complex/struct.Complex.html) type.

### `rust_decimal`

Adds a dependency on [rust_decimal](https://docs.rs/rust_decimal) and enables conversions into its [`Decimal`](https://docs.rs/rust_decimal/latest/rust_decimal/struct.Decimal.html) type.

### `serde`

Enables (de)serialization of `Py<T>` objects via [serde](https://serde.rs/).
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3016.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for converting to and from Python's `decimal.Decimal` and `rust_decimal::Decimal`.
4 changes: 4 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ def set_minimal_package_versions(session: nox.Session, venv_backend="none"):
"examples/word-count",
)
min_pkg_versions = {
# newer versions of rust_decimal want newer arrayvec
"rust_decimal": "1.18.0",
# newer versions of arrayvec use const generics (Rust 1.51+)
"arrayvec": "0.5.2",
"csv": "1.1.6",
"indexmap": "1.6.2",
"inventory": "0.3.4",
Expand Down
1 change: 1 addition & 0 deletions src/conversions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ pub mod hashbrown;
pub mod indexmap;
pub mod num_bigint;
pub mod num_complex;
pub mod rust_decimal;
pub mod serde;
mod std;
229 changes: 229 additions & 0 deletions src/conversions/rust_decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#![cfg(feature = "rust_decimal")]
//! Conversions to and from [rust_decimal](https://docs.rs/rust_decimal)'s [`Decimal`] type.
//!
//! This is useful for converting Python's decimal.Decimal into and from a native Rust type.
//!
//! # Setup
//!
//! To use this feature, add to your **`Cargo.toml`**:
//!
//! ```toml
//! [dependencies]
//! ## change * to the version you want to use, ideally the latest.
//! rust_decimal = "1.0"
// workaround for `extended_key_value_attributes`: https://github.com/rust-lang/rust/issues/82768#issuecomment-803935643
#![cfg_attr(docsrs, cfg_attr(docsrs, doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"rust_decimal\"] }")))]
#![cfg_attr(
not(docsrs),
doc = "pyo3 = { version = \"*\", features = [\"rust_decimal\"] }"
)]
//! ```
//!
//! Note that you must use a compatible version of rust_decimal and PyO3.
//! The required rust_decimal version may vary based on the version of PyO3.
//!
//! # Example
//!
//! Rust code to create a function that adds one to a Decimal
//!
//! ```rust
//! use rust_decimal::Decimal;
//! use pyo3::prelude::*;
//!
//! #[pyfunction]
//! fn add_one(d: Decimal) -> Decimal {
//! d + Decimal::ONE
//! }
//!
//! #[pymodule]
//! fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
//! m.add_function(wrap_pyfunction!(add_one, m)?)?;
//! Ok(())
//! }
//! ```
//!
//! Python code that validates the functionality
//!
//!
//! ```python
//! from my_module import add_one
//! from decimal import Decimal
//!
//! d = Decimal("2")
//! value = add_one(d)
//!
//! assert d + 1 == value
//! ```
use crate::conversion::AsPyPointer;
use crate::exceptions::PyValueError;
use crate::ffi;
use crate::once_cell::GILOnceCell;
use crate::types::PyType;
use crate::{intern, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject};
use rust_decimal::Decimal;
use std::str::FromStr;

impl FromPyObject<'_> for Decimal {
fn extract(obj: &PyAny) -> PyResult<Self> {
// use the string representation to not be lossy
unsafe {
let num_obj = ffi::PyNumber_Index(obj.as_ptr());
if !num_obj.is_null() {
let val = ffi::PyLong_AsLong(num_obj);
ffi::Py_DECREF(num_obj);
Ok(Decimal::new(val, 0))
} else {
Decimal::from_str(obj.str()?.to_str()?)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
}
}
}

static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();

fn get_decimal_cls(py: Python<'_>) -> PyResult<&PyType> {
DECIMAL_CLS
.get_or_try_init(py, || {
py.import(intern!(py, "decimal"))?
.getattr(intern!(py, "Decimal"))?
.extract()
})
.map(|ty| ty.as_ref(py))
}

impl ToPyObject for Decimal {
fn to_object(&self, py: Python<'_>) -> PyObject {
// TODO: handle error gracefully when ToPyObject can error
// look up the decimal.Decimal
let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal");
// now call the constructor with the Rust Decimal string-ified
// to not be lossy
let ret = dec_cls
.call1((self.to_string(),))
.expect("failed to call decimal.Decimal(value)");
ret.to_object(py)
}
}

impl IntoPy<PyObject> for Decimal {
fn into_py(self, py: Python<'_>) -> PyObject {
self.to_object(py)
}
}

#[cfg(test)]
mod test_rust_decimal {
use super::*;
use crate::err::PyErr;
use crate::types::PyDict;
use rust_decimal::Decimal;

#[cfg(not(target_arch = "wasm32"))]
use proptest::prelude::*;

macro_rules! convert_constants {
($name:ident, $rs:expr, $py:literal) => {
#[test]
fn $name() {
Python::with_gil(|py| {
let rs_orig = $rs;
let rs_dec = rs_orig.into_py(py);
let locals = PyDict::new(py);
locals.set_item("rs_dec", &rs_dec).unwrap();
// Checks if Rust Decimal -> Python Decimal conversion is correct
py.run(
&format!(
"import decimal\npy_dec = decimal.Decimal({})\nassert py_dec == rs_dec",
$py
),
None,
Some(locals),
)
.unwrap();
// Checks if Python Decimal -> Rust Decimal conversion is correct
let py_dec = locals.get_item("py_dec").unwrap();
let py_result: Decimal = FromPyObject::extract(py_dec).unwrap();
assert_eq!(rs_orig, py_result);
})
}
};
}

convert_constants!(convert_zero, Decimal::ZERO, "0");
convert_constants!(convert_one, Decimal::ONE, "1");
convert_constants!(convert_neg_one, Decimal::NEGATIVE_ONE, "-1");
convert_constants!(convert_two, Decimal::TWO, "2");
convert_constants!(convert_ten, Decimal::TEN, "10");
convert_constants!(convert_one_hundred, Decimal::ONE_HUNDRED, "100");
convert_constants!(convert_one_thousand, Decimal::ONE_THOUSAND, "1000");

#[cfg(not(target_arch = "wasm32"))]
proptest! {
#[test]
fn test_roundtrip(
lo in any::<u32>(),
mid in any::<u32>(),
high in any::<u32>(),
negative in any::<bool>(),
scale in 0..28u32
) {
let num = Decimal::from_parts(lo, mid, high, negative, scale);
Python::with_gil(|py| {
let rs_dec = num.into_py(py);
let locals = PyDict::new(py);
locals.set_item("rs_dec", &rs_dec).unwrap();
py.run(
&format!(
"import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
num),
None, Some(locals)).unwrap();
let roundtripped: Decimal = rs_dec.extract(py).unwrap();
assert_eq!(num, roundtripped);
})
}

#[test]
fn test_integers(num in any::<i64>()) {
Python::with_gil(|py| {
let py_num = num.into_py(py);
let roundtripped: Decimal = py_num.extract(py).unwrap();
let rs_dec = Decimal::new(num, 0);
assert_eq!(rs_dec, roundtripped);
})
}
}

#[test]
fn test_nan() {
Python::with_gil(|py| {
let locals = PyDict::new(py);
py.run(
"import decimal\npy_dec = decimal.Decimal(\"NaN\")",
None,
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let roundtripped: Result<Decimal, PyErr> = FromPyObject::extract(py_dec);
assert!(roundtripped.is_err());
})
}

#[test]
fn test_infinity() {
Python::with_gil(|py| {
let locals = PyDict::new(py);
py.run(
"import decimal\npy_dec = decimal.Decimal(\"Infinity\")",
None,
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let roundtripped: Result<Decimal, PyErr> = FromPyObject::extract(py_dec);
assert!(roundtripped.is_err());
})
}
}
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@
//! [`BigUint`] types.
//! - [`num-complex`]: Enables conversions between Python objects and [num-complex]'s [`Complex`]
//! type.
//! - [`rust_decimal`]: Enables conversions between Python's decimal.Decimal and [rust_decimal]'s
//! [`Decimal`] type.
//! - [`serde`]: Allows implementing [serde]'s [`Serialize`] and [`Deserialize`] traits for
//! [`Py`]`<T>` for all `T` that implement [`Serialize`] and [`Deserialize`].
//!
Expand Down Expand Up @@ -275,6 +277,9 @@
//! [`num-bigint`]: ./num_bigint/index.html "Documentation about the `num-bigint` feature."
//! [`num-complex`]: ./num_complex/index.html "Documentation about the `num-complex` feature."
//! [`pyo3-build-config`]: https://docs.rs/pyo3-build-config
//! [rust_decimal]: https://docs.rs/rust_decimal
//! [`rust_decimal`]: ./rust_decimal/index.html "Documenation about the `rust_decimal` feature."
//! [`Decimal`]: https://docs.rs/rust_decimal/latest/rust_decimal/struct.Decimal.html
//! [`serde`]: <./serde/index.html> "Documentation about the `serde` feature."
//! [calling_rust]: https://pyo3.rs/latest/python_from_rust.html "Calling Python from Rust - PyO3 user guide"
//! [examples subdirectory]: https://github.com/PyO3/pyo3/tree/main/examples
Expand Down

0 comments on commit 32540a8

Please sign in to comment.