-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Refactor Spark bitshift signature #18649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| /// Performs a bitwise left shift on each element of the `value` array by the corresponding amount in the `shift` array. | ||
| /// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts. | ||
| /// | ||
| /// # Arguments | ||
| /// * `value` - The array of values to shift. | ||
| /// * `shift` - The array of shift amounts (must be Int32). | ||
| /// | ||
| /// # Returns | ||
| /// A new array with the shifted values. | ||
| fn shift_left<T: ArrowPrimitiveType>( | ||
| /// Bitwise left shift on elements in `value` by corresponding `shift` amount. | ||
| /// The shift amount is normalized to the bit width of the type, matching Spark/Java | ||
| /// semantics for negative and large shifts. | ||
| fn shift_left<T>( | ||
| value: &PrimitiveArray<T>, | ||
| shift: &PrimitiveArray<Int32Type>, | ||
| shift: &Int32Array, | ||
| ) -> Result<PrimitiveArray<T>> | ||
| where | ||
| T::Native: ArrowNativeType + std::ops::Shl<i32, Output = T::Native>, | ||
| T: ArrowPrimitiveType, | ||
| T::Native: std::ops::Shl<i32, Output = T::Native>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewrote some comments to be more succinct, and also cleanup some function signatures (remove some unused bounds, use type aliases, etc.)
| } | ||
|
|
||
| #[derive(Debug, Hash, Eq, PartialEq)] | ||
| pub struct SparkShiftLeft { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Folded these structs into a single common SparkBitShift struct
| ); | ||
| Self { | ||
| signature: Signature::user_defined(Volatility::Immutable), | ||
| signature: Signature::one_of( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signature here
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved these all to SLTs
| std::sync::Arc::clone(&INSTANCE) | ||
| } | ||
| }; | ||
| ($UDF:ty, $NAME:ident, $CTOR:path) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to accommodate being able to use a single struct (e.g. SparkBitShift) for multiple different functions; similar to how allow for window functions:
datafusion/datafusion/functions-window/src/macros.rs
Lines 96 to 116 in e661b33
| macro_rules! get_or_init_udwf { | |
| ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { | |
| get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $UDWF::default); | |
| }; | |
| ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr, $CTOR:path) => { | |
| paste::paste! { | |
| #[doc = concat!(" Returns a [`WindowUDF`](datafusion_expr::WindowUDF) for [`", stringify!($OUT_FN_NAME), "`].")] | |
| #[doc = ""] | |
| #[doc = concat!(" ", $DOC)] | |
| pub fn [<$OUT_FN_NAME _udwf>]() -> std::sync::Arc<datafusion_expr::WindowUDF> { | |
| // Singleton instance of UDWF, ensures it is only created once. | |
| static INSTANCE: std::sync::LazyLock<std::sync::Arc<datafusion_expr::WindowUDF>> = | |
| std::sync::LazyLock::new(|| { | |
| std::sync::Arc::new(datafusion_expr::WindowUDF::from($CTOR())) | |
| }); | |
| std::sync::Arc::clone(&INSTANCE) | |
| } | |
| } | |
| }; | |
| } |
I had a bit of difficulty trying to reduce it to something simpler like:
macro_rules! make_udf_function {
($UDF:ty, $NAME:ident) => {
make_udf_function!($UDF, $NAME, $UDF::new); // Error on :: token
};
($UDF:ty, $NAME:ident, $CTOR:path) => {
#[allow(rustdoc::redundant_explicit_links)]
#[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))]
pub fn $NAME() -> std::sync::Arc<datafusion_expr::ScalarUDF> {
// Singleton instance of the function
static INSTANCE: std::sync::LazyLock<
std::sync::Arc<datafusion_expr::ScalarUDF>,
> = std::sync::LazyLock::new(|| {
std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl(
$CTOR(),
))
});
std::sync::Arc::clone(&INSTANCE)
}
};
}To reduce duplication. For now just kept this minor duplication of the arms.
comphead
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Jefffrey
| use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, Int32Array, PrimitiveArray}; | ||
| use arrow::compute; | ||
| use arrow::datatypes::{ | ||
| ArrowNativeType, DataType, Int32Type, Int64Type, UInt32Type, UInt64Type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ArrowNativeType, DataType, Int32Type, Int64Type, UInt32Type, UInt64Type, | |
| DataType, Int32Type, Int64Type, UInt32Type, UInt64Type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The trait is still needed for calling get_byte_width in the shift functions
let bit_num = (T::Native::get_byte_width() * 8) as i32;| &arg_types[1], | ||
| )); | ||
| if value_array.data_type().is_null() || shift_array.data_type().is_null() { | ||
| return Ok(Arc::new(Int32Array::new_null(value_array.len()))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why always Int32Array ?
If shift_array.data_type().is_null() then I think you need to use the type returned by value_array.data_type(), which could be Int64 for example.
If value_array.data_type().is_null() then fallback to Int32Array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point, I'll fix it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the explicit null handling as it seems coercion will coerce to int types for us; added tests to confirm this
Which issue does this PR close?
Part of #12725
Rationale for this change
Prefer to avoid user_defined for consistency in function definitions.
What changes are included in this PR?
Refactor signature of Spark bit shift functions (left, right, right unsigned) to use coercion API instead of being user defined.
Also refactor the bit shift code to have a common base struct.
Move the Rust unit tests to SLTs.
Are these changes tested?
Existing tests.
Are there any user-facing changes?
No.