Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement Decimal to rust_decimal conversions #3016

Merged
merged 1 commit into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ hashbrown = { version = ">= 0.9, < 0.14", optional = true }
indexmap = { version = "1.6", optional = true }
num-bigint = { version = "0.4", optional = true }
num-complex = { version = ">= 0.2, < 0.5", optional = true }
rust_decimal = { version = "1.0.0", default-features = false, optional = true }
serde = { version = "1.0", optional = true }

[dev-dependencies]
Expand All @@ -53,6 +54,7 @@ send_wrapper = "0.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.61"
rayon = "1.0.2"
rust_decimal = { version = "1.8.0", features = ["std"] }
widestring = "0.5.1"

[build-dependencies]
Expand Down Expand Up @@ -110,6 +112,7 @@ full = [
"eyre",
"anyhow",
"experimental-inspect",
"rust_decimal",
]

[[bench]]
Expand All @@ -120,6 +123,11 @@ harness = false
name = "bench_err"
harness = false

[[bench]]
name = "bench_decimal"
harness = false
required-features = ["rust_decimal"]

[[bench]]
name = "bench_dict"
harness = false
Expand Down Expand Up @@ -173,5 +181,5 @@ members = [

[package.metadata.docs.rs]
no-default-features = true
features = ["macros", "num-bigint", "num-complex", "hashbrown", "serde", "multiple-pymethods", "indexmap", "eyre", "chrono"]
features = ["macros", "num-bigint", "num-complex", "hashbrown", "serde", "multiple-pymethods", "indexmap", "eyre", "chrono", "rust_decimal"]
rustdoc-args = ["--cfg", "docsrs"]
32 changes: 32 additions & 0 deletions benches/bench_decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use criterion::{black_box, criterion_group, criterion_main, Bencher, Criterion};

use pyo3::prelude::*;
use pyo3::types::PyDict;
use rust_decimal::Decimal;

fn decimal_via_extract(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
let locals = PyDict::new(py);
py.run(
r#"
import decimal
py_dec = decimal.Decimal("0.0")
"#,
None,
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();

b.iter(|| {
let _: Decimal = black_box(py_dec).extract().unwrap();
});
})
}

fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("decimal_via_extract", decimal_via_extract);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
4 changes: 4 additions & 0 deletions guide/src/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ Adds a dependency on [num-bigint](https://docs.rs/num-bigint) and enables conver

Adds a dependency on [num-complex](https://docs.rs/num-complex) and enables conversions into its [`Complex`](https://docs.rs/num-complex/latest/num_complex/struct.Complex.html) type.

### `rust_decimal`

Adds a dependency on [rust_decimal](https://docs.rs/rust_decimal) and enables conversions into its [`Decimal`](https://docs.rs/rust_decimal/latest/rust_decimal/struct.Decimal.html) type.

### `serde`

Enables (de)serialization of `Py<T>` objects via [serde](https://serde.rs/).
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3016.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for converting to and from Python's `decimal.Decimal` and `rust_decimal::Decimal`.
4 changes: 4 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ def set_minimal_package_versions(session: nox.Session, venv_backend="none"):
"examples/word-count",
)
min_pkg_versions = {
# newer versions of rust_decimal want newer arrayvec
"rust_decimal": "1.18.0",
# newer versions of arrayvec use const generics (Rust 1.51+)
"arrayvec": "0.5.2",
"csv": "1.1.6",
"indexmap": "1.6.2",
"inventory": "0.3.4",
Expand Down
1 change: 1 addition & 0 deletions src/conversions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ pub mod hashbrown;
pub mod indexmap;
pub mod num_bigint;
pub mod num_complex;
pub mod rust_decimal;
pub mod serde;
mod std;
221 changes: 221 additions & 0 deletions src/conversions/rust_decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
#![cfg(feature = "rust_decimal")]
//! Conversions to and from [rust_decimal](https://docs.rs/rust_decimal)'s [`Decimal`] type.
//!
//! This is useful for converting Python's decimal.Decimal into and from a native Rust type.
//!
//! # Setup
//!
//! To use this feature, add to your **`Cargo.toml`**:
//!
//! ```toml
//! [dependencies]
//! rust_decimal = "1.0"
// workaround for `extended_key_value_attributes`: https://github.com/rust-lang/rust/issues/82768#issuecomment-803935643
#![cfg_attr(docsrs, cfg_attr(docsrs, doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"rust_decimal\"] }")))]
#![cfg_attr(
not(docsrs),
doc = "pyo3 = { version = ..., features = [\"rust_decimal\"] }"
)]
//! ```
//!
//! Note that you must use a compatible version of rust_decimal and PyO3.
//! The required rust_decimal version may vary based on the version of PyO3.
//!
//! # Example
//!
//! Rust code to create a function that adds one to a Decimal
//!
//! ```rust
//! use rust_decimal::Decimal;
//! use pyo3::prelude::*;
//!
//! #[pyfunction]
//! fn add_one(d: Decimal) -> Decimal {
//! d + Decimal::ONE
//! }
//!
//! #[pymodule]
//! fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
//! m.add_function(wrap_pyfunction!(add_one, m)?)?;
//! Ok(())
//! }
//! ```
//!
//! Python code that validates the functionality
//!
//!
//! ```python
//! from my_module import add_one
//! from decimal import Decimal
//!
//! d = Decimal("2")
//! value = add_one(d)
//!
//! assert d + 1 == value
//! ```

use crate::exceptions::PyValueError;
use crate::once_cell::GILOnceCell;
use crate::types::PyType;
use crate::{intern, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject};
use rust_decimal::Decimal;
use std::str::FromStr;

impl FromPyObject<'_> for Decimal {
fn extract(obj: &PyAny) -> PyResult<Self> {
// use the string representation to not be lossy
if let Ok(val) = obj.extract() {
Ok(Decimal::new(val, 0))
} else {
cardoe marked this conversation as resolved.
Show resolved Hide resolved
Decimal::from_str(obj.str()?.to_str()?)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
}
}

static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();

fn get_decimal_cls(py: Python<'_>) -> PyResult<&PyType> {
DECIMAL_CLS
.get_or_try_init(py, || {
py.import(intern!(py, "decimal"))?
.getattr(intern!(py, "Decimal"))?
.extract()
})
.map(|ty| ty.as_ref(py))
}

impl ToPyObject for Decimal {
fn to_object(&self, py: Python<'_>) -> PyObject {
// TODO: handle error gracefully when ToPyObject can error
// look up the decimal.Decimal
let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal");
// now call the constructor with the Rust Decimal string-ified
// to not be lossy
let ret = dec_cls
.call1((self.to_string(),))
.expect("failed to call decimal.Decimal(value)");
ret.to_object(py)
}
}

impl IntoPy<PyObject> for Decimal {
fn into_py(self, py: Python<'_>) -> PyObject {
self.to_object(py)
}
}

#[cfg(test)]
mod test_rust_decimal {
use super::*;
use crate::err::PyErr;
use crate::types::PyDict;
use rust_decimal::Decimal;

#[cfg(not(target_arch = "wasm32"))]
use proptest::prelude::*;

macro_rules! convert_constants {
cardoe marked this conversation as resolved.
Show resolved Hide resolved
($name:ident, $rs:expr, $py:literal) => {
#[test]
fn $name() {
Python::with_gil(|py| {
let rs_orig = $rs;
let rs_dec = rs_orig.into_py(py);
let locals = PyDict::new(py);
locals.set_item("rs_dec", &rs_dec).unwrap();
// Checks if Rust Decimal -> Python Decimal conversion is correct
py.run(
&format!(
"import decimal\npy_dec = decimal.Decimal({})\nassert py_dec == rs_dec",
$py
),
None,
Some(locals),
)
.unwrap();
// Checks if Python Decimal -> Rust Decimal conversion is correct
let py_dec = locals.get_item("py_dec").unwrap();
let py_result: Decimal = FromPyObject::extract(py_dec).unwrap();
assert_eq!(rs_orig, py_result);
})
}
};
}

convert_constants!(convert_zero, Decimal::ZERO, "0");
convert_constants!(convert_one, Decimal::ONE, "1");
convert_constants!(convert_neg_one, Decimal::NEGATIVE_ONE, "-1");
convert_constants!(convert_two, Decimal::TWO, "2");
convert_constants!(convert_ten, Decimal::TEN, "10");
convert_constants!(convert_one_hundred, Decimal::ONE_HUNDRED, "100");
convert_constants!(convert_one_thousand, Decimal::ONE_THOUSAND, "1000");
cardoe marked this conversation as resolved.
Show resolved Hide resolved

#[cfg(not(target_arch = "wasm32"))]
proptest! {
#[test]
fn test_roundtrip(
lo in any::<u32>(),
mid in any::<u32>(),
high in any::<u32>(),
negative in any::<bool>(),
scale in 0..28u32
) {
let num = Decimal::from_parts(lo, mid, high, negative, scale);
Python::with_gil(|py| {
let rs_dec = num.into_py(py);
let locals = PyDict::new(py);
locals.set_item("rs_dec", &rs_dec).unwrap();
py.run(
&format!(
"import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
num),
None, Some(locals)).unwrap();
let roundtripped: Decimal = rs_dec.extract(py).unwrap();
assert_eq!(num, roundtripped);
})
}

#[test]
fn test_integers(num in any::<i64>()) {
Python::with_gil(|py| {
let py_num = num.into_py(py);
let roundtripped: Decimal = py_num.extract(py).unwrap();
let rs_dec = Decimal::new(num, 0);
assert_eq!(rs_dec, roundtripped);
})
}
}

#[test]
fn test_nan() {
Python::with_gil(|py| {
let locals = PyDict::new(py);
py.run(
"import decimal\npy_dec = decimal.Decimal(\"NaN\")",
None,
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let roundtripped: Result<Decimal, PyErr> = FromPyObject::extract(py_dec);
assert!(roundtripped.is_err());
})
}

#[test]
fn test_infinity() {
Python::with_gil(|py| {
let locals = PyDict::new(py);
py.run(
"import decimal\npy_dec = decimal.Decimal(\"Infinity\")",
None,
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let roundtripped: Result<Decimal, PyErr> = FromPyObject::extract(py_dec);
assert!(roundtripped.is_err());
})
}
}
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@
//! [`BigUint`] types.
//! - [`num-complex`]: Enables conversions between Python objects and [num-complex]'s [`Complex`]
//! type.
//! - [`rust_decimal`]: Enables conversions between Python's decimal.Decimal and [rust_decimal]'s
//! [`Decimal`] type.
//! - [`serde`]: Allows implementing [serde]'s [`Serialize`] and [`Deserialize`] traits for
//! [`Py`]`<T>` for all `T` that implement [`Serialize`] and [`Deserialize`].
//!
Expand Down Expand Up @@ -275,6 +277,9 @@
//! [`num-bigint`]: ./num_bigint/index.html "Documentation about the `num-bigint` feature."
//! [`num-complex`]: ./num_complex/index.html "Documentation about the `num-complex` feature."
//! [`pyo3-build-config`]: https://docs.rs/pyo3-build-config
//! [rust_decimal]: https://docs.rs/rust_decimal
//! [`rust_decimal`]: ./rust_decimal/index.html "Documenation about the `rust_decimal` feature."
//! [`Decimal`]: https://docs.rs/rust_decimal/latest/rust_decimal/struct.Decimal.html
//! [`serde`]: <./serde/index.html> "Documentation about the `serde` feature."
//! [calling_rust]: https://pyo3.rs/latest/python_from_rust.html "Calling Python from Rust - PyO3 user guide"
//! [examples subdirectory]: https://github.com/PyO3/pyo3/tree/main/examples
Expand Down