Skip to content

Commit

Permalink
Various maintenance tasks (#460)
Browse files Browse the repository at this point in the history
* Const generics to reduce duplicate code for array ops
* Replace default op with `wrapping_sub` to better communicate intentions
* Avoid additional heap allocation for serde serialization
* Fixes documentation test
* Clippy cleanup
  • Loading branch information
c410-f3r authored Jan 10, 2022
1 parent caee892 commit 30eb444
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub const UNSIGN_MASK: u32 = 0x4FFF_FFFF;
// contain a value between 0 and 28 inclusive.
pub const SCALE_MASK: u32 = 0x00FF_0000;
pub const U8_MASK: u32 = 0x0000_00FF;
pub const U32_MASK: u64 = 0xFFFF_FFFF;
pub const U32_MASK: u64 = u32::MAX as _;

// Number of bits scale is shifted by.
pub const SCALE_SHIFT: u32 = 16;
Expand Down
4 changes: 2 additions & 2 deletions src/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,7 @@ impl Decimal {

/// Returns `Some(Decimal)` number rounded to the specified number of significant digits. If
/// the resulting number is unable to be represented by the `Decimal` number then `None` will
/// be returned.
/// be returned.
/// When the number of significant figures of the `Decimal` being rounded is greater than the requested
/// number of significant digits then rounding will be performed using `MidpointNearestEven` strategy.
///
Expand Down Expand Up @@ -1224,7 +1224,7 @@ impl Decimal {

/// Returns `Some(Decimal)` number rounded to the specified number of significant digits. If
/// the resulting number is unable to be represented by the `Decimal` number then `None` will
/// be returned.
/// be returned.
/// When the number of significant figures of the `Decimal` being rounded is greater than the requested
/// number of significant digits then rounding will be performed using the provided [RoundingStrategy].
///
Expand Down
2 changes: 1 addition & 1 deletion src/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mod diesel_mysql {
// internal types.
let bytes = numeric.ok_or("Invalid decimal")?;
let s = std::str::from_utf8(bytes)?;
Decimal::from_str(&s).map_err(|e| e.into())
Decimal::from_str(s).map_err(|e| e.into())
}
}

Expand Down
77 changes: 42 additions & 35 deletions src/ops/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,46 @@ pub(crate) fn rescale_internal(value: &mut [u32; 3], value_scale: &mut u32, new_
}

if *value_scale > new_scale {
let mut diff = *value_scale - new_scale;
let mut diff = value_scale.wrapping_sub(new_scale);
// Scaling further isn't possible since we got an overflow
// In this case we need to reduce the accuracy of the "side to keep"

// Now do the necessary rounding
let mut remainder = 0;
while diff > 0 {
while let Some(diff_minus_one) = diff.checked_sub(1) {
if is_all_zero(value) {
*value_scale = new_scale;
return;
}

diff -= 1;
diff = diff_minus_one;

// Any remainder is discarded if diff > 0 still (i.e. lost precision)
remainder = div_by_u32(value, 10);
}
if remainder >= 5 {
for part in value.iter_mut() {
let digit = u64::from(*part) + 1u64;
remainder = if digit > 0xFFFF_FFFF { 1 } else { 0 };
*part = (digit & 0xFFFF_FFFF) as u32;
remainder = if digit > U32_MASK { 1 } else { 0 };
*part = (digit & U32_MASK) as u32;
if remainder == 0 {
break;
}
}
}
*value_scale = new_scale;
} else {
let mut diff = new_scale - *value_scale;
let mut diff = new_scale.wrapping_sub(*value_scale);
let mut working = [value[0], value[1], value[2]];
while diff > 0 && mul_by_10(&mut working) == 0 {
value.copy_from_slice(&working);
diff -= 1;
while let Some(diff_minus_one) = diff.checked_sub(1) {
if mul_by_10(&mut working) == 0 {
value.copy_from_slice(&working);
diff = diff_minus_one;
} else {
break;
}
}
*value_scale = new_scale - diff;
*value_scale = new_scale.wrapping_sub(diff);
}
}

Expand Down Expand Up @@ -94,32 +98,35 @@ pub(crate) fn add_by_internal(value: &mut [u32], by: &[u32]) -> u32 {
}

pub(crate) fn add_by_internal_flattened(value: &mut [u32; 3], by: u32) -> u32 {
let mut carry: u64;
let mut sum: u64;
sum = u64::from(value[0]) + u64::from(by);
value[0] = (sum & U32_MASK) as u32;
carry = sum >> 32;
if carry > 0 {
sum = u64::from(value[1]) + carry;
value[1] = (sum & U32_MASK) as u32;
carry = sum >> 32;
if carry > 0 {
sum = u64::from(value[2]) + carry;
value[2] = (sum & U32_MASK) as u32;
carry = sum >> 32;
}
}
carry as u32
manage_add_by_internal(by, value)
}

#[inline]
pub(crate) fn add_one_internal(value: &mut [u32]) -> u32 {
let mut carry: u64 = 1; // Start with one, since adding one
let mut sum: u64;
for i in value.iter_mut() {
sum = (*i as u64) + carry;
*i = (sum & U32_MASK) as u32;
carry = sum >> 32;
pub(crate) fn add_one_internal(value: &mut [u32; 3]) -> u32 {
manage_add_by_internal(1, value)
}

// `u64 as u32` are safe because of widening and 32bits shifts
#[inline]
pub(crate) fn manage_add_by_internal<const N: usize>(initial_carry: u32, value: &mut [u32; N]) -> u32 {
let mut carry = u64::from(initial_carry);
let mut iter = 0..value.len();
let mut sum = 0;

let mut sum_fn = |local_carry: &mut u64, idx| {
sum = u64::from(value[idx]).wrapping_add(*local_carry);
value[idx] = (sum & U32_MASK) as u32;
*local_carry = sum.wrapping_shr(32);
};

if let Some(idx) = iter.next() {
sum_fn(&mut carry, idx);
}

for idx in iter {
if carry > 0 {
sum_fn(&mut carry, idx);
}
}

carry as u32
Expand Down Expand Up @@ -209,7 +216,7 @@ pub(crate) fn mul_part(left: u32, right: u32, high: u32) -> (u32, u32) {
}

// Returns remainder
pub(crate) fn div_by_u32(bits: &mut [u32], divisor: u32) -> u32 {
pub(crate) fn div_by_u32<const N: usize>(bits: &mut [u32; N], divisor: u32) -> u32 {
if divisor == 0 {
// Divide by zero
panic!("Internal error: divide by zero");
Expand Down Expand Up @@ -271,7 +278,7 @@ pub(crate) fn cmp_internal(left: &[u32; 3], right: &[u32; 3]) -> core::cmp::Orde
}

#[inline]
pub(crate) fn is_all_zero(bits: &[u32]) -> bool {
pub(crate) fn is_all_zero<const N: usize>(bits: &[u32; N]) -> bool {
bits.iter().all(|b| *b == 0)
}

Expand Down
9 changes: 7 additions & 2 deletions src/ops/legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ fn add_with_scale_internal(
let mut temp4 = [0u32, 0u32, 0u32, 0u32];
if *quotient_scale != *working_scale {
// TODO: Remove necessity for temp (without performance impact)
fn div_by_10(target: &mut [u32], temp: &mut [u32], scale: &mut i32, target_scale: i32) {
fn div_by_10<const N: usize>(target: &mut [u32], temp: &mut [u32; N], scale: &mut i32, target_scale: i32) {
// Copy to the temp array
temp.copy_from_slice(target);
// divide by 10 until target scale is reached
Expand Down Expand Up @@ -676,7 +676,12 @@ fn add_with_scale_internal(
// (ultimately losing significant digits)
if *quotient_scale != *working_scale {
// TODO: Remove necessity for temp (without performance impact)
fn div_by_10_lossy(target: &mut [u32], temp: &mut [u32], scale: &mut i32, target_scale: i32) {
fn div_by_10_lossy<const N: usize>(
target: &mut [u32],
temp: &mut [u32; N],
scale: &mut i32,
target_scale: i32,
) {
temp.copy_from_slice(target);
// divide by 10 until target scale is reached
while *scale > target_scale {
Expand Down
2 changes: 1 addition & 1 deletion src/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ mod diesel_postgres {
}

#[cfg(any(feature = "db-postgres", feature = "db-tokio-postgres"))]
mod postgres {
mod _postgres {
use super::*;
use ::postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
use byteorder::{BigEndian, ReadBytesExt};
Expand Down
12 changes: 9 additions & 3 deletions src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub mod float {
///
/// #[derive(Serialize, Deserialize)]
/// pub struct StringExample {
/// #[serde(with = "rust_decimal::serde::string")]
/// #[serde(with = "rust_decimal::serde::str")]
/// value: Decimal,
/// }
///
Expand All @@ -106,8 +106,12 @@ pub mod float {
/// ```
#[cfg(feature = "serde-with-str")]
pub mod str {
use crate::constants::MAX_STR_BUFFER_SIZE;

use super::*;
use serde::Serialize;
use arrayvec::ArrayString;
use core::convert::TryFrom;
use serde::{ser::Error, Serialize};

pub fn deserialize<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
where
Expand All @@ -120,7 +124,9 @@ pub mod str {
where
S: serde::Serializer,
{
value.to_string().serialize(serializer)
ArrayString::<MAX_STR_BUFFER_SIZE>::try_from(format_args!("{}", value))
.map_err(S::Error::custom)?
.serialize(serializer)
}
}

Expand Down

0 comments on commit 30eb444

Please sign in to comment.