diff --git a/src/compute/cast/dictionary_to.rs b/src/compute/cast/dictionary_to.rs index 2f77a948380..9e053c9d9a0 100644 --- a/src/compute/cast/dictionary_to.rs +++ b/src/compute/cast/dictionary_to.rs @@ -1,7 +1,7 @@ -use super::{cast, primitive_to_primitive}; +use super::{primitive_as_primitive, primitive_to_primitive, CastOptions}; use crate::{ array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}, - compute::take::take, + compute::{cast::cast_with_options, take::take}, datatypes::DataType, error::{ArrowError, Result}, }; @@ -32,7 +32,20 @@ pub fn dictionary_to_dictionary_values( let keys = from.keys(); let values = from.values(); - let values = cast(values.as_ref(), values_type)?.into(); + let values = cast_with_options(values.as_ref(), values_type, CastOptions::default())?.into(); + Ok(DictionaryArray::from_data(keys.clone(), values)) +} + +/// Similar to dictionary_to_dictionary_values, but overflowing cast is wrapped +pub fn wrapping_dictionary_to_dictionary_values( + from: &DictionaryArray, + values_type: &DataType, +) -> Result> { + let keys = from.keys(); + let values = from.values(); + + let values = + cast_with_options(values.as_ref(), values_type, CastOptions { wrapped: true })?.into(); Ok(DictionaryArray::from_data(keys.clone(), values)) } @@ -60,9 +73,30 @@ where } } +/// Similar to dictionary_to_dictionary_keys, but overflowing cast is wrapped +pub fn wrapping_dictionary_to_dictionary_keys( + from: &DictionaryArray, +) -> Result> +where + K1: DictionaryKey + num::traits::AsPrimitive, + K2: DictionaryKey, +{ + let keys = from.keys(); + let values = from.values(); + + let casted_keys = primitive_as_primitive::(keys, &K2::DATA_TYPE); + + if casted_keys.null_count() > keys.null_count() { + Err(ArrowError::KeyOverflowError) + } else { + Ok(DictionaryArray::from_data(casted_keys, values.clone())) + } +} + pub(super) fn dictionary_cast_dyn( array: &dyn Array, to_type: &DataType, + options: CastOptions, ) -> Result> { let array = array.as_any().downcast_ref::>().unwrap(); let keys = array.keys(); @@ -70,7 +104,7 @@ pub(super) fn dictionary_cast_dyn( match to_type { DataType::Dictionary(to_keys_type, to_values_type) => { - let values = cast(values.as_ref(), to_values_type)?.into(); + let values = cast_with_options(values.as_ref(), to_values_type, options)?.into(); // create the appropriate array type match to_keys_type.as_ref() { @@ -85,7 +119,7 @@ pub(super) fn dictionary_cast_dyn( _ => unreachable!(), } } - _ => unpack_dictionary::(keys, values.as_ref(), to_type), + _ => unpack_dictionary::(keys, values.as_ref(), to_type, options), } } @@ -94,13 +128,14 @@ fn unpack_dictionary( keys: &PrimitiveArray, values: &dyn Array, to_type: &DataType, + options: CastOptions, ) -> Result> where K: DictionaryKey, { // attempt to cast the dict values to the target type // use the take kernel to expand out the dictionary - let values = cast(values, to_type)?; + let values = cast_with_options(values, to_type, options)?; // take requires first casting i32 let indices = primitive_to_primitive::<_, i32>(keys, &DataType::Int32); diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index 366a9a3e850..2a99c899c31 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -36,6 +36,29 @@ pub use primitive_to::*; pub use timestamps::*; pub use utf8_to::*; +/// options defining how Cast kernels behave +#[derive(Clone, Copy, Debug)] +struct CastOptions { + /// default to false + /// whether an overflowing cast should be converted to `None` (default), or be wrapped (i.e. `256i16 as u8 = 0` vectorized). + /// Settings this to `true` is 5-6x faster for numeric types. + wrapped: bool, +} + +impl Default for CastOptions { + fn default() -> Self { + Self { wrapped: false } + } +} + +impl CastOptions { + fn with_wrapped(&self, v: bool) -> Self { + let mut option = self.clone(); + option.wrapped = v; + option + } +} + /// Returns true if this type is numeric: (UInt*, Unit*, or Float*). fn is_numeric(t: &DataType) -> bool { use DataType::*; @@ -239,9 +262,18 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } } -fn cast_list(array: &ListArray, to_type: &DataType) -> Result> { +fn cast_list( + array: &ListArray, + to_type: &DataType, + options: CastOptions, +) -> Result> { let values = array.values(); - let new_values = cast(values.as_ref(), ListArray::::get_child_type(to_type))?.into(); + let new_values = cast_with_options( + values.as_ref(), + ListArray::::get_child_type(to_type), + options, + )? + .into(); Ok(ListArray::::from_data( to_type.clone(), @@ -281,6 +313,7 @@ fn cast_large_to_list(array: &ListArray, to_type: &DataType) -> ListArray '1', `false` => `0` /// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings /// in integer casts return null @@ -291,13 +324,28 @@ fn cast_large_to_list(array: &ListArray, to_type: &DataType) -> ListArray Result> { + cast_with_options(array, to_type, CastOptions { wrapped: false }) +} + +/// Similar to [`cast`], but overflowing cast is wrapped +/// Behavior: +/// * PrimitiveArray to PrimitiveArray: overflowing cast will be wrapped (i.e. `256i16 as u8 = 0` vectorized). +pub fn wrapping_cast(array: &dyn Array, to_type: &DataType) -> Result> { + cast_with_options(array, to_type, CastOptions { wrapped: true }) +} + +#[inline] +fn cast_with_options( + array: &dyn Array, + to_type: &DataType, + options: CastOptions, +) -> Result> { use DataType::*; let from_type = array.data_type(); @@ -305,6 +353,8 @@ pub fn cast(array: &dyn Array, to_type: &DataType) -> Result> { if from_type == to_type { return Ok(clone(array)); } + + let as_options = options.with_wrapped(true); match (from_type, to_type) { (Struct(_), _) => Err(ArrowError::NotYetImplemented( "Cannot cast from struct to other types".to_string(), @@ -312,10 +362,12 @@ pub fn cast(array: &dyn Array, to_type: &DataType) -> Result> { (_, Struct(_)) => Err(ArrowError::NotYetImplemented( "Cannot cast to struct from other types".to_string(), )), - (List(_), List(_)) => cast_list::(array.as_any().downcast_ref().unwrap(), to_type) - .map(|x| Box::new(x) as Box), + (List(_), List(_)) => { + cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) + .map(|x| Box::new(x) as Box) + } (LargeList(_), LargeList(_)) => { - cast_list::(array.as_any().downcast_ref().unwrap(), to_type) + cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) .map(|x| Box::new(x) as Box) } (List(lhs), LargeList(rhs)) if lhs == rhs => Ok(cast_list_to_large_list( @@ -331,7 +383,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType) -> Result> { (_, List(to)) => { // cast primitive to list's primitive - let values = cast(array, to.data_type())?.into(); + let values = cast_with_options(array, to.data_type(), options)?.into(); // create offsets, where if array.len() = 2, we have [0,1,2] let offsets = unsafe { Buffer::from_trusted_len_iter_unchecked(0..=array.len() as i32) }; @@ -343,25 +395,25 @@ pub fn cast(array: &dyn Array, to_type: &DataType) -> Result> { } (Dictionary(index_type, _), _) => match **index_type { - DataType::Int8 => dictionary_cast_dyn::(array, to_type), - DataType::Int16 => dictionary_cast_dyn::(array, to_type), - DataType::Int32 => dictionary_cast_dyn::(array, to_type), - DataType::Int64 => dictionary_cast_dyn::(array, to_type), - DataType::UInt8 => dictionary_cast_dyn::(array, to_type), - DataType::UInt16 => dictionary_cast_dyn::(array, to_type), - DataType::UInt32 => dictionary_cast_dyn::(array, to_type), - DataType::UInt64 => dictionary_cast_dyn::(array, to_type), + DataType::Int8 => dictionary_cast_dyn::(array, to_type, options), + DataType::Int16 => dictionary_cast_dyn::(array, to_type, options), + DataType::Int32 => dictionary_cast_dyn::(array, to_type, options), + DataType::Int64 => dictionary_cast_dyn::(array, to_type, options), + DataType::UInt8 => dictionary_cast_dyn::(array, to_type, options), + DataType::UInt16 => dictionary_cast_dyn::(array, to_type, options), + DataType::UInt32 => dictionary_cast_dyn::(array, to_type, options), + DataType::UInt64 => dictionary_cast_dyn::(array, to_type, options), _ => unreachable!(), }, (_, Dictionary(index_type, value_type)) => match **index_type { - DataType::Int8 => cast_to_dictionary::(array, value_type), - DataType::Int16 => cast_to_dictionary::(array, value_type), - DataType::Int32 => cast_to_dictionary::(array, value_type), - DataType::Int64 => cast_to_dictionary::(array, value_type), - DataType::UInt8 => cast_to_dictionary::(array, value_type), - DataType::UInt16 => cast_to_dictionary::(array, value_type), - DataType::UInt32 => cast_to_dictionary::(array, value_type), - DataType::UInt64 => cast_to_dictionary::(array, value_type), + DataType::Int8 => cast_to_dictionary::(array, value_type, options), + DataType::Int16 => cast_to_dictionary::(array, value_type, options), + DataType::Int32 => cast_to_dictionary::(array, value_type, options), + DataType::Int64 => cast_to_dictionary::(array, value_type, options), + DataType::UInt8 => cast_to_dictionary::(array, value_type, options), + DataType::UInt16 => cast_to_dictionary::(array, value_type, options), + DataType::UInt32 => cast_to_dictionary::(array, value_type, options), + DataType::UInt64 => cast_to_dictionary::(array, value_type, options), _ => Err(ArrowError::NotYetImplemented(format!( "Casting from type {:?} to dictionary type {:?} not supported", from_type, to_type, @@ -509,105 +561,105 @@ pub fn cast(array: &dyn Array, to_type: &DataType) -> Result> { .map(|x| Box::new(x) as Box), // start numeric casts - (UInt8, UInt16) => primitive_to_primitive_dyn::(array, to_type), - (UInt8, UInt32) => primitive_to_primitive_dyn::(array, to_type), - (UInt8, UInt64) => primitive_to_primitive_dyn::(array, to_type), - (UInt8, Int8) => primitive_to_primitive_dyn::(array, to_type), - (UInt8, Int16) => primitive_to_primitive_dyn::(array, to_type), - (UInt8, Int32) => primitive_to_primitive_dyn::(array, to_type), - (UInt8, Int64) => primitive_to_primitive_dyn::(array, to_type), - (UInt8, Float32) => primitive_to_primitive_dyn::(array, to_type), - (UInt8, Float64) => primitive_to_primitive_dyn::(array, to_type), - - (UInt16, UInt8) => primitive_to_primitive_dyn::(array, to_type), - (UInt16, UInt32) => primitive_to_primitive_dyn::(array, to_type), - (UInt16, UInt64) => primitive_to_primitive_dyn::(array, to_type), - (UInt16, Int8) => primitive_to_primitive_dyn::(array, to_type), - (UInt16, Int16) => primitive_to_primitive_dyn::(array, to_type), - (UInt16, Int32) => primitive_to_primitive_dyn::(array, to_type), - (UInt16, Int64) => primitive_to_primitive_dyn::(array, to_type), - (UInt16, Float32) => primitive_to_primitive_dyn::(array, to_type), - (UInt16, Float64) => primitive_to_primitive_dyn::(array, to_type), - - (UInt32, UInt8) => primitive_to_primitive_dyn::(array, to_type), - (UInt32, UInt16) => primitive_to_primitive_dyn::(array, to_type), - (UInt32, UInt64) => primitive_to_primitive_dyn::(array, to_type), - (UInt32, Int8) => primitive_to_primitive_dyn::(array, to_type), - (UInt32, Int16) => primitive_to_primitive_dyn::(array, to_type), - (UInt32, Int32) => primitive_to_primitive_dyn::(array, to_type), - (UInt32, Int64) => primitive_to_primitive_dyn::(array, to_type), - (UInt32, Float32) => primitive_to_primitive_dyn::(array, to_type), - (UInt32, Float64) => primitive_to_primitive_dyn::(array, to_type), - - (UInt64, UInt8) => primitive_to_primitive_dyn::(array, to_type), - (UInt64, UInt16) => primitive_to_primitive_dyn::(array, to_type), - (UInt64, UInt32) => primitive_to_primitive_dyn::(array, to_type), - (UInt64, Int8) => primitive_to_primitive_dyn::(array, to_type), - (UInt64, Int16) => primitive_to_primitive_dyn::(array, to_type), - (UInt64, Int32) => primitive_to_primitive_dyn::(array, to_type), - (UInt64, Int64) => primitive_to_primitive_dyn::(array, to_type), - (UInt64, Float32) => primitive_to_primitive_dyn::(array, to_type), - (UInt64, Float64) => primitive_to_primitive_dyn::(array, to_type), - - (Int8, UInt8) => primitive_to_primitive_dyn::(array, to_type), - (Int8, UInt16) => primitive_to_primitive_dyn::(array, to_type), - (Int8, UInt32) => primitive_to_primitive_dyn::(array, to_type), - (Int8, UInt64) => primitive_to_primitive_dyn::(array, to_type), - (Int8, Int16) => primitive_to_primitive_dyn::(array, to_type), - (Int8, Int32) => primitive_to_primitive_dyn::(array, to_type), - (Int8, Int64) => primitive_to_primitive_dyn::(array, to_type), - (Int8, Float32) => primitive_to_primitive_dyn::(array, to_type), - (Int8, Float64) => primitive_to_primitive_dyn::(array, to_type), - - (Int16, UInt8) => primitive_to_primitive_dyn::(array, to_type), - (Int16, UInt16) => primitive_to_primitive_dyn::(array, to_type), - (Int16, UInt32) => primitive_to_primitive_dyn::(array, to_type), - (Int16, UInt64) => primitive_to_primitive_dyn::(array, to_type), - (Int16, Int8) => primitive_to_primitive_dyn::(array, to_type), - (Int16, Int32) => primitive_to_primitive_dyn::(array, to_type), - (Int16, Int64) => primitive_to_primitive_dyn::(array, to_type), - (Int16, Float32) => primitive_to_primitive_dyn::(array, to_type), - (Int16, Float64) => primitive_to_primitive_dyn::(array, to_type), - - (Int32, UInt8) => primitive_to_primitive_dyn::(array, to_type), - (Int32, UInt16) => primitive_to_primitive_dyn::(array, to_type), - (Int32, UInt32) => primitive_to_primitive_dyn::(array, to_type), - (Int32, UInt64) => primitive_to_primitive_dyn::(array, to_type), - (Int32, Int8) => primitive_to_primitive_dyn::(array, to_type), - (Int32, Int16) => primitive_to_primitive_dyn::(array, to_type), - (Int32, Int64) => primitive_to_primitive_dyn::(array, to_type), - (Int32, Float32) => primitive_to_primitive_dyn::(array, to_type), - (Int32, Float64) => primitive_to_primitive_dyn::(array, to_type), - - (Int64, UInt8) => primitive_to_primitive_dyn::(array, to_type), - (Int64, UInt16) => primitive_to_primitive_dyn::(array, to_type), - (Int64, UInt32) => primitive_to_primitive_dyn::(array, to_type), - (Int64, UInt64) => primitive_to_primitive_dyn::(array, to_type), - (Int64, Int8) => primitive_to_primitive_dyn::(array, to_type), - (Int64, Int16) => primitive_to_primitive_dyn::(array, to_type), - (Int64, Int32) => primitive_to_primitive_dyn::(array, to_type), - (Int64, Float32) => primitive_to_primitive_dyn::(array, to_type), - (Int64, Float64) => primitive_to_primitive_dyn::(array, to_type), - - (Float32, UInt8) => primitive_to_primitive_dyn::(array, to_type), - (Float32, UInt16) => primitive_to_primitive_dyn::(array, to_type), - (Float32, UInt32) => primitive_to_primitive_dyn::(array, to_type), - (Float32, UInt64) => primitive_to_primitive_dyn::(array, to_type), - (Float32, Int8) => primitive_to_primitive_dyn::(array, to_type), - (Float32, Int16) => primitive_to_primitive_dyn::(array, to_type), - (Float32, Int32) => primitive_to_primitive_dyn::(array, to_type), - (Float32, Int64) => primitive_to_primitive_dyn::(array, to_type), - (Float32, Float64) => primitive_to_primitive_dyn::(array, to_type), - - (Float64, UInt8) => primitive_to_primitive_dyn::(array, to_type), - (Float64, UInt16) => primitive_to_primitive_dyn::(array, to_type), - (Float64, UInt32) => primitive_to_primitive_dyn::(array, to_type), - (Float64, UInt64) => primitive_to_primitive_dyn::(array, to_type), - (Float64, Int8) => primitive_to_primitive_dyn::(array, to_type), - (Float64, Int16) => primitive_to_primitive_dyn::(array, to_type), - (Float64, Int32) => primitive_to_primitive_dyn::(array, to_type), - (Float64, Int64) => primitive_to_primitive_dyn::(array, to_type), - (Float64, Float32) => primitive_to_primitive_dyn::(array, to_type), + (UInt8, UInt16) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + + (UInt16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + + (UInt32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + + (UInt64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + + (Int8, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, Int16) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Int32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + + (Int16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, Int32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + + (Int32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + + (Int64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + + (Float32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + + (Float64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), // end numeric casts // temporal casts @@ -677,8 +729,9 @@ pub fn cast(array: &dyn Array, to_type: &DataType) -> Result> { fn cast_to_dictionary( array: &dyn Array, dict_value_type: &DataType, + options: CastOptions, ) -> Result> { - let array = cast(array, dict_value_type)?; + let array = cast_with_options(array, dict_value_type, options)?; let array = array.as_ref(); match *dict_value_type { DataType::Int8 => primitive_to_dictionary_dyn::(array), @@ -715,6 +768,139 @@ mod tests { assert!((9.0 - c.value(4)).abs() < f64::EPSILON); } + #[test] + fn i32_as_f64_no_overflow() { + let array = Int32Array::from_slice(&[5, 6, 7, 8, 9]); + let b = wrapping_cast(&array, &DataType::Float64).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert!((5.0 - c.value(0)).abs() < f64::EPSILON); + assert!((6.0 - c.value(1)).abs() < f64::EPSILON); + assert!((7.0 - c.value(2)).abs() < f64::EPSILON); + assert!((8.0 - c.value(3)).abs() < f64::EPSILON); + assert!((9.0 - c.value(4)).abs() < f64::EPSILON); + } + + #[test] + fn u16_as_u8_overflow() { + let array = UInt16Array::from_slice(&[255, 256, 257, 258, 259]); + let b = wrapping_cast(&array, &DataType::UInt8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + let values = c.values().as_slice(); + + println!("{}", 255u8.wrapping_add(10)); + + assert_eq!(values, &[255, 0, 1, 2, 3]) + } + + #[test] + fn u16_as_u8_no_overflow() { + let array = UInt16Array::from_slice(&[1, 2, 3, 4, 5]); + let b = wrapping_cast(&array, &DataType::UInt8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + let values = c.values().as_slice(); + assert_eq!(values, &[1, 2, 3, 4, 5]) + } + + #[test] + fn float_range_max() { + //floats to integers + let u: Option = num::cast(f32::MAX); + assert_eq!(u, None); + let u: Option = num::cast(f32::MAX); + assert_eq!(u, None); + + let u: Option = num::cast(f32::MAX); + assert_eq!(u, None); + let u: Option = num::cast(f32::MAX); + assert_eq!(u, None); + + let u: Option = num::cast(f64::MAX); + assert_eq!(u, None); + let u: Option = num::cast(f64::MAX); + assert_eq!(u, None); + + let u: Option = num::cast(f64::MAX); + assert_eq!(u, None); + let u: Option = num::cast(f64::MAX); + assert_eq!(u, None); + + //integers to floats + let u: Option = num::cast(u32::MAX); + assert!(u.is_some()); + let u: Option = num::cast(u32::MAX); + assert!(u.is_some()); + + let u: Option = num::cast(i32::MAX); + assert!(u.is_some()); + let u: Option = num::cast(i32::MAX); + assert!(u.is_some()); + + let u: Option = num::cast(i64::MAX); + assert!(u.is_some()); + let u: Option = num::cast(u64::MAX); + assert!(u.is_some()); + + let u: Option = num::cast(f32::MAX); + assert!(u.is_some()); + } + + #[test] + fn float_range_min() { + //floats to integers + let u: Option = num::cast(f32::MIN); + assert_eq!(u, None); + let u: Option = num::cast(f32::MIN); + assert_eq!(u, None); + + let u: Option = num::cast(f32::MIN); + assert_eq!(u, None); + let u: Option = num::cast(f32::MIN); + assert_eq!(u, None); + + let u: Option = num::cast(f64::MIN); + assert_eq!(u, None); + let u: Option = num::cast(f64::MIN); + assert_eq!(u, None); + + let u: Option = num::cast(f64::MIN); + assert_eq!(u, None); + let u: Option = num::cast(f64::MIN); + assert_eq!(u, None); + + //integers to floats + let u: Option = num::cast(u32::MIN); + assert!(u.is_some()); + let u: Option = num::cast(u32::MIN); + assert!(u.is_some()); + + let u: Option = num::cast(i32::MIN); + assert!(u.is_some()); + let u: Option = num::cast(i32::MIN); + assert!(u.is_some()); + + let u: Option = num::cast(i64::MIN); + assert!(u.is_some()); + let u: Option = num::cast(u64::MIN); + assert!(u.is_some()); + + let u: Option = num::cast(f32::MIN); + assert!(u.is_some()); + } + + #[test] + fn f32_as_u8_overflow() { + let array = Float32Array::from_slice(&[1.1, 5000.0]); + let b = cast(&array, &DataType::UInt8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + let expected = UInt8Array::from(&[Some(1), None]); + assert_eq!(c, &expected); + + let b = cast_with_options(&array, &DataType::UInt8, CastOptions { wrapped: true }).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + let expected = UInt8Array::from(&[Some(1), Some(255)]); + assert_eq!(c, &expected); + } + #[test] fn i32_to_u8() { let array = Int32Array::from_slice(&[-5, 6, -7, 8, 100000000]); @@ -901,12 +1087,19 @@ mod tests { panic!("Cast should have not failed") } } else { - assert!(cast(array.as_ref(), &d2).is_err()); + assert!( + cast_with_options(array.as_ref(), &d2, CastOptions::default()).is_err() + ); } }); } - fn test_primitive_to_primitive(lhs: &[I], lhs_type: DataType, expected: &[O], expected_type: DataType) { + fn test_primitive_to_primitive( + lhs: &[I], + lhs_type: DataType, + expected: &[O], + expected_type: DataType, + ) { let a = PrimitiveArray::::from_slice(lhs).to(lhs_type); let b = cast(&a, &expected_type).unwrap(); let b = b.as_any().downcast_ref::>().unwrap(); @@ -916,42 +1109,82 @@ mod tests { #[test] fn date32_to_date64() { - test_primitive_to_primitive(&[10000i32, 17890], DataType::Date32, &[864000000000i64, 1545696000000], DataType::Date64); + test_primitive_to_primitive( + &[10000i32, 17890], + DataType::Date32, + &[864000000000i64, 1545696000000], + DataType::Date64, + ); } #[test] fn date64_to_date32() { - test_primitive_to_primitive(&[864000000005i64, 1545696000001], DataType::Date64, &[10000i32, 17890], DataType::Date32); + test_primitive_to_primitive( + &[864000000005i64, 1545696000001], + DataType::Date64, + &[10000i32, 17890], + DataType::Date32, + ); } #[test] fn date32_to_int32() { - test_primitive_to_primitive(&[10000i32, 17890], DataType::Date32, &[10000i32, 17890], DataType::Int32); + test_primitive_to_primitive( + &[10000i32, 17890], + DataType::Date32, + &[10000i32, 17890], + DataType::Int32, + ); } #[test] fn int32_to_date32() { - test_primitive_to_primitive(&[10000i32, 17890], DataType::Int32, &[10000i32, 17890], DataType::Date32); + test_primitive_to_primitive( + &[10000i32, 17890], + DataType::Int32, + &[10000i32, 17890], + DataType::Date32, + ); } #[test] fn timestamp_to_date32() { - test_primitive_to_primitive(&[864000000005i64, 1545696000001], DataType::Timestamp(TimeUnit::Millisecond, Some(String::from("UTC"))), &[10000i32, 17890], DataType::Date32); + test_primitive_to_primitive( + &[864000000005i64, 1545696000001], + DataType::Timestamp(TimeUnit::Millisecond, Some(String::from("UTC"))), + &[10000i32, 17890], + DataType::Date32, + ); } #[test] fn timestamp_to_date64() { - test_primitive_to_primitive(&[864000000005i64, 1545696000001], DataType::Timestamp(TimeUnit::Millisecond, Some(String::from("UTC"))), &[864000000005i64, 1545696000001i64], DataType::Date64); + test_primitive_to_primitive( + &[864000000005i64, 1545696000001], + DataType::Timestamp(TimeUnit::Millisecond, Some(String::from("UTC"))), + &[864000000005i64, 1545696000001i64], + DataType::Date64, + ); } #[test] fn timestamp_to_i64() { - test_primitive_to_primitive(&[864000000005i64, 1545696000001], DataType::Timestamp(TimeUnit::Millisecond, Some(String::from("UTC"))), &[864000000005i64, 1545696000001i64], DataType::Int64); + test_primitive_to_primitive( + &[864000000005i64, 1545696000001], + DataType::Timestamp(TimeUnit::Millisecond, Some(String::from("UTC"))), + &[864000000005i64, 1545696000001i64], + DataType::Int64, + ); } #[test] fn timestamp_to_timestamp() { - test_primitive_to_primitive(&[864000003005i64, 1545696002001], DataType::Timestamp(TimeUnit::Millisecond, None), &[864000003i64, 1545696002], DataType::Timestamp(TimeUnit::Second, None)); + test_primitive_to_primitive( + &[864000003005i64, 1545696002001], + DataType::Timestamp(TimeUnit::Millisecond, None), + &[864000003i64, 1545696002], + DataType::Timestamp(TimeUnit::Second, None), + ); } #[test] @@ -963,7 +1196,9 @@ mod tests { let result = cast(&array, &cast_type).expect("cast failed"); let mut expected = MutableDictionaryArray::>::new(); - expected.try_extend([Some("one"), None, Some("three"), Some("one")]).unwrap(); + expected + .try_extend([Some("one"), None, Some("three"), Some("one")]) + .unwrap(); let expected: DictionaryArray = expected.into(); assert_eq!(expected, result.as_ref()); } @@ -971,7 +1206,9 @@ mod tests { #[test] fn dict_to_utf8() { let mut array = MutableDictionaryArray::>::new(); - array.try_extend([Some("one"), None, Some("three"), Some("one")]).unwrap(); + array + .try_extend([Some("one"), None, Some("three"), Some("one")]) + .unwrap(); let array: DictionaryArray = array.into(); let result = cast(&array, &DataType::Utf8).expect("cast failed"); @@ -990,7 +1227,9 @@ mod tests { let result = cast(&array, &cast_type).expect("cast failed"); let mut expected = MutableDictionaryArray::>::new(); - expected.try_extend([Some(1), None, Some(3), Some(1)]).unwrap(); + expected + .try_extend([Some(1), None, Some(3), Some(1)]) + .unwrap(); let expected: DictionaryArray = expected.into(); assert_eq!(expected, result.as_ref()); } @@ -1002,7 +1241,9 @@ mod tests { Some(vec![Some(4), None, Some(6)]), ]; - let expected_data = data.iter().map(|x| x.as_ref().map(|x| x.iter().map(|x| x.map(|x| x as u16)))); + let expected_data = data + .iter() + .map(|x| x.as_ref().map(|x| x.iter().map(|x| x.map(|x| x as u16)))); let mut array = MutableListArray::>::new(); array.try_extend(data.clone()).unwrap(); @@ -1037,8 +1278,8 @@ mod tests { let array: ArrayRef = Arc::new(builder.finish()); let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); - let res = cast(&array, &cast_type); - assert!(res.is_err()); + let res = cast_with_options(&array, &cast_type); + assert, CastOptions::default())!(res.is_err()); let actual_error = format!("{:?}", res); let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; assert!( @@ -1069,8 +1310,8 @@ mod tests { let array: ArrayRef = Arc::new(builder.finish()); let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); - let res = cast(&array, &cast_type); - assert!(res.is_err()); + let res = cast_with_options(&array, &cast_type); + assert, CastOptions::default())!(res.is_err()); let actual_error = format!("{:?}", res); let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; assert!( @@ -1095,7 +1336,7 @@ mod tests { "2000", // just a year is invalid ]); let array = Arc::new(a) as ArrayRef; - let b = cast(&array, &DataType::Date32).unwrap(); + let b = cast_with_options(&array, &DataType::Date32, CastOptions::default()).unwrap(); let c = b.as_any().downcast_ref::().unwrap(); // test valid inputs @@ -1126,7 +1367,7 @@ mod tests { "2000-01-01", // just a date is invalid ]); let array = Arc::new(a) as ArrayRef; - let b = cast(&array, &DataType::Date64).unwrap(); + let b = cast_with_options(&array, &DataType::Date64, CastOptions::default()).unwrap(); let c = b.as_any().downcast_ref::().unwrap(); // test valid inputs diff --git a/src/compute/cast/primitive_to.rs b/src/compute/cast/primitive_to.rs index dd34d93461a..862af498e45 100644 --- a/src/compute/cast/primitive_to.rs +++ b/src/compute/cast/primitive_to.rs @@ -10,6 +10,8 @@ use crate::{ }; use crate::{error::Result, util::lexical_to_string}; +use super::CastOptions; + /// Returns a [`BooleanArray`] where every element is different from zero. /// Validity is preserved. pub fn primitive_to_boolean(from: &PrimitiveArray) -> BooleanArray { @@ -48,13 +50,18 @@ where pub(super) fn primitive_to_primitive_dyn( from: &dyn Array, to_type: &DataType, + options: CastOptions, ) -> Result> where - I: NativeType + num::NumCast, + I: NativeType + num::NumCast + num::traits::AsPrimitive, O: NativeType + num::NumCast, { let from = from.as_any().downcast_ref::>().unwrap(); - Ok(Box::new(primitive_to_primitive::(from, to_type))) + if options.wrapped { + Ok(Box::new(primitive_as_primitive::(from, to_type))) + } else { + Ok(Box::new(primitive_to_primitive::(from, to_type))) + } } /// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of another physical type via numeric conversion. @@ -72,6 +79,23 @@ where PrimitiveArray::::from_trusted_len_iter(iter).to(to_type.clone()) } +/// Cast [`PrimitiveArray`] as a [`PrimitiveArray`] +/// Same as `number as to_number_type` in rust +pub fn primitive_as_primitive( + from: &PrimitiveArray, + to_type: &DataType, +) -> PrimitiveArray +where + I: NativeType + num::traits::AsPrimitive, + O: NativeType, +{ + unary( + from, + |x| num::traits::AsPrimitive::::as_(x), + to_type.clone(), + ) +} + /// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type. /// This is O(1). pub fn primitive_to_same_primitive(