Skip to content

Commit c163a91

Browse files
committed
Refactor common logic
1 parent 1f19580 commit c163a91

File tree

1 file changed

+85
-78
lines changed

1 file changed

+85
-78
lines changed

arrow-cast/src/cast/decimal.rs

Lines changed: 85 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -142,66 +142,95 @@ impl DecimalCast for i256 {
142142
/// Build a rescale function from (input_precision, input_scale) to (output_precision, output_scale)
143143
/// returning a closure `Fn(I::Native) -> Option<O::Native>` that performs the conversion.
144144
pub(crate) fn rescale_decimal<I, O>(
145-
_input_precision: u8,
145+
input_precision: u8,
146146
input_scale: i8,
147-
_output_precision: u8,
147+
output_precision: u8,
148148
output_scale: i8,
149-
) -> impl Fn(I::Native) -> Option<O::Native>
149+
) -> impl Fn(I::Native) -> Result<O::Native, ArrowError>
150150
where
151151
I: DecimalType,
152152
O: DecimalType,
153153
I::Native: DecimalCast + ArrowNativeTypeOp,
154154
O::Native: DecimalCast + ArrowNativeTypeOp,
155155
{
156156
let delta_scale = output_scale - input_scale;
157-
158-
// Precompute parameters and capture them in a single closure type
159-
let mul_opt = if delta_scale > 0 {
160-
O::Native::from_decimal(10_i128)
161-
.and_then(|t| t.pow_checked(delta_scale as u32).ok())
157+
let input_precision_i8 = input_precision as i8;
158+
let output_precision_i8 = output_precision as i8;
159+
160+
// Determine if the cast is infallible based on precision/scale math
161+
let is_infallible_cast = input_precision_i8 + delta_scale < output_precision_i8;
162+
163+
// Build a single mode once and use a thin closure that calls into it
164+
enum RescaleMode<I, O> {
165+
SameScale,
166+
Up { mul: O },
167+
Down { div: I, half: I, half_neg: I },
168+
Invalid,
169+
}
170+
171+
let mode = if delta_scale == 0 {
172+
RescaleMode::SameScale
173+
} else if delta_scale > 0 {
174+
match O::Native::from_decimal(10_i128).and_then(|t| t.pow_checked(delta_scale as u32).ok())
175+
{
176+
Some(mul) => RescaleMode::Up { mul },
177+
None => RescaleMode::Invalid,
178+
}
162179
} else {
163-
None
180+
// delta_scale < 0
181+
match I::Native::from_decimal(10_i128)
182+
.and_then(|t| t.pow_checked(delta_scale.unsigned_abs() as u32).ok())
183+
{
184+
Some(div) => {
185+
let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
186+
let half_neg = half.neg_wrapping();
187+
RescaleMode::Down {
188+
div,
189+
half,
190+
half_neg,
191+
}
192+
}
193+
None => RescaleMode::Invalid,
194+
}
164195
};
165196

166-
let (div_opt, half_opt, half_neg_opt) = if delta_scale < 0 {
167-
let div = I::Native::from_decimal(10_i128)
168-
.and_then(|t| t.pow_checked(delta_scale.unsigned_abs() as u32).ok());
169-
if let Some(div) = div {
170-
let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
171-
let half_neg = half.neg_wrapping();
172-
(Some(div), Some(half), Some(half_neg))
173-
} else {
174-
(None, None, None)
197+
let f = move |x: I::Native| {
198+
match &mode {
199+
RescaleMode::SameScale => O::Native::from_decimal(x),
200+
RescaleMode::Up { mul } => {
201+
O::Native::from_decimal(x).and_then(|x| x.mul_checked(*mul).ok())
202+
}
203+
RescaleMode::Down {
204+
div,
205+
half,
206+
half_neg,
207+
} => {
208+
// div is >= 10 and so this cannot overflow
209+
let d = x.div_wrapping(*div);
210+
let r = x.mod_wrapping(*div);
211+
212+
// Round result
213+
let adjusted = match x >= I::Native::ZERO {
214+
true if r >= *half => d.add_wrapping(I::Native::ONE),
215+
false if r <= *half_neg => d.sub_wrapping(I::Native::ONE),
216+
_ => d,
217+
};
218+
O::Native::from_decimal(adjusted)
219+
}
220+
RescaleMode::Invalid => None,
175221
}
176-
} else {
177-
(None, None, None)
178222
};
179223

180-
move |x: I::Native| {
181-
if delta_scale == 0 {
182-
return O::Native::from_decimal(x);
183-
}
224+
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
184225

185-
if let Some(mul) = mul_opt {
186-
return O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
226+
move |x| {
227+
if is_infallible_cast {
228+
f(x).ok_or_else(|| error(x))
229+
} else {
230+
f(x).ok_or_else(|| error(x)).and_then(|v| {
231+
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
232+
})
187233
}
188-
189-
// Decrease scale path
190-
let div = div_opt.unwrap();
191-
let half = half_opt.unwrap();
192-
let half_neg = half_neg_opt.unwrap();
193-
194-
// div is >= 10 and so this cannot overflow
195-
let d = x.div_wrapping(div);
196-
let r = x.mod_wrapping(div);
197-
198-
// Round result
199-
let adjusted = match x >= I::Native::ZERO {
200-
true if r >= half => d.add_wrapping(I::Native::ONE),
201-
false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
202-
_ => d,
203-
};
204-
O::Native::from_decimal(adjusted)
205234
}
206235
}
207236

@@ -240,7 +269,8 @@ where
240269
I::Native: DecimalCast + ArrowNativeTypeOp,
241270
O::Native: DecimalCast + ArrowNativeTypeOp,
242271
{
243-
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
272+
// make sure we don't perform calculations that don't make sense w/o validation
273+
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
244274
let delta_scale = input_scale - output_scale;
245275
// if the reduction of the input number through scaling (dividing) is greater
246276
// than a possible precision loss (plus potential increase via rounding)
@@ -254,27 +284,15 @@ where
254284
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
255285
let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
256286

257-
let f = rescale_decimal::<I, O>(
258-
input_precision,
259-
input_scale,
260-
output_precision,
261-
output_scale,
262-
);
287+
let f = rescale_decimal::<I, O>(input_precision, input_scale, output_precision, output_scale);
263288

264289
Ok(if is_infallible_cast {
265-
// make sure we don't perform calculations that don't make sense w/o validation
266-
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
267-
let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed
268-
// to fit into the target type
269-
array.unary(g)
290+
// unwrapping is safe since the result is guaranteed to fit into the target type
291+
array.unary(|x| f(x).unwrap())
270292
} else if cast_options.safe {
271-
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
293+
array.unary_opt(|x| f(x).ok())
272294
} else {
273-
array.try_unary(|x| {
274-
f(x).ok_or_else(|| error(x)).and_then(|v| {
275-
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
276-
})
277-
})?
295+
array.try_unary(|x| f(x))?
278296
})
279297
}
280298

@@ -292,7 +310,8 @@ where
292310
I::Native: DecimalCast + ArrowNativeTypeOp,
293311
O::Native: DecimalCast + ArrowNativeTypeOp,
294312
{
295-
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
313+
// make sure we don't perform calculations that don't make sense w/o validation
314+
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
296315
let delta_scale = output_scale - input_scale;
297316

298317
// if the gain in precision (digits) is greater than the multiplication due to scaling
@@ -302,27 +321,15 @@ where
302321
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
303322
// needs to provide at least 8 digits precision
304323
let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
305-
let f = rescale_decimal::<I, O>(
306-
input_precision,
307-
input_scale,
308-
output_precision,
309-
output_scale,
310-
);
324+
let f = rescale_decimal::<I, O>(input_precision, input_scale, output_precision, output_scale);
311325

312326
Ok(if is_infallible_cast {
313-
// make sure we don't perform calculations that don't make sense w/o validation
314-
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
315327
// unwrapping is safe since the result is guaranteed to fit into the target type
316-
let f = |x: I::Native| f(x).unwrap();
317-
array.unary(f)
328+
array.unary(|x| f(x).unwrap())
318329
} else if cast_options.safe {
319-
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
330+
array.unary_opt(|x| f(x).ok())
320331
} else {
321-
array.try_unary(|x| {
322-
f(x).ok_or_else(|| error(x)).and_then(|v| {
323-
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
324-
})
325-
})?
332+
array.try_unary(|x| f(x))?
326333
})
327334
}
328335

0 commit comments

Comments
 (0)