Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added support to cast decimal (#761)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Jan 13, 2022
1 parent 6b7af9f commit 1c6e8bd
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 6 deletions.
136 changes: 136 additions & 0 deletions src/compute/cast/decimal_to.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use num_traits::{AsPrimitive, Float, NumCast};

use crate::error::Result;
use crate::types::NativeType;
use crate::{array::*, datatypes::DataType};

#[inline]
fn decimal_to_decimal_impl<F: Fn(i128) -> Option<i128>>(
from: &PrimitiveArray<i128>,
op: F,
to_precision: usize,
to_scale: usize,
) -> PrimitiveArray<i128> {
let min_for_precision = 9_i128
.saturating_pow(1 + to_precision as u32)
.saturating_neg();
let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32);

let values = from.iter().map(|x| {
x.and_then(|x| {
op(*x).and_then(|x| {
if x > max_for_precision || x < min_for_precision {
None
} else {
Some(x)
}
})
})
});
PrimitiveArray::<i128>::from_trusted_len_iter(values)
.to(DataType::Decimal(to_precision, to_scale))
}

/// Returns a [`PrimitiveArray<i128>`] with the casted values. Values are `None` on overflow
pub fn decimal_to_decimal(
from: &PrimitiveArray<i128>,
to_precision: usize,
to_scale: usize,
) -> PrimitiveArray<i128> {
let (from_precision, from_scale) =
if let DataType::Decimal(p, s) = from.data_type().to_logical_type() {
(*p, *s)
} else {
panic!("internal error: i128 is always a decimal")
};

if to_scale == from_scale && to_precision >= from_precision {
// fast path
return from.clone().to(DataType::Decimal(to_precision, to_scale));
}
// todo: other fast paths include increasing scale and precision by so that
// a number will never overflow (validity is preserved)

if from_scale > to_scale {
let factor = 10_i128.pow((from_scale - to_scale) as u32);
decimal_to_decimal_impl(
from,
|x: i128| x.checked_div(factor),
to_precision,
to_scale,
)
} else {
let factor = 10_i128.pow((to_scale - from_scale) as u32);
decimal_to_decimal_impl(
from,
|x: i128| x.checked_mul(factor),
to_precision,
to_scale,
)
}
}

pub(super) fn decimal_to_decimal_dyn(
from: &dyn Array,
to_precision: usize,
to_scale: usize,
) -> Result<Box<dyn Array>> {
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(decimal_to_decimal(from, to_precision, to_scale)))
}

/// Returns a [`PrimitiveArray<i128>`] with the casted values. Values are `None` on overflow
pub fn decimal_to_float<T>(from: &PrimitiveArray<i128>) -> PrimitiveArray<T>
where
T: NativeType + Float,
f64: AsPrimitive<T>,
{
let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() {
(*p, *s)
} else {
panic!("internal error: i128 is always a decimal")
};

let div = 10_f64.powi(from_scale as i32);
let values = from
.values()
.iter()
.map(|x| (*x as f64 / div).as_())
.collect();

PrimitiveArray::<T>::from_data(T::PRIMITIVE.into(), values, from.validity().cloned())
}

pub(super) fn decimal_to_float_dyn<T>(from: &dyn Array) -> Result<Box<dyn Array>>
where
T: NativeType + Float,
f64: AsPrimitive<T>,
{
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(decimal_to_float::<T>(from)))
}

/// Returns a [`PrimitiveArray<i128>`] with the casted values. Values are `None` on overflow
pub fn decimal_to_integer<T>(from: &PrimitiveArray<i128>) -> PrimitiveArray<T>
where
T: NativeType + NumCast,
{
let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() {
(*p, *s)
} else {
panic!("internal error: i128 is always a decimal")
};

let factor = 10_i128.pow(from_scale as u32);
let values = from.iter().map(|x| x.and_then(|x| T::from(*x / factor)));

PrimitiveArray::from_trusted_len_iter(values)
}

pub(super) fn decimal_to_integer_dyn<T>(from: &dyn Array) -> Result<Box<dyn Array>>
where
T: NativeType + NumCast,
{
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(decimal_to_integer::<T>(from)))
}
61 changes: 55 additions & 6 deletions src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
//! Defines different casting operators such as [`cast`] or [`primitive_to_binary`].

use crate::{
array::*,
datatypes::*,
error::{ArrowError, Result},
};

mod binary_to;
mod boolean_to;
mod decimal_to;
mod dictionary_to;
mod primitive_to;
mod utf8_to;

pub use binary_to::*;
pub use boolean_to::*;
pub use decimal_to::*;
pub use dictionary_to::*;
pub use primitive_to::*;
pub use utf8_to::*;

use crate::{
array::*,
datatypes::*,
error::{ArrowError, Result},
};

/// options defining how Cast kernels behave
#[derive(Clone, Copy, Debug, Default)]
pub struct CastOptions {
Expand Down Expand Up @@ -143,6 +145,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(UInt8, Int64) => true,
(UInt8, Float32) => true,
(UInt8, Float64) => true,
(UInt8, Decimal(_, _)) => true,

(UInt16, UInt8) => true,
(UInt16, UInt32) => true,
Expand All @@ -153,6 +156,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(UInt16, Int64) => true,
(UInt16, Float32) => true,
(UInt16, Float64) => true,
(UInt16, Decimal(_, _)) => true,

(UInt32, UInt8) => true,
(UInt32, UInt16) => true,
Expand All @@ -163,6 +167,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(UInt32, Int64) => true,
(UInt32, Float32) => true,
(UInt32, Float64) => true,
(UInt32, Decimal(_, _)) => true,

(UInt64, UInt8) => true,
(UInt64, UInt16) => true,
Expand All @@ -173,6 +178,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(UInt64, Int64) => true,
(UInt64, Float32) => true,
(UInt64, Float64) => true,
(UInt64, Decimal(_, _)) => true,

(Int8, UInt8) => true,
(Int8, UInt16) => true,
Expand All @@ -183,6 +189,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int8, Int64) => true,
(Int8, Float32) => true,
(Int8, Float64) => true,
(Int8, Decimal(_, _)) => true,

(Int16, UInt8) => true,
(Int16, UInt16) => true,
Expand All @@ -193,6 +200,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int16, Int64) => true,
(Int16, Float32) => true,
(Int16, Float64) => true,
(Int16, Decimal(_, _)) => true,

(Int32, UInt8) => true,
(Int32, UInt16) => true,
Expand All @@ -203,6 +211,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int32, Int64) => true,
(Int32, Float32) => true,
(Int32, Float64) => true,
(Int32, Decimal(_, _)) => true,

(Int64, UInt8) => true,
(Int64, UInt16) => true,
Expand All @@ -213,6 +222,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int64, Int32) => true,
(Int64, Float32) => true,
(Int64, Float64) => true,
(Int64, Decimal(_, _)) => true,

(Float32, UInt8) => true,
(Float32, UInt16) => true,
Expand All @@ -223,6 +233,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Float32, Int32) => true,
(Float32, Int64) => true,
(Float32, Float64) => true,
(Float32, Decimal(_, _)) => true,

(Float64, UInt8) => true,
(Float64, UInt16) => true,
Expand All @@ -233,6 +244,22 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Float64, Int32) => true,
(Float64, Int64) => true,
(Float64, Float32) => true,
(Float64, Decimal(_, _)) => true,

(
Decimal(_, _),
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float32
| Float64
| Decimal(_, _),
) => true,
// end numeric casts

// temporal casts
Expand Down Expand Up @@ -649,6 +676,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(UInt8, Int64) => primitive_to_primitive_dyn::<u8, i64>(array, to_type, options),
(UInt8, Float32) => primitive_to_primitive_dyn::<u8, f32>(array, to_type, as_options),
(UInt8, Float64) => primitive_to_primitive_dyn::<u8, f64>(array, to_type, as_options),
(UInt8, Decimal(p, s)) => integer_to_decimal_dyn::<u8>(array, *p, *s),

(UInt16, UInt8) => primitive_to_primitive_dyn::<u16, u8>(array, to_type, options),
(UInt16, UInt32) => primitive_to_primitive_dyn::<u16, u32>(array, to_type, as_options),
Expand All @@ -659,6 +687,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(UInt16, Int64) => primitive_to_primitive_dyn::<u16, i64>(array, to_type, options),
(UInt16, Float32) => primitive_to_primitive_dyn::<u16, f32>(array, to_type, as_options),
(UInt16, Float64) => primitive_to_primitive_dyn::<u16, f64>(array, to_type, as_options),
(UInt16, Decimal(p, s)) => integer_to_decimal_dyn::<u16>(array, *p, *s),

(UInt32, UInt8) => primitive_to_primitive_dyn::<u32, u8>(array, to_type, options),
(UInt32, UInt16) => primitive_to_primitive_dyn::<u32, u16>(array, to_type, options),
Expand All @@ -669,6 +698,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(UInt32, Int64) => primitive_to_primitive_dyn::<u32, i64>(array, to_type, options),
(UInt32, Float32) => primitive_to_primitive_dyn::<u32, f32>(array, to_type, as_options),
(UInt32, Float64) => primitive_to_primitive_dyn::<u32, f64>(array, to_type, as_options),
(UInt32, Decimal(p, s)) => integer_to_decimal_dyn::<u32>(array, *p, *s),

(UInt64, UInt8) => primitive_to_primitive_dyn::<u64, u8>(array, to_type, options),
(UInt64, UInt16) => primitive_to_primitive_dyn::<u64, u16>(array, to_type, options),
Expand All @@ -679,6 +709,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(UInt64, Int64) => primitive_to_primitive_dyn::<u64, i64>(array, to_type, options),
(UInt64, Float32) => primitive_to_primitive_dyn::<u64, f32>(array, to_type, as_options),
(UInt64, Float64) => primitive_to_primitive_dyn::<u64, f64>(array, to_type, as_options),
(UInt64, Decimal(p, s)) => integer_to_decimal_dyn::<u64>(array, *p, *s),

(Int8, UInt8) => primitive_to_primitive_dyn::<i8, u8>(array, to_type, options),
(Int8, UInt16) => primitive_to_primitive_dyn::<i8, u16>(array, to_type, options),
Expand All @@ -689,6 +720,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int8, Int64) => primitive_to_primitive_dyn::<i8, i64>(array, to_type, as_options),
(Int8, Float32) => primitive_to_primitive_dyn::<i8, f32>(array, to_type, as_options),
(Int8, Float64) => primitive_to_primitive_dyn::<i8, f64>(array, to_type, as_options),
(Int8, Decimal(p, s)) => integer_to_decimal_dyn::<i8>(array, *p, *s),

(Int16, UInt8) => primitive_to_primitive_dyn::<i16, u8>(array, to_type, options),
(Int16, UInt16) => primitive_to_primitive_dyn::<i16, u16>(array, to_type, options),
Expand All @@ -699,6 +731,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int16, Int64) => primitive_to_primitive_dyn::<i16, i64>(array, to_type, as_options),
(Int16, Float32) => primitive_to_primitive_dyn::<i16, f32>(array, to_type, as_options),
(Int16, Float64) => primitive_to_primitive_dyn::<i16, f64>(array, to_type, as_options),
(Int16, Decimal(p, s)) => integer_to_decimal_dyn::<i16>(array, *p, *s),

(Int32, UInt8) => primitive_to_primitive_dyn::<i32, u8>(array, to_type, options),
(Int32, UInt16) => primitive_to_primitive_dyn::<i32, u16>(array, to_type, options),
Expand All @@ -709,6 +742,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int32, Int64) => primitive_to_primitive_dyn::<i32, i64>(array, to_type, as_options),
(Int32, Float32) => primitive_to_primitive_dyn::<i32, f32>(array, to_type, as_options),
(Int32, Float64) => primitive_to_primitive_dyn::<i32, f64>(array, to_type, as_options),
(Int32, Decimal(p, s)) => integer_to_decimal_dyn::<i32>(array, *p, *s),

(Int64, UInt8) => primitive_to_primitive_dyn::<i64, u8>(array, to_type, options),
(Int64, UInt16) => primitive_to_primitive_dyn::<i64, u16>(array, to_type, options),
Expand All @@ -719,6 +753,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int64, Int32) => primitive_to_primitive_dyn::<i64, i32>(array, to_type, options),
(Int64, Float32) => primitive_to_primitive_dyn::<i64, f32>(array, to_type, options),
(Int64, Float64) => primitive_to_primitive_dyn::<i64, f64>(array, to_type, as_options),
(Int64, Decimal(p, s)) => integer_to_decimal_dyn::<i64>(array, *p, *s),

(Float32, UInt8) => primitive_to_primitive_dyn::<f32, u8>(array, to_type, options),
(Float32, UInt16) => primitive_to_primitive_dyn::<f32, u16>(array, to_type, options),
Expand All @@ -729,6 +764,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Float32, Int32) => primitive_to_primitive_dyn::<f32, i32>(array, to_type, options),
(Float32, Int64) => primitive_to_primitive_dyn::<f32, i64>(array, to_type, options),
(Float32, Float64) => primitive_to_primitive_dyn::<f32, f64>(array, to_type, as_options),
(Float32, Decimal(p, s)) => float_to_decimal_dyn::<f32>(array, *p, *s),

(Float64, UInt8) => primitive_to_primitive_dyn::<f64, u8>(array, to_type, options),
(Float64, UInt16) => primitive_to_primitive_dyn::<f64, u16>(array, to_type, options),
Expand All @@ -739,6 +775,19 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Float64, Int32) => primitive_to_primitive_dyn::<f64, i32>(array, to_type, options),
(Float64, Int64) => primitive_to_primitive_dyn::<f64, i64>(array, to_type, options),
(Float64, Float32) => primitive_to_primitive_dyn::<f64, f32>(array, to_type, options),
(Float64, Decimal(p, s)) => float_to_decimal_dyn::<f64>(array, *p, *s),

(Decimal(_, _), UInt8) => decimal_to_integer_dyn::<u8>(array),
(Decimal(_, _), UInt16) => decimal_to_integer_dyn::<u16>(array),
(Decimal(_, _), UInt32) => decimal_to_integer_dyn::<u32>(array),
(Decimal(_, _), UInt64) => decimal_to_integer_dyn::<u64>(array),
(Decimal(_, _), Int8) => decimal_to_integer_dyn::<i8>(array),
(Decimal(_, _), Int16) => decimal_to_integer_dyn::<i16>(array),
(Decimal(_, _), Int32) => decimal_to_integer_dyn::<i32>(array),
(Decimal(_, _), Int64) => decimal_to_integer_dyn::<i64>(array),
(Decimal(_, _), Float32) => decimal_to_float_dyn::<f32>(array),
(Decimal(_, _), Float64) => decimal_to_float_dyn::<f64>(array),
(Decimal(_, _), Decimal(to_p, to_s)) => decimal_to_decimal_dyn(array, *to_p, *to_s),
// end numeric casts

// temporal casts
Expand Down
Loading

0 comments on commit 1c6e8bd

Please sign in to comment.