|
17 | 17 |
|
18 | 18 | //! Module for transforming a typed arrow `Array` to `VariantArray`. |
19 | 19 |
|
20 | | -use arrow::datatypes::{self, ArrowPrimitiveType, Decimal32Type, DecimalType, MAX_DECIMAL32_FOR_EACH_PRECISION}; |
21 | | -use parquet_variant::{Variant, VariantDecimal4}; |
| 20 | +use arrow::{ |
| 21 | + compute::{DecimalCast, rescale_decimal}, |
| 22 | + datatypes::{ |
| 23 | + self, ArrowPrimitiveType, Decimal32Type, Decimal64Type, Decimal128Type, DecimalType, |
| 24 | + }, |
| 25 | +}; |
| 26 | +use arrow_schema::ArrowError; |
| 27 | +use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16}; |
22 | 28 |
|
23 | 29 | /// Options for controlling the behavior of `cast_to_variant_with_options`. |
24 | 30 | #[derive(Debug, Clone, PartialEq, Eq)] |
@@ -61,90 +67,61 @@ impl_primitive_from_variant!(datatypes::Float16Type, as_f16); |
61 | 67 | impl_primitive_from_variant!(datatypes::Float32Type, as_f32); |
62 | 68 | impl_primitive_from_variant!(datatypes::Float64Type, as_f64); |
63 | 69 |
|
64 | | -macro_rules! scale_variant_decimal { |
65 | | - ($variant:expr, $variant_method:ident, $to_int_ty:expr, $output_scale:expr, $precision:expr, $validate:path) => {{ |
66 | | - let variant = $variant.$variant_method()?; |
67 | | - let input_scale = variant.scale() as i8; |
68 | | - let variant = $to_int_ty(variant.integer()); |
69 | | - let ten = $to_int_ty(10); |
70 | | - |
71 | | - let scaled = if input_scale == $output_scale { |
72 | | - Some(variant) |
73 | | - } else if input_scale < $output_scale { |
74 | | - // scale_up means output has more fractional digits than input |
75 | | - // multiply integer by 10^(output_scale - input_scale) |
76 | | - let delta = ($output_scale - input_scale) as u32; |
77 | | - let mul = ten.checked_pow(delta)?; |
78 | | - variant.checked_mul(mul) |
79 | | - } else { |
80 | | - // scale_down means output has fewer fractional digits than input |
81 | | - // divide by 10^(input_scale - output_scale) with rounding |
82 | | - let delta = (input_scale - $output_scale) as u32; |
83 | | - let div = ten.checked_pow(delta)?; |
84 | | - let d = variant.checked_div(div)?; |
85 | | - let r = variant % div; |
86 | | - |
87 | | - // rounding in the same way as convert_to_smaller_scale_decimal in arrow-cast |
88 | | - let half = div.checked_div($to_int_ty(2))?; |
89 | | - let half_neg = half.checked_neg()?; |
90 | | - let adjusted = match variant >= $to_int_ty(0) { |
91 | | - true if r >= half => d.checked_add($to_int_ty(1))?, |
92 | | - false if r <= half_neg => d.checked_sub($to_int_ty(1))?, |
93 | | - _ => d, |
94 | | - }; |
95 | | - Some(adjusted) |
96 | | - }; |
97 | | - |
98 | | - scaled.filter(|v| $validate(*v, $precision)) |
99 | | - }}; |
100 | | -} |
101 | | -pub(crate) use scale_variant_decimal; |
102 | | - |
103 | | -fn variant_to_unscaled_decimal32( |
104 | | - variant: Variant<'_, '_>, |
| 70 | +pub(crate) fn variant_to_unscaled_decimal<O>( |
| 71 | + variant: &Variant<'_, '_>, |
105 | 72 | precision: u8, |
106 | 73 | scale: i8, |
107 | | -) -> Option<i32> { |
108 | | - match variant { |
109 | | - Variant::Int32(i) => scale_variant_decimal_new::<Decimal32Type, Decimal32Type>(i, VariantDecimal4::MAX_PRECISION, 0, precision, scale), |
110 | | - Variant::Decimal4(d) => scale_variant_decimal_new::<Decimal32Type, Decimal32Type>(d.integer(), VariantDecimal4::MAX_PRECISION, d.scale() as i8, precision, scale), |
111 | | - _ => None, |
112 | | - } |
113 | | -} |
114 | | - |
115 | | -// fn rescale_variant(integer: i32, input_precision: u8, input_scale: i8, output_precision: u8, output_scale: i8) -> Option<i32> { |
116 | | -// let input_precision = input_precision as i8; |
117 | | -// let output_precision = output_precision as i8; |
118 | | -// let mut input_integer_digits = input_precision - input_scale; |
119 | | -// let output_integer_digits = output_precision - output_scale; |
120 | | - |
121 | | -// // |
122 | | -// if input_integer_digits > output_integer_digits { |
123 | | -// if !Decimal32Type::is_valid_decimal_precision(integer, (output_integer_digits + input_scale) as u8) { |
124 | | -// return None; |
125 | | -// } |
126 | | -// input_integer_digits = output_integer_digits; |
127 | | -// } |
128 | | - |
129 | | -// if input_integer_digits == output_integer_digits { |
130 | | -// let rescaled = |
131 | | -// } |
132 | | -// } |
133 | | - |
134 | | - |
135 | | - |
136 | | -fn scale_variant_decimal_new<I, O>( |
137 | | - integer: I::Native, |
138 | | - input_precision: u8, |
139 | | - input_scale: i8, |
140 | | - output_precision: u8, |
141 | | - output_scale: i8, |
142 | 74 | ) -> Option<O::Native> |
143 | 75 | where |
144 | | - I: DecimalType, |
145 | 76 | O: DecimalType, |
| 77 | + O::Native: DecimalCast, |
146 | 78 | { |
147 | | - return None; |
| 79 | + let maybe_rescaled = match variant { |
| 80 | + Variant::Int8(i) => { |
| 81 | + rescale_decimal::<Decimal32Type, O>(VariantDecimal4::MAX_PRECISION, 0, precision, scale)( |
| 82 | + *i as i32, |
| 83 | + ) |
| 84 | + } |
| 85 | + Variant::Int16(i) => { |
| 86 | + rescale_decimal::<Decimal32Type, O>(VariantDecimal4::MAX_PRECISION, 0, precision, scale)( |
| 87 | + *i as i32, |
| 88 | + ) |
| 89 | + } |
| 90 | + Variant::Int32(i) => { |
| 91 | + rescale_decimal::<Decimal32Type, O>(VariantDecimal4::MAX_PRECISION, 0, precision, scale)( |
| 92 | + *i, |
| 93 | + ) |
| 94 | + } |
| 95 | + Variant::Int64(i) => { |
| 96 | + rescale_decimal::<Decimal64Type, O>(VariantDecimal8::MAX_PRECISION, 0, precision, scale)( |
| 97 | + *i, |
| 98 | + ) |
| 99 | + } |
| 100 | + Variant::Decimal4(d) => rescale_decimal::<Decimal32Type, O>( |
| 101 | + VariantDecimal4::MAX_PRECISION, |
| 102 | + d.scale() as i8, |
| 103 | + precision, |
| 104 | + scale, |
| 105 | + )(d.integer()), |
| 106 | + Variant::Decimal8(d) => rescale_decimal::<Decimal64Type, O>( |
| 107 | + VariantDecimal8::MAX_PRECISION, |
| 108 | + d.scale() as i8, |
| 109 | + precision, |
| 110 | + scale, |
| 111 | + )(d.integer()), |
| 112 | + Variant::Decimal16(d) => rescale_decimal::<Decimal128Type, O>( |
| 113 | + VariantDecimal16::MAX_PRECISION, |
| 114 | + d.scale() as i8, |
| 115 | + precision, |
| 116 | + scale, |
| 117 | + )(d.integer()), |
| 118 | + _ => Err(ArrowError::InvalidArgumentError(format!( |
| 119 | + "Invalid variant type: {:?}", |
| 120 | + variant |
| 121 | + ))), |
| 122 | + }; |
| 123 | + |
| 124 | + maybe_rescaled.ok() |
148 | 125 | } |
149 | 126 |
|
150 | 127 | /// Convert the value at a specific index in the given array into a `Variant`. |
|
0 commit comments