diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 63e1c04762..e7c238f7eb 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -99,3 +99,7 @@ harness = false [[test]] name = "test_udf_registration" path = "tests/spark_expr_reg.rs" + +[[bench]] +name = "cast_from_boolean" +harness = false diff --git a/native/spark-expr/benches/cast_from_boolean.rs b/native/spark-expr/benches/cast_from_boolean.rs new file mode 100644 index 0000000000..dbb986df91 --- /dev/null +++ b/native/spark-expr/benches/cast_from_boolean.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{BooleanBuilder, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let expr = Arc::new(Column::new("a", 0)); + let boolean_batch = create_boolean_batch(); + let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); + let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); + let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); + let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options.clone()); + let cast_to_f32 = Cast::new(expr.clone(), DataType::Float32, spark_cast_options.clone()); + let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64, spark_cast_options.clone()); + let cast_to_str = Cast::new(expr.clone(), DataType::Utf8, spark_cast_options.clone()); + let cast_to_decimal = Cast::new(expr, DataType::Decimal128(10, 4), spark_cast_options); + + let mut group = c.benchmark_group("cast_bool".to_string()); + group.bench_function("i8", |b| { + b.iter(|| cast_to_i8.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("i16", |b| { + b.iter(|| cast_to_i16.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("i32", |b| { + b.iter(|| cast_to_i32.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("i64", |b| { + b.iter(|| cast_to_i64.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("f32", |b| { + b.iter(|| cast_to_f32.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("f64", |b| { + b.iter(|| cast_to_f64.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("str", |b| { + b.iter(|| cast_to_str.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("decimal", |b| { + b.iter(|| cast_to_decimal.evaluate(&boolean_batch).unwrap()); + }); +} + +fn create_boolean_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])); + let mut b = BooleanBuilder::with_capacity(1000); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + let array = b.finish(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs new file mode 100644 index 0000000000..db288fa32a --- /dev/null +++ b/native/spark-expr/src/conversion_funcs/boolean.rs @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::SparkResult; +use arrow::array::{ArrayRef, AsArray, Decimal128Array}; +use arrow::datatypes::DataType; +use std::sync::Arc; + +pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool { + use DataType::*; + matches!( + to_type, + Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 + ) +} + +// only DF incompatible boolean cast +pub fn cast_boolean_to_decimal( + array: &ArrayRef, + precision: u8, + scale: i8, +) -> SparkResult { + let bool_array = array.as_boolean(); + let scaled_val = 10_i128.pow(scale as u32); + let result: Decimal128Array = bool_array + .iter() + .map(|v| v.map(|b| if b { scaled_val } else { 0 })) + .collect(); + Ok(Arc::new(result.with_precision_and_scale(precision, scale)?)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cast::cast_array; + use crate::{EvalMode, SparkCastOptions}; + use arrow::array::{ + Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, StringArray, + }; + use arrow::datatypes::DataType::Decimal128; + use std::sync::Arc; + + fn test_input_bool_array() -> ArrayRef { + Arc::new(BooleanArray::from(vec![Some(true), Some(false), None])) + } + + fn test_input_spark_opts() -> SparkCastOptions { + SparkCastOptions::new(EvalMode::Legacy, "Asia/Kolkata", false) + } + + #[test] + fn test_is_df_cast_from_bool_spark_compatible() { + assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Boolean)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int8)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int16)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int32)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int64)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float32)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float64)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Utf8)); + assert!(!is_df_cast_from_bool_spark_compatible( + &DataType::Decimal128(10, 4) + )); + assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Null)); + } + + #[test] + fn test_bool_to_int8_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Int8, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_int16_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Int16, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_int32_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Int32, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_int64_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Int64, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_float32_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Float32, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1.0); + assert_eq!(arr.value(1), 0.0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_float64_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Float64, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1.0); + assert_eq!(arr.value(1), 0.0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_string_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Utf8, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "true"); + assert_eq!(arr.value(1), "false"); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_decimal_cast() { + let result = cast_array( + test_input_bool_array(), + &Decimal128(10, 4), + &test_input_spark_opts(), + ) + .unwrap(); + let expected_arr = Decimal128Array::from(vec![10000_i128, 0_i128]) + .with_precision_and_scale(10, 4) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), expected_arr.value(0)); + assert_eq!(arr.value(1), expected_arr.value(1)); + assert!(arr.is_null(2)); + } +} diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index f5ab83b8a5..004668b8f2 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +use crate::conversion_funcs::boolean::{ + cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible, +}; +use crate::conversion_funcs::utils::spark_cast_postprocess; +use crate::conversion_funcs::utils::{cast_overflow, invalid_value}; use crate::utils::array_with_timezone; use crate::EvalMode::Legacy; use crate::{timezone, BinaryOutputStyle}; @@ -37,7 +42,7 @@ use arrow::{ GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, }, - compute::{cast_with_options, take, unary, CastOptions}, + compute::{cast_with_options, take, CastOptions}, datatypes::{ is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type, Float64Type, Int64Type, TimestampMicrosecondType, @@ -48,16 +53,10 @@ use arrow::{ }; use base64::prelude::*; use chrono::{DateTime, NaiveDate, TimeZone, Timelike}; -use datafusion::common::{ - cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult, - ScalarValue, -}; +use datafusion::common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ColumnarValue; -use num::{ - cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, ToPrimitive, - Zero, -}; +use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, ToPrimitive, Zero}; use regex::Regex; use std::str::FromStr; use std::{ @@ -70,7 +69,7 @@ use std::{ static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); -const MICROS_PER_SECOND: i64 = 1000000; +pub(crate) const MICROS_PER_SECOND: i64 = 1000000; static CAST_OPTIONS: CastOptions = CastOptions { safe: true, @@ -776,7 +775,7 @@ fn dict_from_values( Ok(Arc::new(dict_array)) } -fn cast_array( +pub(crate) fn cast_array( array: ArrayRef, to_type: &DataType, cast_options: &SparkCastOptions, @@ -1018,16 +1017,6 @@ fn cast_date_to_timestamp( )) } -fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> SparkResult { - let bool_array = array.as_boolean(); - let scaled_val = 10_i128.pow(scale as u32); - let result: Decimal128Array = bool_array - .iter() - .map(|v| v.map(|b| if b { scaled_val } else { 0 })) - .collect(); - Ok(Arc::new(result.with_precision_and_scale(precision, scale)?)) -} - fn cast_string_to_float( array: &ArrayRef, to_type: &DataType, @@ -1186,16 +1175,7 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b DataType::Null => { matches!(to_type, DataType::List(_)) } - DataType::Boolean => matches!( - to_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Utf8 - ), + DataType::Boolean => is_df_cast_from_bool_spark_compatible(to_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { matches!( to_type, @@ -2437,24 +2417,6 @@ fn parse_decimal_str( Ok((final_mantissa, final_scale)) } -#[inline] -fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { - SparkError::CastInvalidValue { - value: value.to_string(), - from_type: from_type.to_string(), - to_type: to_type.to_string(), - } -} - -#[inline] -fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError { - SparkError::CastOverFlow { - value: value.to_string(), - from_type: from_type.to_string(), - to_type: to_type.to_string(), - } -} - impl Display for Cast { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( @@ -2852,84 +2814,6 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult> } } -/// This takes for special casting cases of Spark. E.g., Timestamp to Long. -/// This function runs as a post process of the DataFusion cast(). By the time it arrives here, -/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify -/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in -/// expressions/cast.rs, so it can be still Dictionary. -fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef { - match (from_type, to_type) { - (DataType::Timestamp(_, _), DataType::Int64) => { - // See Spark's `Cast` expression - unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() - } - (DataType::Dictionary(_, value_type), DataType::Int64) - if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => - { - // See Spark's `Cast` expression - unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() - } - (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array), - (DataType::Dictionary(_, value_type), DataType::Utf8) - if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => - { - remove_trailing_zeroes(array) - } - _ => array, - } -} - -/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated -fn unary_dyn(array: &ArrayRef, op: F) -> Result -where - T: ArrowPrimitiveType, - F: Fn(T::Native) -> T::Native, -{ - if let Some(d) = array.as_any_dictionary_opt() { - let new_values = unary_dyn::(d.values(), op)?; - return Ok(Arc::new(d.with_values(Arc::new(new_values)))); - } - - match array.as_primitive_opt::() { - Some(a) if PrimitiveArray::::is_compatible(a.data_type()) => { - Ok(Arc::new(unary::( - array.as_any().downcast_ref::>().unwrap(), - op, - ))) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation of type {} on array of type {}", - T::DATA_TYPE, - array.data_type() - ))), - } -} - -/// Remove any trailing zeroes in the string if they occur after in the fractional seconds, -/// to match Spark behavior -/// example: -/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9" -/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99" -/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999" -/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00" -/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001" -fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef { - let string_array = as_generic_string_array::(&array).unwrap(); - let result = string_array - .iter() - .map(|s| s.map(trim_end)) - .collect::>(); - Arc::new(result) as ArrayRef -} - -fn trim_end(s: &str) -> &str { - if s.rfind('.').is_some() { - s.trim_end_matches('0') - } else { - s - } -} - #[cfg(test)] mod tests { use arrow::array::StringArray; diff --git a/native/spark-expr/src/conversion_funcs/mod.rs b/native/spark-expr/src/conversion_funcs/mod.rs index f2c6f7ca36..190c115204 100644 --- a/native/spark-expr/src/conversion_funcs/mod.rs +++ b/native/spark-expr/src/conversion_funcs/mod.rs @@ -15,4 +15,6 @@ // specific language governing permissions and limitations // under the License. +mod boolean; pub mod cast; +mod utils; diff --git a/native/spark-expr/src/conversion_funcs/utils.rs b/native/spark-expr/src/conversion_funcs/utils.rs new file mode 100644 index 0000000000..8b8d974ffe --- /dev/null +++ b/native/spark-expr/src/conversion_funcs/utils.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::MICROS_PER_SECOND; +use crate::SparkError; +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, +}; +use arrow::compute::unary; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::error::ArrowError; +use datafusion::common::cast::as_generic_string_array; +use num::integer::div_floor; +use std::sync::Arc; + +/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated +pub fn unary_dyn(array: &ArrayRef, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + if let Some(d) = array.as_any_dictionary_opt() { + let new_values = unary_dyn::(d.values(), op)?; + return Ok(Arc::new(d.with_values(Arc::new(new_values)))); + } + + match array.as_primitive_opt::() { + Some(a) if PrimitiveArray::::is_compatible(a.data_type()) => { + Ok(Arc::new(unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + ))) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + array.data_type() + ))), + } +} + +/// This takes for special casting cases of Spark. E.g., Timestamp to Long. +/// This function runs as a post process of the DataFusion cast(). By the time it arrives here, +/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify +/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in +/// expressions/cast.rs, so it can be still Dictionary. +pub fn spark_cast_postprocess( + array: ArrayRef, + from_type: &DataType, + to_type: &DataType, +) -> ArrayRef { + match (from_type, to_type) { + (DataType::Timestamp(_, _), DataType::Int64) => { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() + } + (DataType::Dictionary(_, value_type), DataType::Int64) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() + } + (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array), + (DataType::Dictionary(_, value_type), DataType::Utf8) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + remove_trailing_zeroes(array) + } + _ => array, + } +} + +/// Remove any trailing zeroes in the string if they occur after in the fractional seconds, +/// to match Spark behavior +/// example: +/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9" +/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99" +/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999" +/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00" +/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001" +fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef { + let string_array = as_generic_string_array::(&array).unwrap(); + let result = string_array + .iter() + .map(|s| s.map(trim_end)) + .collect::>(); + Arc::new(result) as ArrayRef +} + +fn trim_end(s: &str) -> &str { + if s.rfind('.').is_some() { + s.trim_end_matches('0') + } else { + s + } +} + +#[inline] +pub fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastOverFlow { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} + +#[inline] +pub fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastInvalidValue { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +}