From 8482150fd2823574139a5e5bb9e7ef86101c065c Mon Sep 17 00:00:00 2001 From: Arttu Date: Sat, 27 Jul 2024 14:38:05 +0200 Subject: [PATCH] chore: make Cast's logic reusable for other projects (#716) --- native/spark-expr/src/cast.rs | 1099 ++++++++++++++++----------------- native/spark-expr/src/lib.rs | 2 +- 2 files changed, 550 insertions(+), 551 deletions(-) diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 9a47cc873..ae0818970 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -502,158 +502,166 @@ impl Cast { eval_mode, } } +} - fn cast_array(&self, array: ArrayRef) -> DataFusionResult { - let to_type = &self.data_type; - let array = array_with_timezone(array, self.timezone.clone(), Some(to_type))?; - let from_type = array.data_type().clone(); - let array = match &from_type { - DataType::Dictionary(key_type, value_type) - if key_type.as_ref() == &DataType::Int32 - && (value_type.as_ref() == &DataType::Utf8 - || value_type.as_ref() == &DataType::LargeUtf8) => - { - let dict_array = array - .as_any() - .downcast_ref::>() - .expect("Expected a dictionary array"); - - let casted_dictionary = DictionaryArray::::new( - dict_array.keys().clone(), - self.cast_array(dict_array.values().clone())?, - ); - - let casted_result = match to_type { - DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()), - _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, - }; - return Ok(spark_cast(casted_result, &from_type, to_type)); - } - _ => array, - }; - let from_type = array.data_type(); - - let cast_result = match (from_type, to_type) { - (DataType::Utf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) - } - (DataType::LargeUtf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) - } - (DataType::Utf8, DataType::Timestamp(_, _)) => { - Self::cast_string_to_timestamp(&array, to_type, self.eval_mode) - } - (DataType::Utf8, DataType::Date32) => { - Self::cast_string_to_date(&array, to_type, self.eval_mode) - } - (DataType::Int64, DataType::Int32) - | (DataType::Int64, DataType::Int16) - | (DataType::Int64, DataType::Int8) - | (DataType::Int32, DataType::Int16) - | (DataType::Int32, DataType::Int8) - | (DataType::Int16, DataType::Int8) - if self.eval_mode != EvalMode::Try => - { - Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type) - } - ( - DataType::Utf8, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), - ( - DataType::LargeUtf8, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), - (DataType::Float64, DataType::Utf8) => { - Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) - } - (DataType::Float64, DataType::LargeUtf8) => { - Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) - } - (DataType::Float32, DataType::Utf8) => { - Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) - } - (DataType::Float32, DataType::LargeUtf8) => { - Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) - } - (DataType::Float32, DataType::Decimal128(precision, scale)) => { - Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode) - } - (DataType::Float64, DataType::Decimal128(precision, scale)) => { - Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode) - } - (DataType::Float32, DataType::Int8) - | (DataType::Float32, DataType::Int16) - | (DataType::Float32, DataType::Int32) - | (DataType::Float32, DataType::Int64) - | (DataType::Float64, DataType::Int8) - | (DataType::Float64, DataType::Int16) - | (DataType::Float64, DataType::Int32) - | (DataType::Float64, DataType::Int64) - | (DataType::Decimal128(_, _), DataType::Int8) - | (DataType::Decimal128(_, _), DataType::Int16) - | (DataType::Decimal128(_, _), DataType::Int32) - | (DataType::Decimal128(_, _), DataType::Int64) - if self.eval_mode != EvalMode::Try => - { - Self::spark_cast_nonintegral_numeric_to_integral( - &array, - self.eval_mode, - from_type, - to_type, - ) - } - _ if Self::is_datafusion_spark_compatible(from_type, to_type) => { - // use DataFusion cast only when we know that it is compatible with Spark - Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) - } - _ => { - // we should never reach this code because the Scala code should be checking - // for supported cast operations and falling back to Spark for anything that - // is not yet supported - Err(SparkError::Internal(format!( - "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" - ))) - } - }; - Ok(spark_cast(cast_result?, from_type, to_type)) +/// Spark-compatible cast implementation. Defers to DataFusion's cast where that is known +/// to be compatible, and returns an error when a not supported and not DF-compatible cast +/// is requested. +pub fn spark_cast( + arg: ColumnarValue, + data_type: &DataType, + eval_mode: EvalMode, + timezone: String, +) -> DataFusionResult { + match arg { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( + array, + data_type, + eval_mode, + timezone.to_owned(), + )?)), + ColumnarValue::Scalar(scalar) => { + // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for + // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it + // here. + let array = scalar.to_array()?; + let scalar = ScalarValue::try_from_array( + &cast_array(array, data_type, eval_mode, timezone.to_owned())?, + 0, + )?; + Ok(ColumnarValue::Scalar(scalar)) + } } +} - /// Determines if DataFusion supports the given cast in a way that is - /// compatible with Spark - fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { - if from_type == to_type { - return true; +fn cast_array( + array: ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + timezone: String, +) -> DataFusionResult { + let array = array_with_timezone(array, timezone.clone(), Some(to_type))?; + let from_type = array.data_type().clone(); + let array = match &from_type { + DataType::Dictionary(key_type, value_type) + if key_type.as_ref() == &DataType::Int32 + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => + { + let dict_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a dictionary array"); + + let casted_dictionary = DictionaryArray::::new( + dict_array.keys().clone(), + cast_array(dict_array.values().clone(), to_type, eval_mode, timezone)?, + ); + + let casted_result = match to_type { + DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()), + _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, + }; + return Ok(spark_cast_postprocess(casted_result, &from_type, to_type)); } - match from_type { - DataType::Boolean => matches!( - to_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Utf8 - ), - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - // note that the cast from Int32/Int64 -> Decimal128 here is actually - // not compatible with Spark (no overflow checks) but we have tests that - // rely on this cast working so we have to leave it here for now - matches!( - to_type, - DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Utf8 - ) - } - DataType::Float32 | DataType::Float64 => matches!( + _ => array, + }; + let from_type = array.data_type(); + + let cast_result = match (from_type, to_type) { + (DataType::Utf8, DataType::Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + (DataType::LargeUtf8, DataType::Boolean) => { + spark_cast_utf8_to_boolean::(&array, eval_mode) + } + (DataType::Utf8, DataType::Timestamp(_, _)) => { + cast_string_to_timestamp(&array, to_type, eval_mode) + } + (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, to_type, eval_mode), + (DataType::Int64, DataType::Int32) + | (DataType::Int64, DataType::Int16) + | (DataType::Int64, DataType::Int8) + | (DataType::Int32, DataType::Int16) + | (DataType::Int32, DataType::Int8) + | (DataType::Int16, DataType::Int8) + if eval_mode != EvalMode::Try => + { + spark_cast_int_to_int(&array, eval_mode, from_type, to_type) + } + (DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64) => { + cast_string_to_int::(to_type, &array, eval_mode) + } + ( + DataType::LargeUtf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => cast_string_to_int::(to_type, &array, eval_mode), + (DataType::Float64, DataType::Utf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (DataType::Float64, DataType::LargeUtf8) => { + spark_cast_float64_to_utf8::(&array, eval_mode) + } + (DataType::Float32, DataType::Utf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + (DataType::Float32, DataType::LargeUtf8) => { + spark_cast_float32_to_utf8::(&array, eval_mode) + } + (DataType::Float32, DataType::Decimal128(precision, scale)) => { + cast_float32_to_decimal128(&array, *precision, *scale, eval_mode) + } + (DataType::Float64, DataType::Decimal128(precision, scale)) => { + cast_float64_to_decimal128(&array, *precision, *scale, eval_mode) + } + (DataType::Float32, DataType::Int8) + | (DataType::Float32, DataType::Int16) + | (DataType::Float32, DataType::Int32) + | (DataType::Float32, DataType::Int64) + | (DataType::Float64, DataType::Int8) + | (DataType::Float64, DataType::Int16) + | (DataType::Float64, DataType::Int32) + | (DataType::Float64, DataType::Int64) + | (DataType::Decimal128(_, _), DataType::Int8) + | (DataType::Decimal128(_, _), DataType::Int16) + | (DataType::Decimal128(_, _), DataType::Int32) + | (DataType::Decimal128(_, _), DataType::Int64) + if eval_mode != EvalMode::Try => + { + spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type) + } + _ if is_datafusion_spark_compatible(from_type, to_type) => { + // use DataFusion cast only when we know that it is compatible with Spark + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } + _ => { + // we should never reach this code because the Scala code should be checking + // for supported cast operations and falling back to Spark for anything that + // is not yet supported + Err(SparkError::Internal(format!( + "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" + ))) + } + }; + Ok(spark_cast_postprocess(cast_result?, from_type, to_type)) +} + +/// Determines if DataFusion supports the given cast in a way that is +/// compatible with Spark +fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { + if from_type == to_type { + return true; + } + match from_type { + DataType::Boolean => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + ), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + // note that the cast from Int32/Int64 -> Decimal128 here is actually + // not compatible with Spark (no overflow checks) but we have tests that + // rely on this cast working so we have to leave it here for now + matches!( to_type, DataType::Boolean | DataType::Int8 @@ -662,182 +670,180 @@ impl Cast { | DataType::Int64 | DataType::Float32 | DataType::Float64 - ), - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( - to_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - ), - DataType::Utf8 => matches!(to_type, DataType::Binary), - DataType::Date32 => matches!(to_type, DataType::Utf8), - DataType::Timestamp(_, _) => { - matches!( - to_type, - DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) - ) - } - DataType::Binary => { - // note that this is not completely Spark compatible because - // DataFusion only supports binary data containing valid UTF-8 strings - matches!(to_type, DataType::Utf8) - } - _ => false, + | DataType::Utf8 + ) + } + DataType::Float32 | DataType::Float64 => matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ), + DataType::Utf8 => matches!(to_type, DataType::Binary), + DataType::Date32 => matches!(to_type, DataType::Utf8), + DataType::Timestamp(_, _) => { + matches!( + to_type, + DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) + ) } + DataType::Binary => { + // note that this is not completely Spark compatible because + // DataFusion only supports binary data containing valid UTF-8 strings + matches!(to_type, DataType::Utf8) + } + _ => false, } +} - fn cast_string_to_int( - to_type: &DataType, - array: &ArrayRef, - eval_mode: EvalMode, - ) -> SparkResult { - let string_array = array - .as_any() - .downcast_ref::>() - .expect("cast_string_to_int expected a string array"); +fn cast_string_to_int( + to_type: &DataType, + array: &ArrayRef, + eval_mode: EvalMode, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("cast_string_to_int expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?, + DataType::Int16 => { + cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? + } + DataType::Int32 => { + cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? + } + DataType::Int64 => { + cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? + } + dt => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), + }; + Ok(cast_array) +} - let cast_array: ArrayRef = match to_type { - DataType::Int8 => { - cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)? - } - DataType::Int16 => { - cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? - } - DataType::Int32 => { - cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? - } - DataType::Int64 => { - cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? - } - dt => unreachable!( - "{}", - format!("invalid integer type {dt} in cast from string") - ), - }; - Ok(cast_array) +fn cast_string_to_date( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + if to_type != &DataType::Date32 { + unreachable!("Invalid data type {:?} in cast from string", to_type); } - fn cast_string_to_date( - array: &ArrayRef, - to_type: &DataType, - eval_mode: EvalMode, - ) -> SparkResult { - let string_array = array - .as_any() - .downcast_ref::>() - .expect("Expected a string array"); + let len = string_array.len(); + let mut cast_array = PrimitiveArray::::builder(len); - if to_type != &DataType::Date32 { - unreachable!("Invalid data type {:?} in cast from string", to_type); - } + for i in 0..len { + let value = if string_array.is_null(i) { + None + } else { + match date_parser(string_array.value(i), eval_mode) { + Ok(Some(cast_value)) => Some(cast_value), + Ok(None) => None, + Err(e) => return Err(e), + } + }; - let len = string_array.len(); - let mut cast_array = PrimitiveArray::::builder(len); + match value { + Some(cast_value) => cast_array.append_value(cast_value), + None => cast_array.append_null(), + } + } - for i in 0..len { - let value = if string_array.is_null(i) { - None - } else { - match date_parser(string_array.value(i), eval_mode) { - Ok(Some(cast_value)) => Some(cast_value), - Ok(None) => None, - Err(e) => return Err(e), - } - }; + Ok(Arc::new(cast_array.finish()) as ArrayRef) +} - match value { - Some(cast_value) => cast_array.append_value(cast_value), - None => cast_array.append_null(), - } +fn cast_string_to_timestamp( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Timestamp(_, _) => { + cast_utf8_to_timestamp!( + string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser + ) } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) +} - Ok(Arc::new(cast_array.finish()) as ArrayRef) - } +fn cast_float64_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} - fn cast_string_to_timestamp( - array: &ArrayRef, - to_type: &DataType, - eval_mode: EvalMode, - ) -> SparkResult { - let string_array = array - .as_any() - .downcast_ref::>() - .expect("Expected a string array"); +fn cast_float32_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} - let cast_array: ArrayRef = match to_type { - DataType::Timestamp(_, _) => { - cast_utf8_to_timestamp!( - string_array, - eval_mode, - TimestampMicrosecondType, - timestamp_parser - ) - } - _ => unreachable!("Invalid data type {:?} in cast from string", to_type), - }; - Ok(cast_array) - } +fn cast_floating_point_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult +where + ::Native: AsPrimitive, +{ + let input = array.as_any().downcast_ref::>().unwrap(); + let mut cast_array = PrimitiveArray::::builder(input.len()); - fn cast_float64_to_decimal128( - array: &dyn Array, - precision: u8, - scale: i8, - eval_mode: EvalMode, - ) -> SparkResult { - Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) - } + let mul = 10_f64.powi(scale as i32); - fn cast_float32_to_decimal128( - array: &dyn Array, - precision: u8, - scale: i8, - eval_mode: EvalMode, - ) -> SparkResult { - Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) - } + for i in 0..input.len() { + if input.is_null(i) { + cast_array.append_null(); + } else { + let input_value = input.value(i).as_(); + let value = (input_value * mul).round().to_i128(); - fn cast_floating_point_to_decimal128( - array: &dyn Array, - precision: u8, - scale: i8, - eval_mode: EvalMode, - ) -> SparkResult - where - ::Native: AsPrimitive, - { - let input = array.as_any().downcast_ref::>().unwrap(); - let mut cast_array = PrimitiveArray::::builder(input.len()); - - let mul = 10_f64.powi(scale as i32); - - for i in 0..input.len() { - if input.is_null(i) { - cast_array.append_null(); - } else { - let input_value = input.value(i).as_(); - let value = (input_value * mul).round().to_i128(); - - match value { - Some(v) => { - if Decimal128Type::validate_decimal_precision(v, precision).is_err() { - if eval_mode == EvalMode::Ansi { - return Err(SparkError::NumericValueOutOfRange { - value: input_value.to_string(), - precision, - scale, - }); - } else { - cast_array.append_null(); - } - } - cast_array.append_value(v); - } - None => { + match value { + Some(v) => { + if Decimal128Type::validate_decimal_precision(v, precision).is_err() { if eval_mode == EvalMode::Ansi { return Err(SparkError::NumericValueOutOfRange { value: input_value.to_string(), @@ -848,240 +854,252 @@ impl Cast { cast_array.append_null(); } } + cast_array.append_value(v); + } + None => { + if eval_mode == EvalMode::Ansi { + return Err(SparkError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } } } } - - let res = Arc::new( - cast_array - .with_precision_and_scale(precision, scale)? - .finish(), - ) as ArrayRef; - Ok(res) } - fn spark_cast_float64_to_utf8( - from: &dyn Array, - _eval_mode: EvalMode, - ) -> SparkResult - where - OffsetSize: OffsetSizeTrait, - { - cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) - } + let res = Arc::new( + cast_array + .with_precision_and_scale(precision, scale)? + .finish(), + ) as ArrayRef; + Ok(res) +} - fn spark_cast_float32_to_utf8( - from: &dyn Array, - _eval_mode: EvalMode, - ) -> SparkResult - where - OffsetSize: OffsetSizeTrait, - { - cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) - } +fn spark_cast_float64_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) +} - fn spark_cast_int_to_int( - array: &dyn Array, - eval_mode: EvalMode, - from_type: &DataType, - to_type: &DataType, - ) -> SparkResult { - match (from_type, to_type) { - (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( - array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" - ), - (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( - array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" - ), - (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( - array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" - ), - (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( - array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" - ), - (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( - array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" - ), - (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( - array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" - ), - _ => unreachable!( - "{}", - format!("invalid integer type {to_type} in cast from {from_type}") - ), - } +fn spark_cast_float32_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) +} + +fn spark_cast_int_to_int( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" + ), + (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" + ), + (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" + ), + (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" + ), + (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" + ), + (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" + ), + _ => unreachable!( + "{}", + format!("invalid integer type {to_type} in cast from {from_type}") + ), } +} - fn spark_cast_utf8_to_boolean( - from: &dyn Array, - eval_mode: EvalMode, - ) -> SparkResult - where - OffsetSize: OffsetSizeTrait, - { - let array = from - .as_any() - .downcast_ref::>() - .unwrap(); +fn spark_cast_utf8_to_boolean( + from: &dyn Array, + eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); - let output_array = array - .iter() - .map(|value| match value { - Some(value) => match value.to_ascii_lowercase().trim() { - "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), - "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), - _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue { - value: value.to_string(), - from_type: "STRING".to_string(), - to_type: "BOOLEAN".to_string(), - }), - _ => Ok(None), - }, + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), + "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), + _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "BOOLEAN".to_string(), + }), _ => Ok(None), - }) - .collect::>()?; + }, + _ => Ok(None), + }) + .collect::>()?; - Ok(Arc::new(output_array)) - } + Ok(Arc::new(output_array)) +} - fn spark_cast_nonintegral_numeric_to_integral( - array: &dyn Array, - eval_mode: EvalMode, - from_type: &DataType, - to_type: &DataType, - ) -> SparkResult { - match (from_type, to_type) { - (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( - array, - eval_mode, - Float32Array, - Int8Array, - f32, - i8, - "FLOAT", - "TINYINT", - "{:e}" - ), - (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( - array, - eval_mode, - Float32Array, - Int16Array, - f32, - i16, - "FLOAT", - "SMALLINT", - "{:e}" - ), - (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( - array, - eval_mode, - Float32Array, - Int32Array, - f32, - i32, - "FLOAT", - "INT", - i32::MAX, - "{:e}" - ), - (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( - array, - eval_mode, - Float32Array, - Int64Array, - f32, - i64, - "FLOAT", - "BIGINT", - i64::MAX, - "{:e}" - ), - (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( - array, - eval_mode, - Float64Array, - Int8Array, - f64, - i8, - "DOUBLE", - "TINYINT", - "{:e}D" - ), - (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( - array, - eval_mode, - Float64Array, - Int16Array, - f64, - i16, - "DOUBLE", - "SMALLINT", - "{:e}D" - ), - (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( +fn spark_cast_nonintegral_numeric_to_integral( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int8Array, + f32, + i8, + "FLOAT", + "TINYINT", + "{:e}" + ), + (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int16Array, + f32, + i16, + "FLOAT", + "SMALLINT", + "{:e}" + ), + (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int32Array, + f32, + i32, + "FLOAT", + "INT", + i32::MAX, + "{:e}" + ), + (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int64Array, + f32, + i64, + "FLOAT", + "BIGINT", + i64::MAX, + "{:e}" + ), + (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int8Array, + f64, + i8, + "DOUBLE", + "TINYINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int16Array, + f64, + i16, + "DOUBLE", + "SMALLINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int32Array, + f64, + i32, + "DOUBLE", + "INT", + i32::MAX, + "{:e}D" + ), + (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int64Array, + f64, + i64, + "DOUBLE", + "BIGINT", + i64::MAX, + "{:e}D" + ), + (DataType::Decimal128(precision, scale), DataType::Int8) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int16) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int32) => { + cast_decimal_to_int32_up!( array, eval_mode, - Float64Array, Int32Array, - f64, i32, - "DOUBLE", "INT", i32::MAX, - "{:e}D" - ), - (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( + *precision, + *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int64) => { + cast_decimal_to_int32_up!( array, eval_mode, - Float64Array, Int64Array, - f64, i64, - "DOUBLE", "BIGINT", i64::MAX, - "{:e}D" - ), - (DataType::Decimal128(precision, scale), DataType::Int8) => { - cast_decimal_to_int16_down!( - array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale - ) - } - (DataType::Decimal128(precision, scale), DataType::Int16) => { - cast_decimal_to_int16_down!( - array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale - ) - } - (DataType::Decimal128(precision, scale), DataType::Int32) => { - cast_decimal_to_int32_up!( - array, - eval_mode, - Int32Array, - i32, - "INT", - i32::MAX, - *precision, - *scale - ) - } - (DataType::Decimal128(precision, scale), DataType::Int64) => { - cast_decimal_to_int32_up!( - array, - eval_mode, - Int64Array, - i64, - "BIGINT", - i64::MAX, - *precision, - *scale - ) - } - _ => unreachable!( - "{}", - format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}") - ), + *precision, + *scale + ) } + _ => unreachable!( + "{}", + format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}") + ), } } @@ -1294,17 +1312,7 @@ impl PhysicalExpr for Cast { fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array(self.cast_array(array)?)), - ColumnarValue::Scalar(scalar) => { - // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for - // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it - // here. - let array = scalar.to_array()?; - let scalar = ScalarValue::try_from_array(&self.cast_array(array)?, 0)?; - Ok(ColumnarValue::Scalar(scalar)) - } - } + spark_cast(arg, &self.data_type, self.eval_mode, self.timezone.clone()) } fn children(&self) -> Vec<&Arc> { @@ -1660,7 +1668,7 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult> /// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify /// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in /// expressions/cast.rs, so it can be still Dictionary. -fn spark_cast(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef { +fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef { match (from_type, to_type) { (DataType::Timestamp(_, _), DataType::Int64) => { // See Spark's `Cast` expression @@ -1739,8 +1747,6 @@ mod tests { use arrow_array::StringArray; use arrow_schema::TimeUnit; - use datafusion_physical_expr::expressions::Column; - use super::*; #[test] @@ -1819,18 +1825,14 @@ mod tests { ])); let dict_array = Arc::new(DictionaryArray::new(keys, values)); - // prepare cast expression let timezone = "UTC".to_string(); - let expr = Arc::new(Column::new("a", 0)); // this is not used by the test - let cast = Cast::new( - expr, - DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), + // test casting string dictionary array to timestamp array + let result = cast_array( + dict_array, + &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), EvalMode::Legacy, timezone.clone(), - ); - - // test casting string dictionary array to timestamp array - let result = cast.cast_array(dict_array)?; + )?; assert_eq!( *result.data_type(), DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into())) @@ -1912,8 +1914,7 @@ mod tests { Some("2020-01-01T"), ])); - let result = - Cast::cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap(); + let result = cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap(); let date32_array = result .as_any() @@ -1939,7 +1940,7 @@ mod tests { for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { let result = - Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) .unwrap(); let date32_array = result @@ -1971,7 +1972,7 @@ mod tests { for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { let result = - Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) .unwrap(); let date32_array = result @@ -1995,7 +1996,7 @@ mod tests { } let result = - Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi); + cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi); match result { Err(e) => assert!( e.to_string().contains( @@ -2035,26 +2036,24 @@ mod tests { fn test_cast_unsupported_timestamp_to_date() { // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported let timestamps: PrimitiveArray = vec![i64::MAX].into(); - let cast = Cast::new( - Arc::new(Column::new("a", 0)), - DataType::Date32, + let result = cast_array( + Arc::new(timestamps.with_timezone("Europe/Copenhagen")), + &DataType::Date32, EvalMode::Legacy, "UTC".to_owned(), ); - let result = cast.cast_array(Arc::new(timestamps.with_timezone("Europe/Copenhagen"))); assert!(result.is_err()) } #[test] fn test_cast_invalid_timezone() { let timestamps: PrimitiveArray = vec![i64::MAX].into(); - let cast = Cast::new( - Arc::new(Column::new("a", 0)), - DataType::Date32, + let result = cast_array( + Arc::new(timestamps.with_timezone("Europe/Copenhagen")), + &DataType::Date32, EvalMode::Legacy, "Not a valid timezone".to_owned(), ); - let result = cast.cast_array(Arc::new(timestamps.with_timezone("Europe/Copenhagen"))); assert!(result.is_err()) } } diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 336201f48..22628978d 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -24,7 +24,7 @@ mod temporal; pub mod timezone; pub mod utils; -pub use cast::Cast; +pub use cast::{spark_cast, Cast}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr};