Skip to content

Commit a7cdd33

Browse files
committed
Use rescale_decimal for variant decimal scaling
1 parent 274a028 commit a7cdd33

File tree

5 files changed

+157
-184
lines changed

5 files changed

+157
-184
lines changed

arrow-cast/src/cast/decimal.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,23 @@ use crate::cast::*;
1919

2020
/// A utility trait that provides checked conversions between
2121
/// decimal types inspired by [`NumCast`]
22-
pub(crate) trait DecimalCast: Sized {
22+
pub trait DecimalCast: Sized {
23+
/// Convert the decimal to an i32
2324
fn to_i32(self) -> Option<i32>;
2425

26+
/// Convert the decimal to an i64
2527
fn to_i64(self) -> Option<i64>;
2628

29+
/// Convert the decimal to an i128
2730
fn to_i128(self) -> Option<i128>;
2831

32+
/// Convert the decimal to an i256
2933
fn to_i256(self) -> Option<i256>;
3034

35+
/// Convert a decimal from a decimal
3136
fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
3237

38+
/// Convert a decimal from a f64
3339
fn from_f64(n: f64) -> Option<Self>;
3440
}
3541

@@ -141,7 +147,7 @@ impl DecimalCast for i256 {
141147

142148
/// Build a rescale function from (input_precision, input_scale) to (output_precision, output_scale)
143149
/// returning a closure `Fn(I::Native) -> Option<O::Native>` that performs the conversion.
144-
pub(crate) fn rescale_decimal<I, O>(
150+
pub fn rescale_decimal<I, O>(
145151
input_precision: u8,
146152
input_scale: i8,
147153
output_precision: u8,

arrow-cast/src/cast/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ use arrow_schema::*;
6767
use arrow_select::take::take;
6868
use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive};
6969

70+
pub use decimal::{DecimalCast, rescale_decimal};
71+
7072
/// CastOptions provides a way to override the default cast behaviors
7173
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7274
pub struct CastOptions<'a> {

parquet-variant-compute/src/type_conversion.rs

Lines changed: 57 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@
1717

1818
//! Module for transforming a typed arrow `Array` to `VariantArray`.
1919
20-
use arrow::datatypes::{self, ArrowPrimitiveType, Decimal32Type, DecimalType, MAX_DECIMAL32_FOR_EACH_PRECISION};
21-
use parquet_variant::{Variant, VariantDecimal4};
20+
use arrow::{
21+
compute::{DecimalCast, rescale_decimal},
22+
datatypes::{
23+
self, ArrowPrimitiveType, Decimal32Type, Decimal64Type, Decimal128Type, DecimalType,
24+
},
25+
};
26+
use arrow_schema::ArrowError;
27+
use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16};
2228

2329
/// Options for controlling the behavior of `cast_to_variant_with_options`.
2430
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -61,90 +67,61 @@ impl_primitive_from_variant!(datatypes::Float16Type, as_f16);
6167
impl_primitive_from_variant!(datatypes::Float32Type, as_f32);
6268
impl_primitive_from_variant!(datatypes::Float64Type, as_f64);
6369

64-
macro_rules! scale_variant_decimal {
65-
($variant:expr, $variant_method:ident, $to_int_ty:expr, $output_scale:expr, $precision:expr, $validate:path) => {{
66-
let variant = $variant.$variant_method()?;
67-
let input_scale = variant.scale() as i8;
68-
let variant = $to_int_ty(variant.integer());
69-
let ten = $to_int_ty(10);
70-
71-
let scaled = if input_scale == $output_scale {
72-
Some(variant)
73-
} else if input_scale < $output_scale {
74-
// scale_up means output has more fractional digits than input
75-
// multiply integer by 10^(output_scale - input_scale)
76-
let delta = ($output_scale - input_scale) as u32;
77-
let mul = ten.checked_pow(delta)?;
78-
variant.checked_mul(mul)
79-
} else {
80-
// scale_down means output has fewer fractional digits than input
81-
// divide by 10^(input_scale - output_scale) with rounding
82-
let delta = (input_scale - $output_scale) as u32;
83-
let div = ten.checked_pow(delta)?;
84-
let d = variant.checked_div(div)?;
85-
let r = variant % div;
86-
87-
// rounding in the same way as convert_to_smaller_scale_decimal in arrow-cast
88-
let half = div.checked_div($to_int_ty(2))?;
89-
let half_neg = half.checked_neg()?;
90-
let adjusted = match variant >= $to_int_ty(0) {
91-
true if r >= half => d.checked_add($to_int_ty(1))?,
92-
false if r <= half_neg => d.checked_sub($to_int_ty(1))?,
93-
_ => d,
94-
};
95-
Some(adjusted)
96-
};
97-
98-
scaled.filter(|v| $validate(*v, $precision))
99-
}};
100-
}
101-
pub(crate) use scale_variant_decimal;
102-
103-
fn variant_to_unscaled_decimal32(
104-
variant: Variant<'_, '_>,
70+
pub(crate) fn variant_to_unscaled_decimal<O>(
71+
variant: &Variant<'_, '_>,
10572
precision: u8,
10673
scale: i8,
107-
) -> Option<i32> {
108-
match variant {
109-
Variant::Int32(i) => scale_variant_decimal_new::<Decimal32Type, Decimal32Type>(i, VariantDecimal4::MAX_PRECISION, 0, precision, scale),
110-
Variant::Decimal4(d) => scale_variant_decimal_new::<Decimal32Type, Decimal32Type>(d.integer(), VariantDecimal4::MAX_PRECISION, d.scale() as i8, precision, scale),
111-
_ => None,
112-
}
113-
}
114-
115-
// fn rescale_variant(integer: i32, input_precision: u8, input_scale: i8, output_precision: u8, output_scale: i8) -> Option<i32> {
116-
// let input_precision = input_precision as i8;
117-
// let output_precision = output_precision as i8;
118-
// let mut input_integer_digits = input_precision - input_scale;
119-
// let output_integer_digits = output_precision - output_scale;
120-
121-
// //
122-
// if input_integer_digits > output_integer_digits {
123-
// if !Decimal32Type::is_valid_decimal_precision(integer, (output_integer_digits + input_scale) as u8) {
124-
// return None;
125-
// }
126-
// input_integer_digits = output_integer_digits;
127-
// }
128-
129-
// if input_integer_digits == output_integer_digits {
130-
// let rescaled =
131-
// }
132-
// }
133-
134-
135-
136-
fn scale_variant_decimal_new<I, O>(
137-
integer: I::Native,
138-
input_precision: u8,
139-
input_scale: i8,
140-
output_precision: u8,
141-
output_scale: i8,
14274
) -> Option<O::Native>
14375
where
144-
I: DecimalType,
14576
O: DecimalType,
77+
O::Native: DecimalCast,
14678
{
147-
return None;
79+
let maybe_rescaled = match variant {
80+
Variant::Int8(i) => {
81+
rescale_decimal::<Decimal32Type, O>(VariantDecimal4::MAX_PRECISION, 0, precision, scale)(
82+
*i as i32,
83+
)
84+
}
85+
Variant::Int16(i) => {
86+
rescale_decimal::<Decimal32Type, O>(VariantDecimal4::MAX_PRECISION, 0, precision, scale)(
87+
*i as i32,
88+
)
89+
}
90+
Variant::Int32(i) => {
91+
rescale_decimal::<Decimal32Type, O>(VariantDecimal4::MAX_PRECISION, 0, precision, scale)(
92+
*i,
93+
)
94+
}
95+
Variant::Int64(i) => {
96+
rescale_decimal::<Decimal64Type, O>(VariantDecimal8::MAX_PRECISION, 0, precision, scale)(
97+
*i,
98+
)
99+
}
100+
Variant::Decimal4(d) => rescale_decimal::<Decimal32Type, O>(
101+
VariantDecimal4::MAX_PRECISION,
102+
d.scale() as i8,
103+
precision,
104+
scale,
105+
)(d.integer()),
106+
Variant::Decimal8(d) => rescale_decimal::<Decimal64Type, O>(
107+
VariantDecimal8::MAX_PRECISION,
108+
d.scale() as i8,
109+
precision,
110+
scale,
111+
)(d.integer()),
112+
Variant::Decimal16(d) => rescale_decimal::<Decimal128Type, O>(
113+
VariantDecimal16::MAX_PRECISION,
114+
d.scale() as i8,
115+
precision,
116+
scale,
117+
)(d.integer()),
118+
_ => Err(ArrowError::InvalidArgumentError(format!(
119+
"Invalid variant type: {:?}",
120+
variant
121+
))),
122+
};
123+
124+
maybe_rescaled.ok()
148125
}
149126

150127
/// Convert the value at a specific index in the given array into a `Variant`.

0 commit comments

Comments
 (0)