Skip to content

Commit

Permalink
Fixing BigDecimal conversion for PostgreSQL
Browse files Browse the repository at this point in the history
Now working properly with numbers, such as `0.01` and `0.012`.
  • Loading branch information
Julius de Bruijn committed Oct 30, 2020
1 parent e1e9946 commit 0c1a3c5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 39 deletions.
77 changes: 38 additions & 39 deletions sqlx-core/src/postgres/types/bigdecimal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::cmp;
use std::convert::{TryFrom, TryInto};

use bigdecimal::BigDecimal;
use bigdecimal::{BigDecimal, ToPrimitive, Zero};
use num_bigint::{BigInt, Sign};

use crate::decode::Decode;
Expand Down Expand Up @@ -77,65 +77,64 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric {
type Error = BoxDynError;

fn try_from(decimal: &BigDecimal) -> Result<Self, BoxDynError> {
let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16);
if decimal.is_zero() {
return Ok(PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
digits: vec![],
});
}

// NOTE: this unfortunately copies the BigInt internally
let (integer, exp) = decimal.as_bigint_and_exponent();

// this routine is specifically optimized for base-10
// FIXME: is there a way to iterate over the digits to avoid the Vec allocation
let (sign, base_10) = integer.to_radix_be(10);

// weight is positive power of 10000
// exp is the negative power of 10
let weight_10 = base_10.len() as i64 - exp;

// scale is only nonzero when we have fractional digits
// since `exp` is the _negative_ decimal exponent, it tells us
// exactly what our scale should be
let scale: i16 = cmp::max(0, exp).try_into()?;

// there's an implicit +1 offset in the interpretation
let weight: i16 = if weight_10 <= 0 {
weight_10 / 4 - 1
} else {
// the `-1` is a fix for an off by 1 error (4 digits should still be 0 weight)
(weight_10 - 1) / 4
}
.try_into()?;
let (sign, uint) = integer.into_parts();
let mut mantissa = uint.to_u128().unwrap();

let digits_len = if base_10.len() % 4 != 0 {
base_10.len() / 4 + 1
} else {
base_10.len() / 4
};
// If our scale is not a multiple of 4, we need to go to the next
// multiple.
let groups_diff = scale % 4;
if groups_diff > 0 {
let remainder = 4 - groups_diff as u32;
let power = 10u32.pow(remainder as u32) as u128;

let offset = weight_10.rem_euclid(4) as usize;
mantissa = mantissa * power;
}

let mut digits = Vec::with_capacity(digits_len);
// Array to store max mantissa of Decimal in Postgres decimal format.
let mut digits = Vec::with_capacity(8);

if let Some(first) = base_10.get(..offset) {
if offset != 0 {
digits.push(base_10_to_10000(first));
}
// Convert to base-10000.
while mantissa != 0 {
digits.push((mantissa % 10_000) as i16);
mantissa /= 10_000;
}

if let Some(rest) = base_10.get(offset..) {
digits.extend(
rest.chunks(4)
.map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)),
);
}
// Change the endianness.
digits.reverse();

// Weight is number of digits on the left side of the decimal.
let digits_after_decimal = (scale + 3) as u16 / 4;
let weight = digits.len() as i16 - digits_after_decimal as i16 - 1;

// Remove non-significant zeroes.
while let Some(&0) = digits.last() {
digits.pop();
}

let sign = match sign {
Sign::Plus | Sign::NoSign => PgNumericSign::Positive,
Sign::Minus => PgNumericSign::Negative,
};

Ok(PgNumeric::Number {
sign: match sign {
Sign::Plus | Sign::NoSign => PgNumericSign::Positive,
Sign::Minus => PgNumericSign::Negative,
},
sign,
scale,
weight,
digits,
Expand Down
7 changes: 7 additions & 0 deletions tests/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,14 @@ test_type!(bigdecimal<sqlx::types::BigDecimal>(Postgres,
"10000::numeric" == "10000".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.1::numeric" == "0.1".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.01::numeric" == "0.01".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.012::numeric" == "0.012".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.0123::numeric" == "0.0123".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.01234::numeric" == "0.01234".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.012345::numeric" == "0.012345".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.0123456::numeric" == "0.0123456".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.01234567::numeric" == "0.01234567".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.012345678::numeric" == "0.012345678".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.0123456789::numeric" == "0.0123456789".parse::<sqlx::types::BigDecimal>().unwrap(),
"12.34::numeric" == "12.34".parse::<sqlx::types::BigDecimal>().unwrap(),
"12345.6789::numeric" == "12345.6789".parse::<sqlx::types::BigDecimal>().unwrap(),
));
Expand Down

0 comments on commit 0c1a3c5

Please sign in to comment.