From c615542d64de2113391623a8b42526933eb87641 Mon Sep 17 00:00:00 2001 From: Caio Date: Sun, 9 Jan 2022 09:06:11 -0300 Subject: [PATCH] Maintenance --- src/constants.rs | 2 +- src/decimal.rs | 4 +-- src/mysql.rs | 2 +- src/ops/array.rs | 77 ++++++++++++++++++++++++++--------------------- src/ops/legacy.rs | 9 ++++-- src/postgres.rs | 2 +- src/serde.rs | 12 ++++++-- 7 files changed, 63 insertions(+), 45 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 2600f4fb..59f33665 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -9,7 +9,7 @@ pub const UNSIGN_MASK: u32 = 0x4FFF_FFFF; // contain a value between 0 and 28 inclusive. pub const SCALE_MASK: u32 = 0x00FF_0000; pub const U8_MASK: u32 = 0x0000_00FF; -pub const U32_MASK: u64 = 0xFFFF_FFFF; +pub const U32_MASK: u64 = u32::MAX as _; // Number of bits scale is shifted by. pub const SCALE_SHIFT: u32 = 16; diff --git a/src/decimal.rs b/src/decimal.rs index e8f49225..b8523b2c 100644 --- a/src/decimal.rs +++ b/src/decimal.rs @@ -1183,7 +1183,7 @@ impl Decimal { /// Returns `Some(Decimal)` number rounded to the specified number of significant digits. If /// the resulting number is unable to be represented by the `Decimal` number then `None` will - /// be returned. + /// be returned. /// When the number of significant figures of the `Decimal` being rounded is greater than the requested /// number of significant digits then rounding will be performed using `MidpointNearestEven` strategy. /// @@ -1224,7 +1224,7 @@ impl Decimal { /// Returns `Some(Decimal)` number rounded to the specified number of significant digits. If /// the resulting number is unable to be represented by the `Decimal` number then `None` will - /// be returned. + /// be returned. /// When the number of significant figures of the `Decimal` being rounded is greater than the requested /// number of significant digits then rounding will be performed using the provided [RoundingStrategy]. /// diff --git a/src/mysql.rs b/src/mysql.rs index edac42a2..b169fee5 100644 --- a/src/mysql.rs +++ b/src/mysql.rs @@ -26,7 +26,7 @@ mod diesel_mysql { // internal types. let bytes = numeric.ok_or("Invalid decimal")?; let s = std::str::from_utf8(bytes)?; - Decimal::from_str(&s).map_err(|e| e.into()) + Decimal::from_str(s).map_err(|e| e.into()) } } diff --git a/src/ops/array.rs b/src/ops/array.rs index fbea1e21..651ba55d 100644 --- a/src/ops/array.rs +++ b/src/ops/array.rs @@ -15,19 +15,19 @@ pub(crate) fn rescale_internal(value: &mut [u32; 3], value_scale: &mut u32, new_ } if *value_scale > new_scale { - let mut diff = *value_scale - new_scale; + let mut diff = value_scale.wrapping_sub(new_scale); // Scaling further isn't possible since we got an overflow // In this case we need to reduce the accuracy of the "side to keep" // Now do the necessary rounding let mut remainder = 0; - while diff > 0 { + while let Some(diff_minus_one) = diff.checked_sub(1) { if is_all_zero(value) { *value_scale = new_scale; return; } - diff -= 1; + diff = diff_minus_one; // Any remainder is discarded if diff > 0 still (i.e. lost precision) remainder = div_by_u32(value, 10); @@ -35,8 +35,8 @@ pub(crate) fn rescale_internal(value: &mut [u32; 3], value_scale: &mut u32, new_ if remainder >= 5 { for part in value.iter_mut() { let digit = u64::from(*part) + 1u64; - remainder = if digit > 0xFFFF_FFFF { 1 } else { 0 }; - *part = (digit & 0xFFFF_FFFF) as u32; + remainder = if digit > U32_MASK { 1 } else { 0 }; + *part = (digit & U32_MASK) as u32; if remainder == 0 { break; } @@ -44,13 +44,17 @@ pub(crate) fn rescale_internal(value: &mut [u32; 3], value_scale: &mut u32, new_ } *value_scale = new_scale; } else { - let mut diff = new_scale - *value_scale; + let mut diff = new_scale.wrapping_sub(*value_scale); let mut working = [value[0], value[1], value[2]]; - while diff > 0 && mul_by_10(&mut working) == 0 { - value.copy_from_slice(&working); - diff -= 1; + while let Some(diff_minus_one) = diff.checked_sub(1) { + if mul_by_10(&mut working) == 0 { + value.copy_from_slice(&working); + diff = diff_minus_one; + } else { + break; + } } - *value_scale = new_scale - diff; + *value_scale = new_scale.wrapping_sub(diff); } } @@ -94,32 +98,35 @@ pub(crate) fn add_by_internal(value: &mut [u32], by: &[u32]) -> u32 { } pub(crate) fn add_by_internal_flattened(value: &mut [u32; 3], by: u32) -> u32 { - let mut carry: u64; - let mut sum: u64; - sum = u64::from(value[0]) + u64::from(by); - value[0] = (sum & U32_MASK) as u32; - carry = sum >> 32; - if carry > 0 { - sum = u64::from(value[1]) + carry; - value[1] = (sum & U32_MASK) as u32; - carry = sum >> 32; - if carry > 0 { - sum = u64::from(value[2]) + carry; - value[2] = (sum & U32_MASK) as u32; - carry = sum >> 32; - } - } - carry as u32 + manage_add_by_internal(by, value) } #[inline] -pub(crate) fn add_one_internal(value: &mut [u32]) -> u32 { - let mut carry: u64 = 1; // Start with one, since adding one - let mut sum: u64; - for i in value.iter_mut() { - sum = (*i as u64) + carry; - *i = (sum & U32_MASK) as u32; - carry = sum >> 32; +pub(crate) fn add_one_internal(value: &mut [u32; 3]) -> u32 { + manage_add_by_internal(1, value) +} + +// `u64 as u32` are safe because of widening and 32bits shifts +#[inline] +pub(crate) fn manage_add_by_internal(initial_carry: u32, value: &mut [u32; N]) -> u32 { + let mut carry = u64::from(initial_carry); + let mut iter = 0..value.len(); + let mut sum = 0; + + let mut sum_fn = |local_carry: &mut u64, idx| { + sum = u64::from(value[idx]).wrapping_add(*local_carry); + value[idx] = (sum & U32_MASK) as u32; + *local_carry = sum.wrapping_shr(32); + }; + + if let Some(idx) = iter.next() { + sum_fn(&mut carry, idx); + } + + for idx in iter { + if carry > 0 { + sum_fn(&mut carry, idx); + } } carry as u32 @@ -209,7 +216,7 @@ pub(crate) fn mul_part(left: u32, right: u32, high: u32) -> (u32, u32) { } // Returns remainder -pub(crate) fn div_by_u32(bits: &mut [u32], divisor: u32) -> u32 { +pub(crate) fn div_by_u32(bits: &mut [u32; N], divisor: u32) -> u32 { if divisor == 0 { // Divide by zero panic!("Internal error: divide by zero"); @@ -271,7 +278,7 @@ pub(crate) fn cmp_internal(left: &[u32; 3], right: &[u32; 3]) -> core::cmp::Orde } #[inline] -pub(crate) fn is_all_zero(bits: &[u32]) -> bool { +pub(crate) fn is_all_zero(bits: &[u32; N]) -> bool { bits.iter().all(|b| *b == 0) } diff --git a/src/ops/legacy.rs b/src/ops/legacy.rs index 4bb2924e..49f39814 100644 --- a/src/ops/legacy.rs +++ b/src/ops/legacy.rs @@ -624,7 +624,7 @@ fn add_with_scale_internal( let mut temp4 = [0u32, 0u32, 0u32, 0u32]; if *quotient_scale != *working_scale { // TODO: Remove necessity for temp (without performance impact) - fn div_by_10(target: &mut [u32], temp: &mut [u32], scale: &mut i32, target_scale: i32) { + fn div_by_10(target: &mut [u32], temp: &mut [u32; N], scale: &mut i32, target_scale: i32) { // Copy to the temp array temp.copy_from_slice(target); // divide by 10 until target scale is reached @@ -676,7 +676,12 @@ fn add_with_scale_internal( // (ultimately losing significant digits) if *quotient_scale != *working_scale { // TODO: Remove necessity for temp (without performance impact) - fn div_by_10_lossy(target: &mut [u32], temp: &mut [u32], scale: &mut i32, target_scale: i32) { + fn div_by_10_lossy( + target: &mut [u32], + temp: &mut [u32; N], + scale: &mut i32, + target_scale: i32, + ) { temp.copy_from_slice(target); // divide by 10 until target scale is reached while *scale > target_scale { diff --git a/src/postgres.rs b/src/postgres.rs index 9ea09f88..e0744d0e 100644 --- a/src/postgres.rs +++ b/src/postgres.rs @@ -460,7 +460,7 @@ mod diesel_postgres { } #[cfg(any(feature = "db-postgres", feature = "db-tokio-postgres"))] -mod postgres { +mod _postgres { use super::*; use ::postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type}; use byteorder::{BigEndian, ReadBytesExt}; diff --git a/src/serde.rs b/src/serde.rs index c149b85d..24b1dd80 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -94,7 +94,7 @@ pub mod float { /// /// #[derive(Serialize, Deserialize)] /// pub struct StringExample { -/// #[serde(with = "rust_decimal::serde::string")] +/// #[serde(with = "rust_decimal::serde::str")] /// value: Decimal, /// } /// @@ -106,8 +106,12 @@ pub mod float { /// ``` #[cfg(feature = "serde-with-str")] pub mod str { + use crate::constants::MAX_STR_BUFFER_SIZE; + use super::*; - use serde::Serialize; + use arrayvec::ArrayString; + use core::convert::TryFrom; + use serde::{ser::Error, Serialize}; pub fn deserialize<'de, D>(deserializer: D) -> Result where @@ -120,7 +124,9 @@ pub mod str { where S: serde::Serializer, { - value.to_string().serialize(serializer) + ArrayString::::try_from(format_args!("{}", value)) + .map_err(S::Error::custom)? + .serialize(serializer) } }