diff --git a/datafusion/spark/src/function/bitwise/bit_shift.rs b/datafusion/spark/src/function/bitwise/bit_shift.rs new file mode 100644 index 000000000000..79f62587c0dd --- /dev/null +++ b/datafusion/spark/src/function/bitwise/bit_shift.rs @@ -0,0 +1,740 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; +use arrow::compute; +use arrow::datatypes::{ + ArrowNativeType, DataType, Int32Type, Int64Type, UInt32Type, UInt64Type, +}; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; + +/// Performs a bitwise left shift on each element of the `value` array by the corresponding amount in the `shift` array. +/// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts. +/// +/// # Arguments +/// * `value` - The array of values to shift. +/// * `shift` - The array of shift amounts (must be Int32). +/// +/// # Returns +/// A new array with the shifted values. +/// +fn shift_left( + value: &PrimitiveArray, + shift: &PrimitiveArray, +) -> Result> +where + T::Native: ArrowNativeType + std::ops::Shl, +{ + let bit_num = (T::Native::get_byte_width() * 8) as i32; + let result = compute::binary::<_, Int32Type, _, _>( + value, + shift, + |value: T::Native, shift: i32| { + let shift = ((shift % bit_num) + bit_num) % bit_num; + value << shift + }, + )?; + Ok(result) +} + +/// Performs a bitwise right shift on each element of the `value` array by the corresponding amount in the `shift` array. +/// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts. +/// +/// # Arguments +/// * `value` - The array of values to shift. +/// * `shift` - The array of shift amounts (must be Int32). +/// +/// # Returns +/// A new array with the shifted values. +/// +fn shift_right( + value: &PrimitiveArray, + shift: &PrimitiveArray, +) -> Result> +where + T::Native: ArrowNativeType + std::ops::Shr, +{ + let bit_num = (T::Native::get_byte_width() * 8) as i32; + let result = compute::binary::<_, Int32Type, _, _>( + value, + shift, + |value: T::Native, shift: i32| { + let shift = ((shift % bit_num) + bit_num) % bit_num; + value >> shift + }, + )?; + Ok(result) +} + +/// Trait for performing an unsigned right shift (logical shift right). +/// This is used to mimic Java's `>>>` operator, which does not exist in Rust. +/// For unsigned types, this is just the normal right shift. +/// For signed types, this casts to the unsigned type, shifts, then casts back. +trait UShr { + fn ushr(self, rhs: Rhs) -> Self; +} + +impl UShr for u32 { + fn ushr(self, rhs: i32) -> Self { + self >> rhs + } +} + +impl UShr for u64 { + fn ushr(self, rhs: i32) -> Self { + self >> rhs + } +} + +impl UShr for i32 { + fn ushr(self, rhs: i32) -> Self { + ((self as u32) >> rhs) as i32 + } +} + +impl UShr for i64 { + fn ushr(self, rhs: i32) -> Self { + ((self as u64) >> rhs) as i64 + } +} + +/// Performs a bitwise unsigned right shift on each element of the `value` array by the corresponding amount in the `shift` array. +/// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts. +/// +/// # Arguments +/// * `value` - The array of values to shift. +/// * `shift` - The array of shift amounts (must be Int32). +/// +/// # Returns +/// A new array with the shifted values. +/// +fn shift_right_unsigned( + value: &PrimitiveArray, + shift: &PrimitiveArray, +) -> Result> +where + T::Native: ArrowNativeType + UShr, +{ + let bit_num = (T::Native::get_byte_width() * 8) as i32; + let result = compute::binary::<_, Int32Type, _, _>( + value, + shift, + |value: T::Native, shift: i32| { + let shift = ((shift % bit_num) + bit_num) % bit_num; + value.ushr(shift) + }, + )?; + Ok(result) +} + +trait BitShiftUDF: ScalarUDFImpl { + fn shift( + &self, + value: &PrimitiveArray, + shift: &PrimitiveArray, + ) -> Result> + where + T::Native: ArrowNativeType + + std::ops::Shl + + std::ops::Shr + + UShr; + + fn spark_shift(&self, arrays: &[ArrayRef]) -> Result { + let value_array = arrays[0].as_ref(); + let shift_array = arrays[1].as_ref(); + + // Ensure shift array is Int32 + let shift_array = if shift_array.data_type() != &DataType::Int32 { + return plan_err!("{} shift amount must be Int32", self.name()); + } else { + shift_array.as_primitive::() + }; + + match value_array.data_type() { + DataType::Int32 => { + let value_array = value_array.as_primitive::(); + Ok(Arc::new(self.shift(value_array, shift_array)?)) + } + DataType::Int64 => { + let value_array = value_array.as_primitive::(); + Ok(Arc::new(self.shift(value_array, shift_array)?)) + } + DataType::UInt32 => { + let value_array = value_array.as_primitive::(); + Ok(Arc::new(self.shift(value_array, shift_array)?)) + } + DataType::UInt64 => { + let value_array = value_array.as_primitive::(); + Ok(Arc::new(self.shift(value_array, shift_array)?)) + } + _ => { + plan_err!( + "{} function does not support data type: {:?}", + self.name(), + value_array.data_type() + ) + } + } + } +} + +fn bit_shift_coerce_types(arg_types: &[DataType], func: &str) -> Result> { + if arg_types.len() != 2 { + return Err(invalid_arg_count_exec_err(func, (2, 2), arg_types.len())); + } + if !arg_types[0].is_integer() && !arg_types[0].is_null() { + return Err(unsupported_data_type_exec_err( + func, + "Integer Type", + &arg_types[0], + )); + } + if !arg_types[1].is_integer() && !arg_types[1].is_null() { + return Err(unsupported_data_type_exec_err( + func, + "Integer Type", + &arg_types[1], + )); + } + + // Coerce smaller integer types to Int32 + let coerced_first = match &arg_types[0] { + DataType::Int8 | DataType::Int16 | DataType::Null => DataType::Int32, + DataType::UInt8 | DataType::UInt16 => DataType::UInt32, + _ => arg_types[0].clone(), + }; + + Ok(vec![coerced_first, DataType::Int32]) +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkShiftLeft { + signature: Signature, +} + +impl Default for SparkShiftLeft { + fn default() -> Self { + Self::new() + } +} + +impl SparkShiftLeft { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl BitShiftUDF for SparkShiftLeft { + fn shift( + &self, + value: &PrimitiveArray, + shift: &PrimitiveArray, + ) -> Result> + where + T::Native: ArrowNativeType + + std::ops::Shl + + std::ops::Shr + + UShr, + { + shift_left(value, shift) + } +} + +impl ScalarUDFImpl for SparkShiftLeft { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "shiftleft" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + bit_shift_coerce_types(arg_types, "shiftleft") + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return plan_err!("shiftleft expects exactly 2 arguments"); + } + // Return type is the same as the first argument (the value to shift) + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 2 { + return plan_err!("shiftleft expects exactly 2 arguments"); + } + let inner = |arr: &[ArrayRef]| -> Result { self.spark_shift(arr) }; + make_scalar_function(inner, vec![])(&args.args) + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkShiftRightUnsigned { + signature: Signature, +} + +impl Default for SparkShiftRightUnsigned { + fn default() -> Self { + Self::new() + } +} + +impl SparkShiftRightUnsigned { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl BitShiftUDF for SparkShiftRightUnsigned { + fn shift( + &self, + value: &PrimitiveArray, + shift: &PrimitiveArray, + ) -> Result> + where + T::Native: ArrowNativeType + + std::ops::Shl + + std::ops::Shr + + UShr, + { + shift_right_unsigned(value, shift) + } +} + +impl ScalarUDFImpl for SparkShiftRightUnsigned { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "shiftrightunsigned" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + bit_shift_coerce_types(arg_types, "shiftrightunsigned") + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return plan_err!("shiftrightunsigned expects exactly 2 arguments"); + } + // Return type is the same as the first argument (the value to shift) + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 2 { + return plan_err!("shiftrightunsigned expects exactly 2 arguments"); + } + let inner = |arr: &[ArrayRef]| -> Result { self.spark_shift(arr) }; + make_scalar_function(inner, vec![])(&args.args) + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkShiftRight { + signature: Signature, +} + +impl Default for SparkShiftRight { + fn default() -> Self { + Self::new() + } +} + +impl SparkShiftRight { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl BitShiftUDF for SparkShiftRight { + fn shift( + &self, + value: &PrimitiveArray, + shift: &PrimitiveArray, + ) -> Result> + where + T::Native: ArrowNativeType + + std::ops::Shl + + std::ops::Shr + + UShr, + { + shift_right(value, shift) + } +} + +impl ScalarUDFImpl for SparkShiftRight { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "shiftright" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + bit_shift_coerce_types(arg_types, "shiftright") + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return plan_err!("shiftright expects exactly 2 arguments"); + } + // Return type is the same as the first argument (the value to shift) + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 2 { + return plan_err!("shiftright expects exactly 2 arguments"); + } + let inner = |arr: &[ArrayRef]| -> Result { self.spark_shift(arr) }; + make_scalar_function(inner, vec![])(&args.args) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Int32Array, Int64Array, UInt32Array, UInt64Array}; + + #[test] + fn test_shift_right_unsigned_int32() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16, 32])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2 + assert_eq!(arr.value(3), 2); // 32 >>> 4 = 2 + } + + #[test] + fn test_shift_right_unsigned_int64() { + let value_array = Arc::new(Int64Array::from(vec![4i64, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2 + } + + #[test] + fn test_shift_right_unsigned_uint32() { + let value_array = Arc::new(UInt32Array::from(vec![4u32, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2 + } + + #[test] + fn test_shift_right_unsigned_uint64() { + let value_array = Arc::new(UInt64Array::from(vec![4u64, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2 + } + + #[test] + fn test_shift_right_unsigned_nulls() { + let value_array = Arc::new(Int32Array::from(vec![Some(4), None, Some(8)])); + let shift_array = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert!(arr.is_null(1)); // null >>> 2 = null + assert!(arr.is_null(2)); // 8 >>> null = null + } + + #[test] + fn test_shift_right_unsigned_negative_shift() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![-1, -2, -3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 4 >>> -1 = 0 + assert_eq!(arr.value(1), 0); // 8 >>> -2 = 0 + assert_eq!(arr.value(2), 0); // 16 >>> -3 = 0 + } + + #[test] + fn test_shift_right_unsigned_negative_values() { + let value_array = Arc::new(Int32Array::from(vec![-4, -8, -16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + // For unsigned right shift, negative values are treated as large positive values + // -4 as u32 = 4294967292, -4 >>> 1 = 2147483646 + assert_eq!(arr.value(0), 2147483646); + // -8 as u32 = 4294967288, -8 >>> 2 = 1073741822 + assert_eq!(arr.value(1), 1073741822); + // -16 as u32 = 4294967280, -16 >>> 3 = 536870910 + assert_eq!(arr.value(2), 536870910); + } + + #[test] + fn test_shift_right_int32() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16, 32])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >> 3 = 2 + assert_eq!(arr.value(3), 2); // 32 >> 4 = 2 + } + + #[test] + fn test_shift_right_int64() { + let value_array = Arc::new(Int64Array::from(vec![4i64, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >> 3 = 2 + } + + #[test] + fn test_shift_right_uint32() { + let value_array = Arc::new(UInt32Array::from(vec![4u32, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >> 3 = 2 + } + + #[test] + fn test_shift_right_uint64() { + let value_array = Arc::new(UInt64Array::from(vec![4u64, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >> 3 = 2 + } + + #[test] + fn test_shift_right_nulls() { + let value_array = Arc::new(Int32Array::from(vec![Some(4), None, Some(8)])); + let shift_array = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert!(arr.is_null(1)); // null >> 2 = null + assert!(arr.is_null(2)); // 8 >> null = null + } + + #[test] + fn test_shift_right_large_shift() { + let value_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![32, 33, 64])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 1); // 1 >> 32 = 1 + assert_eq!(arr.value(1), 1); // 2 >> 33 = 1 + assert_eq!(arr.value(2), 3); // 3 >> 64 = 3 + } + + #[test] + fn test_shift_right_negative_shift() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![-1, -2, -3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 4 >> -1 = 0 + assert_eq!(arr.value(1), 0); // 8 >> -2 = 0 + assert_eq!(arr.value(2), 0); // 16 >> -3 = 0 + } + + #[test] + fn test_shift_right_negative_values() { + let value_array = Arc::new(Int32Array::from(vec![-4, -8, -16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + // For signed integers, right shift preserves the sign bit + assert_eq!(arr.value(0), -2); // -4 >> 1 = -2 + assert_eq!(arr.value(1), -2); // -8 >> 2 = -2 + assert_eq!(arr.value(2), -2); // -16 >> 3 = -2 + } + + #[test] + fn test_shift_left_int32() { + let value_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 1 << 1 = 2 + assert_eq!(arr.value(1), 8); // 2 << 2 = 8 + assert_eq!(arr.value(2), 24); // 3 << 3 = 24 + assert_eq!(arr.value(3), 64); // 4 << 4 = 64 + } + + #[test] + fn test_shift_left_int64() { + let value_array = Arc::new(Int64Array::from(vec![1i64, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 1 << 1 = 2 + assert_eq!(arr.value(1), 8); // 2 << 2 = 8 + assert_eq!(arr.value(2), 24); // 3 << 3 = 24 + } + + #[test] + fn test_shift_left_uint32() { + let value_array = Arc::new(UInt32Array::from(vec![1u32, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 1 << 1 = 2 + assert_eq!(arr.value(1), 8); // 2 << 2 = 8 + assert_eq!(arr.value(2), 24); // 3 << 3 = 24 + } + + #[test] + fn test_shift_left_uint64() { + let value_array = Arc::new(UInt64Array::from(vec![1u64, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 1 << 1 = 2 + assert_eq!(arr.value(1), 8); // 2 << 2 = 8 + assert_eq!(arr.value(2), 24); // 3 << 3 = 24 + } + + #[test] + fn test_shift_left_nulls() { + let value_array = Arc::new(Int32Array::from(vec![Some(2), None, Some(3)])); + let shift_array = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 4); // 2 << 1 = 4 + assert!(arr.is_null(1)); // null << 2 = null + assert!(arr.is_null(2)); // 3 << null = null + } + + #[test] + fn test_shift_left_large_shift() { + let value_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![32, 33, 64])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 1); // 1 << 32 = 0 (overflow) + assert_eq!(arr.value(1), 4); // 2 << 33 = 0 (overflow) + assert_eq!(arr.value(2), 3); // 3 << 64 = 0 (overflow) + } + + #[test] + fn test_shift_left_negative_shift() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![-1, -2, -3])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 4 << -1 = 0 + assert_eq!(arr.value(1), 0); // 8 << -2 = 0 + assert_eq!(arr.value(2), 0); // 16 << -3 = 0 + } +} diff --git a/datafusion/spark/src/function/bitwise/mod.rs b/datafusion/spark/src/function/bitwise/mod.rs index f8131176ff31..b5603c191440 100644 --- a/datafusion/spark/src/function/bitwise/mod.rs +++ b/datafusion/spark/src/function/bitwise/mod.rs @@ -17,11 +17,15 @@ pub mod bit_count; pub mod bit_get; +pub mod bit_shift; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(bit_shift::SparkShiftLeft, shiftleft); +make_udf_function!(bit_shift::SparkShiftRight, shiftright); +make_udf_function!(bit_shift::SparkShiftRightUnsigned, shiftrightunsigned); make_udf_function!(bit_get::SparkBitGet, bit_get); make_udf_function!(bit_count::SparkBitCount, bit_count); @@ -34,8 +38,29 @@ pub mod expr_fn { "Returns the number of bits set in the binary representation of the argument.", col )); + export_functions!(( + shiftleft, + "Shifts the bits of the first argument left by the number of positions specified by the second argument. If the shift amount is negative or greater than or equal to the bit width, it is normalized to the bit width (i.e., pmod(shift, bit_width)).", + value shift + )); + export_functions!(( + shiftright, + "Shifts the bits of the first argument right by the number of positions specified by the second argument (arithmetic/signed shift). If the shift amount is negative or greater than or equal to the bit width, it is normalized to the bit width (i.e., pmod(shift, bit_width)).", + value shift + )); + export_functions!(( + shiftrightunsigned, + "Shifts the bits of the first argument right by the number of positions specified by the second argument (logical/unsigned shift). If the shift amount is negative or greater than or equal to the bit width, it is normalized to the bit width (i.e., pmod(shift, bit_width)).", + value shift + )); } pub fn functions() -> Vec> { - vec![bit_get(), bit_count()] + vec![ + bit_get(), + bit_count(), + shiftleft(), + shiftright(), + shiftrightunsigned(), + ] } diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt b/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt index 2956fd5a6333..3587bcc7ca52 100644 --- a/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt +++ b/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt @@ -23,5 +23,125 @@ ## Original Query: SELECT shiftright(4, 1); ## PySpark 3.5.5 Result: {'shiftright(4, 1)': 2, 'typeof(shiftright(4, 1))': 'int', 'typeof(4)': 'int', 'typeof(1)': 'int'} -#query -#SELECT shiftright(4::int, 1::int); + +# Basic shiftright tests +query I +SELECT shiftright(4::int, 1::int); +---- +2 + +query I +SELECT shiftright(8::int, 2::int); +---- +2 + +query I +SELECT shiftright(16::int, 3::int); +---- +2 + +# Different data types +query I +SELECT shiftright(4::bigint, 1::int); +---- +2 + +query I +SELECT shiftright(8::bigint, 2::int); +---- +2 + +query I +SELECT shiftright(4::int, 1::bigint); +---- +2 + +# Large shifts (should handle modulo correctly) +query I +SELECT shiftright(1::int, 32::int); +---- +1 + +query I +SELECT shiftright(2::int, 33::int); +---- +1 + +query I +SELECT shiftright(3::int, 64::int); +---- +3 + +# Negative shifts +query I +SELECT shiftright(4::int, -1::int); +---- +0 + +query I +SELECT shiftright(8::int, -2::int); +---- +0 + +query I +SELECT shiftright(16::int, -3::int); +---- +0 + +# Zero shifts +query I +SELECT shiftright(5::int, 0::int); +---- +5 + +query I +SELECT shiftright(0::int, 5::int); +---- +0 + +# Edge cases - signed right shift preserves sign +query I +SELECT shiftright(-4::int, 1::int); +---- +-2 + +query I +SELECT shiftright(-8::int, 2::int); +---- +-2 + +query I +SELECT shiftright(-16::int, 3::int); +---- +-2 + +query I +SELECT shiftright(2147483647::int, 1::int); +---- +1073741823 + +# Null handling +query I +SELECT shiftright(NULL::int, 1::int); +---- +NULL + +query I +SELECT shiftright(1::int, NULL::int); +---- +NULL + +query I +SELECT shiftright(NULL::int, NULL::int); +---- +NULL + +query I +select shiftright(3::int,-31); +---- +1 + +query I +select shiftright(3::int,-32); +---- +3 diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt b/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt index 134246957e3c..b0d4cfaec702 100644 --- a/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt +++ b/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt @@ -23,5 +23,126 @@ ## Original Query: SELECT shiftrightunsigned(4, 1); ## PySpark 3.5.5 Result: {'shiftrightunsigned(4, 1)': 2, 'typeof(shiftrightunsigned(4, 1))': 'int', 'typeof(4)': 'int', 'typeof(1)': 'int'} -#query -#SELECT shiftrightunsigned(4::int, 1::int); + +# Basic shiftrightunsigned tests +query I +SELECT shiftrightunsigned(4::int, 1::int); +---- +2 + +query I +SELECT shiftrightunsigned(8::int, 2::int); +---- +2 + +query I +SELECT shiftrightunsigned(16::int, 3::int); +---- +2 + +# Different data types +query I +SELECT shiftrightunsigned(4::bigint, 1::int); +---- +2 + +query I +SELECT shiftrightunsigned(8::bigint, 2::int); +---- +2 + +query I +SELECT shiftrightunsigned(4::int, 1::bigint); +---- +2 + +# Large shifts (should handle modulo correctly) +query I +SELECT shiftrightunsigned(1::int, 32::int); +---- +1 + +query I +SELECT shiftrightunsigned(2::int, 33::int); +---- +1 + +query I +SELECT shiftrightunsigned(3::int, 64::int); +---- +3 + +# Negative shifts +query I +SELECT shiftrightunsigned(4::int, -1::int); +---- +0 + +query I +SELECT shiftrightunsigned(8::int, -2::int); +---- +0 + +query I +SELECT shiftrightunsigned(16::int, -3::int); +---- +0 + +# Zero shifts +query I +SELECT shiftrightunsigned(5::int, 0::int); +---- +5 + +query I +SELECT shiftrightunsigned(0::int, 5::int); +---- +0 + +# Edge cases - unsigned right shift treats negative values as large positive +query I +SELECT shiftrightunsigned(-4::int, 1::int); +---- +2147483646 + +query I +SELECT shiftrightunsigned(-8::int, 2::int); +---- +1073741822 + +query I +SELECT shiftrightunsigned(-16::int, 3::int); +---- +536870910 + +query I +SELECT shiftrightunsigned(2147483647::int, 1::int); +---- +1073741823 + + +# Null handling +query I +SELECT shiftrightunsigned(NULL::int, 1::int); +---- +NULL + +query I +SELECT shiftrightunsigned(1::int, NULL::int); +---- +NULL + +query I +SELECT shiftrightunsigned(NULL::int, NULL::int); +---- +NULL + +query I +select shiftrightunsigned(3::int,-31); +---- +1 + +query I +select shiftrightunsigned(3::int,-32); +---- +3 diff --git a/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt b/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt index 69a8e85565e9..3676e4c18153 100644 --- a/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt +++ b/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt @@ -23,5 +23,124 @@ ## Original Query: SELECT shiftleft(2, 1); ## PySpark 3.5.5 Result: {'shiftleft(2, 1)': 4, 'typeof(shiftleft(2, 1))': 'int', 'typeof(2)': 'int', 'typeof(1)': 'int'} -#query -#SELECT shiftleft(2::int, 1::int); + +# Basic shiftleft tests +query I +SELECT shiftleft(2::int, 1::int); +---- +4 + +query I +SELECT shiftleft(1::int, 2::int); +---- +4 + +query I +SELECT shiftleft(3::int, 3::int); +---- +24 + +# Different data types +query I +SELECT shiftleft(2::bigint, 1::int); +---- +4 + +query I +SELECT shiftleft(1::bigint, 2::int); +---- +4 + +query I +SELECT shiftleft(2::int, 1::bigint); +---- +4 + +# Large shifts (should handle modulo correctly) +query I +SELECT shiftleft(1::int, 32::int); +---- +1 + +query I +SELECT shiftleft(2::int, 33::int); +---- +4 + +query I +SELECT shiftleft(3::int, 64::int); +---- +3 + +# Negative shifts +query I +SELECT shiftleft(4::int, -1::int); +---- +0 + +query I +SELECT shiftleft(8::int, -2::int); +---- +0 + +query I +SELECT shiftleft(16::int, -3::int); +---- +0 + +# Zero shifts +query I +SELECT shiftleft(5::int, 0::int); +---- +5 + +query I +SELECT shiftleft(0::int, 5::int); +---- +0 + +# Edge cases +query I +SELECT shiftleft(2147483647::int, 1::int); +---- +-2 + +query I +SELECT shiftleft(-1::int, 1::int); +---- +-2 + +# Multiple values in a table +query I +SELECT shiftleft(value, shift) FROM (VALUES (1, 1), (2, 2), (3, 3), (4, 4)) AS t(value, shift); +---- +2 +8 +24 +64 + +# Null handling +query I +SELECT shiftleft(NULL::int, 1::int); +---- +NULL + +query I +SELECT shiftleft(1::int, NULL::int); +---- +NULL + +query I +SELECT shiftleft(NULL::int, NULL::int); +---- +NULL + +query I +select shiftleft(3::int,-31); +---- +6 + +query I +select shiftleft(3::int,-32); +---- +3