From 8f83f9ce659bac4262f1d387583f2ed5ba1554eb Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 15 Nov 2025 06:13:46 +0800 Subject: [PATCH 1/6] Refactor InListExpr to support structs by re-using existing hashing infrastructure Co-authored-by: David Hewitt --- datafusion/common/src/hash_utils.rs | 178 +- .../src/expressions/comparator.rs | 1444 +++++++++++++++ .../physical-expr/src/expressions/in_list.rs | 1574 +++++++++++++++-- .../physical-expr/src/expressions/mod.rs | 1 + datafusion/physical-plan/src/joins/utils.rs | 2 +- datafusion/sqllogictest/test_files/array.slt | 10 +- datafusion/sqllogictest/test_files/expr.slt | 207 +++ .../test_files/tpch/plans/q19.slt.part | 4 +- .../test_files/tpch/plans/q22.slt.part | 4 +- 9 files changed, 3238 insertions(+), 186 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/comparator.rs diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index d60189fb6fa3f..e52203244321f 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -31,8 +31,8 @@ use crate::cast::{ as_string_array, as_string_view_array, as_struct_array, }; use crate::error::Result; -#[cfg(not(feature = "force_hash_collisions"))] -use crate::error::_internal_err; +use crate::error::{_internal_datafusion_err, _internal_err}; +use std::cell::RefCell; // Combines two hashes into one hash #[inline] @@ -41,6 +41,94 @@ pub fn combine_hashes(l: u64, r: u64) -> u64 { hash.wrapping_mul(37).wrapping_add(r) } +/// Maximum size for the thread-local hash buffer before truncation (4MB = 524,288 u64 elements). +/// The goal of this is to avoid unbounded memory growth that would appear as a memory leak. +/// We allow temporary allocations beyond this size, but after use the buffer is truncated +/// to this size. +const MAX_BUFFER_SIZE: usize = 524_288; + +thread_local! { + /// Thread-local buffer for hash computations to avoid repeated allocations. + /// The buffer is reused across calls and truncated if it exceeds MAX_BUFFER_SIZE. + /// Defaults to a capacity of 8192 u64 elements which is the default batch size. + /// This corresponds to 64KB of memory. + static HASH_BUFFER: RefCell> = RefCell::new(Vec::with_capacity(8192)); +} + +/// Creates hashes for the given arrays using a thread-local buffer, then calls the provided callback +/// with an immutable reference to the computed hashes. +/// +/// This function manages a thread-local buffer to avoid repeated allocations. The buffer is automatically +/// truncated if it exceeds `MAX_BUFFER_SIZE` after use. +/// +/// # Arguments +/// * `arrays` - The arrays to hash (must contain at least one array) +/// * `random_state` - The random state for hashing +/// * `callback` - A function that receives an immutable reference to the hash slice and returns a result +/// +/// # Errors +/// Returns an error if: +/// - No arrays are provided +/// - The function is called reentrantly (i.e., the callback invokes `with_hashes` again on the same thread) +/// - The function is called during or after thread destruction +/// +/// # Example +/// ```ignore +/// use datafusion_common::hash_utils::{with_hashes, RandomState}; +/// use arrow::array::{Int32Array, ArrayRef}; +/// use std::sync::Arc; +/// +/// let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); +/// let random_state = RandomState::new(); +/// +/// let result = with_hashes([&array], &random_state, |hashes| { +/// // Use the hashes here +/// Ok(hashes.len()) +/// })?; +/// ``` +pub fn with_hashes( + arrays: I, + random_state: &RandomState, + callback: F, +) -> Result +where + I: IntoIterator, + T: AsDynArray, + F: FnOnce(&[u64]) -> Result, +{ + // Peek at the first array to determine buffer size without fully collecting + let mut iter = arrays.into_iter().peekable(); + + // Get the required size from the first array + let required_size = match iter.peek() { + Some(arr) => arr.as_dyn_array().len(), + None => return _internal_err!("with_hashes requires at least one array"), + }; + + HASH_BUFFER.try_with(|cell| { + let mut buffer = cell.try_borrow_mut() + .map_err(|_| _internal_datafusion_err!("with_hashes cannot be called reentrantly on the same thread"))?; + + // Ensure buffer has sufficient length, clearing old values + buffer.clear(); + buffer.resize(required_size, 0); + + // Create hashes in the buffer - this consumes the iterator + create_hashes(iter, random_state, &mut buffer[..required_size])?; + + // Execute the callback with an immutable slice + let result = callback(&buffer[..required_size])?; + + // Cleanup: truncate if buffer grew too large + if buffer.capacity() > MAX_BUFFER_SIZE { + buffer.truncate(MAX_BUFFER_SIZE); + buffer.shrink_to_fit(); + } + + Ok(result) + }).map_err(|_| _internal_datafusion_err!("with_hashes cannot access thread-local storage during or after thread destruction"))? +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { if mul_col { @@ -478,8 +566,8 @@ impl AsDynArray for &ArrayRef { pub fn create_hashes<'a, I, T>( arrays: I, random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> + hashes_buffer: &'a mut [u64], +) -> Result<&'a mut [u64]> where I: IntoIterator, T: AsDynArray, @@ -522,7 +610,7 @@ mod tests { fn create_hashes_for_empty_fixed_size_lit() -> Result<()> { let empty_array = FixedSizeListBuilder::new(StringBuilder::new(), 1).finish(); let random_state = RandomState::with_seeds(0, 0, 0, 0); - let hashes_buff = &mut vec![0; 0]; + let hashes_buff = &mut [0; 0]; let hashes = create_hashes( &[Arc::new(empty_array) as ArrayRef], &random_state, @@ -1000,4 +1088,84 @@ mod tests { assert_eq!(hashes1, hashes2); } + + #[test] + fn test_with_hashes() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + // Test that with_hashes produces the same results as create_hashes + let mut expected_hashes = vec![0; array.len()]; + create_hashes([&array], &random_state, &mut expected_hashes).unwrap(); + + let result = with_hashes([&array], &random_state, |hashes| { + assert_eq!(hashes.len(), 4); + // Verify hashes match expected values + assert_eq!(hashes, &expected_hashes[..]); + // Return a copy of the hashes + Ok(hashes.to_vec()) + }) + .unwrap(); + + // Verify callback result is returned correctly + assert_eq!(result, expected_hashes); + } + + #[test] + fn test_with_hashes_multi_column() { + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let str_array: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + // Test multi-column hashing + let mut expected_hashes = vec![0; int_array.len()]; + create_hashes( + [&int_array, &str_array], + &random_state, + &mut expected_hashes, + ) + .unwrap(); + + with_hashes([&int_array, &str_array], &random_state, |hashes| { + assert_eq!(hashes.len(), 3); + assert_eq!(hashes, &expected_hashes[..]); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_with_hashes_empty_arrays() { + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + // Test that passing no arrays returns an error + let empty: [&ArrayRef; 0] = []; + let result = with_hashes(empty, &random_state, |_hashes| Ok(())); + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("requires at least one array")); + } + + #[test] + fn test_with_hashes_reentrancy() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let array2: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + // Test that reentrant calls return an error instead of panicking + let result = with_hashes([&array], &random_state, |_hashes| { + // Try to call with_hashes again inside the callback + with_hashes([&array2], &random_state, |_inner_hashes| Ok(())) + }); + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("reentrantly") || err_msg.contains("cannot be called"), + "Error message should mention reentrancy: {err_msg}", + ); + } } diff --git a/datafusion/physical-expr/src/expressions/comparator.rs b/datafusion/physical-expr/src/expressions/comparator.rs new file mode 100644 index 0000000000000..d0cf0ffc045db --- /dev/null +++ b/datafusion/physical-expr/src/expressions/comparator.rs @@ -0,0 +1,1444 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Enum-based comparator that eliminates dynamic dispatch for scalar types +//! +//! This module provides an optimized comparator implementation that uses an enum +//! with variants for each scalar Arrow type, eliminating the overhead of dynamic +//! dispatch for common comparison operations. Complex recursive types (List, Struct, +//! Map, Dictionary) fall back to dynamic dispatch. +//! +//! While we are implementing this in DataFusion for now we hope to upstream this into arrow-rs +//! and replace the existing completely dynamic comparator there with this more efficient one. + +use arrow::array::types::*; +use arrow::array::{make_comparator as arrow_make_comparator, *}; +use arrow::buffer::{BooleanBuffer, NullBuffer, ScalarBuffer}; +use arrow::compute::SortOptions; +use arrow::datatypes::{ + i256, DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, TimeUnit, +}; +use arrow::error::ArrowError; +use std::cmp::Ordering; + +// Type alias for dynamic comparator (same as arrow_ord::ord::DynComparator) +type DynComparator = Box Ordering + Send + Sync>; + +/// Comparator that uses enum dispatch for scalar types and dynamic dispatch for complex types +pub(crate) enum Comparator { + // Primitive integer types + Int8 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Int16 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Int32 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Int64 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Unsigned integer types + UInt8 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + UInt16 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + UInt32 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + UInt64 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Floating point types + Float16 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Float32 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Float64 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Date and time types + Date32 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Date64 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Time32Second { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Time32Millisecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Time64Microsecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Time64Nanosecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Timestamp types + TimestampSecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + TimestampMillisecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + TimestampMicrosecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + TimestampNanosecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Duration types + DurationSecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + DurationMillisecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + DurationMicrosecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + DurationNanosecond { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Interval types + IntervalYearMonth { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + IntervalDayTime { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + IntervalMonthDayNano { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Decimal types + Decimal128 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Decimal256 { + left: ScalarBuffer, + right: ScalarBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Boolean type + Boolean { + left: BooleanBuffer, + right: BooleanBuffer, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // String types + Utf8 { + left: GenericByteArray, + right: GenericByteArray, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + LargeUtf8 { + left: GenericByteArray, + right: GenericByteArray, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + Utf8View { + left: GenericByteViewArray, + right: GenericByteViewArray, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Binary types + Binary { + left: GenericByteArray, + right: GenericByteArray, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + LargeBinary { + left: GenericByteArray, + right: GenericByteArray, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + BinaryView { + left: GenericByteViewArray, + right: GenericByteViewArray, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // FixedSizeBinary + FixedSizeBinary { + left: FixedSizeBinaryArray, + right: FixedSizeBinaryArray, + left_nulls: Option, + right_nulls: Option, + opts: SortOptions, + }, + + // Dynamic fallback for recursive/complex types: + // - List, LargeList, FixedSizeList + // - Struct + // - Map + // - Dictionary + Dynamic(DynComparator), +} + +/// Helper macro to reduce duplication for float comparisons using total_cmp +macro_rules! compare_float { + ($left:expr, $right:expr, $left_nulls:expr, $right_nulls:expr, $opts:expr, $i:expr, $j:expr) => {{ + let ord = match ( + $left_nulls.as_ref().is_some_and(|n| n.is_null($i)), + $right_nulls.as_ref().is_some_and(|n| n.is_null($j)), + ) { + (true, true) => return Ordering::Equal, + (true, false) => { + return if $opts.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + return if $opts.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + (false, false) => { + let left_slice = $left.as_ref(); + let right_slice = $right.as_ref(); + left_slice[$i].total_cmp(&right_slice[$j]) + } + }; + + if $opts.descending { + ord.reverse() + } else { + ord + } + }}; +} + +impl Comparator { + /// Compare elements at indices i (from left array) and j (from right array) + #[inline] + pub fn compare(&self, i: usize, j: usize) -> Ordering { + match self { + Self::Int8 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Int16 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Int32 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Int64 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::UInt8 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::UInt16 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::UInt32 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::UInt64 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Float16 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_float!(left, right, left_nulls, right_nulls, opts, i, j), + Self::Float32 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_float!(left, right, left_nulls, right_nulls, opts, i, j), + Self::Float64 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_float!(left, right, left_nulls, right_nulls, opts, i, j), + Self::Date32 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Date64 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Time32Second { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Time32Millisecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Time64Microsecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Time64Nanosecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::TimestampSecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::TimestampMillisecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::TimestampMicrosecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::TimestampNanosecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::DurationSecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::DurationMillisecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::DurationMicrosecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::DurationNanosecond { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::IntervalYearMonth { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::IntervalDayTime { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::IntervalMonthDayNano { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Decimal128 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Decimal256 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Boolean { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_boolean_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Utf8 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_bytes_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::LargeUtf8 { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_bytes_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::Utf8View { + left, + right, + left_nulls, + right_nulls, + opts, + } => { + compare_byte_view_values(left, right, left_nulls, right_nulls, opts, i, j) + } + Self::Binary { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_bytes_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::LargeBinary { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_bytes_values(left, right, left_nulls, right_nulls, opts, i, j), + Self::BinaryView { + left, + right, + left_nulls, + right_nulls, + opts, + } => { + compare_byte_view_values(left, right, left_nulls, right_nulls, opts, i, j) + } + Self::FixedSizeBinary { + left, + right, + left_nulls, + right_nulls, + opts, + } => compare_fixed_binary_values( + left, + right, + left_nulls, + right_nulls, + opts, + i, + j, + ), + Self::Dynamic(cmp) => cmp(i, j), + } + } +} + +// Helper functions for comparing values with null handling +use arrow::datatypes::ArrowNativeType; + +/// Compare values using Ord::cmp for types that implement Ord (integers, decimals, intervals, etc.) +#[inline] +fn compare_ord_values( + left: &ScalarBuffer, + right: &ScalarBuffer, + left_nulls: &Option, + right_nulls: &Option, + opts: &SortOptions, + i: usize, + j: usize, +) -> Ordering { + // Check nulls first + let ord = match ( + left_nulls.as_ref().is_some_and(|n| n.is_null(i)), + right_nulls.as_ref().is_some_and(|n| n.is_null(j)), + ) { + (true, true) => return Ordering::Equal, + (true, false) => { + return if opts.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + return if opts.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + (false, false) => { + let left_slice: &[T] = left.as_ref(); + let right_slice: &[T] = right.as_ref(); + left_slice[i].cmp(&right_slice[j]) + } + }; + + if opts.descending { + ord.reverse() + } else { + ord + } +} + +#[inline] +fn compare_boolean_values( + left: &BooleanBuffer, + right: &BooleanBuffer, + left_nulls: &Option, + right_nulls: &Option, + opts: &SortOptions, + i: usize, + j: usize, +) -> Ordering { + // Check nulls first + let ord = match ( + left_nulls.as_ref().is_some_and(|n| n.is_null(i)), + right_nulls.as_ref().is_some_and(|n| n.is_null(j)), + ) { + (true, true) => return Ordering::Equal, + (true, false) => { + return if opts.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + return if opts.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + (false, false) => left.value(i).cmp(&right.value(j)), + }; + + if opts.descending { + ord.reverse() + } else { + ord + } +} + +#[inline] +fn compare_bytes_values( + left: &GenericByteArray, + right: &GenericByteArray, + left_nulls: &Option, + right_nulls: &Option, + opts: &SortOptions, + i: usize, + j: usize, +) -> Ordering { + // Check nulls first + let ord = match ( + left_nulls.as_ref().is_some_and(|n| n.is_null(i)), + right_nulls.as_ref().is_some_and(|n| n.is_null(j)), + ) { + (true, true) => return Ordering::Equal, + (true, false) => { + return if opts.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + return if opts.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + (false, false) => { + let l: &[u8] = left.value(i).as_ref(); + let r: &[u8] = right.value(j).as_ref(); + l.cmp(r) + } + }; + + if opts.descending { + ord.reverse() + } else { + ord + } +} + +#[inline] +fn compare_byte_view_values( + left: &GenericByteViewArray, + right: &GenericByteViewArray, + left_nulls: &Option, + right_nulls: &Option, + opts: &SortOptions, + i: usize, + j: usize, +) -> Ordering { + // Check nulls first + let ord = match ( + left_nulls.as_ref().is_some_and(|n| n.is_null(i)), + right_nulls.as_ref().is_some_and(|n| n.is_null(j)), + ) { + (true, true) => return Ordering::Equal, + (true, false) => { + return if opts.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + return if opts.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + (false, false) => { + let l: &[u8] = left.value(i).as_ref(); + let r: &[u8] = right.value(j).as_ref(); + l.cmp(r) + } + }; + + if opts.descending { + ord.reverse() + } else { + ord + } +} + +#[inline] +fn compare_fixed_binary_values( + left: &FixedSizeBinaryArray, + right: &FixedSizeBinaryArray, + left_nulls: &Option, + right_nulls: &Option, + opts: &SortOptions, + i: usize, + j: usize, +) -> Ordering { + // Check nulls first + let ord = match ( + left_nulls.as_ref().is_some_and(|n| n.is_null(i)), + right_nulls.as_ref().is_some_and(|n| n.is_null(j)), + ) { + (true, true) => return Ordering::Equal, + (true, false) => { + return if opts.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + return if opts.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + (false, false) => left.value(i).cmp(right.value(j)), + }; + + if opts.descending { + ord.reverse() + } else { + ord + } +} + +/// Create a comparator for the given arrays and sort options. +/// +/// This wraps Arrow's `make_comparator` but returns our enum-based `Comparator` +/// for scalar types, falling back to dynamic dispatch for complex types. +/// +/// # Errors +/// If the data types of the arrays are not supported for comparison. +pub(crate) fn make_comparator( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { + use DataType::*; + + let left_nulls = left.nulls().filter(|x| x.null_count() > 0).cloned(); + let right_nulls = right.nulls().filter(|x| x.null_count() > 0).cloned(); + + Ok(match (left.data_type(), right.data_type()) { + (Int8, Int8) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Int8 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Int16, Int16) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Int16 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Int32, Int32) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Int32 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Int64, Int64) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Int64 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (UInt8, UInt8) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::UInt8 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (UInt16, UInt16) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::UInt16 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (UInt32, UInt32) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::UInt32 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (UInt64, UInt64) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::UInt64 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Float16, Float16) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Float16 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Float32, Float32) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Float32 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Float64, Float64) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Float64 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Date32, Date32) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Date32 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Date64, Date64) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Date64 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Time32(TimeUnit::Second), Time32(TimeUnit::Second)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Time32Second { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Millisecond)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Time32Millisecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Microsecond)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Time64Microsecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Nanosecond)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Time64Nanosecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Timestamp(TimeUnit::Second, _), Timestamp(TimeUnit::Second, _)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::TimestampSecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Timestamp(TimeUnit::Millisecond, _), Timestamp(TimeUnit::Millisecond, _)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::TimestampMillisecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Timestamp(TimeUnit::Microsecond, _), Timestamp(TimeUnit::Microsecond, _)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::TimestampMicrosecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Timestamp(TimeUnit::Nanosecond, _), Timestamp(TimeUnit::Nanosecond, _)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::TimestampNanosecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Duration(TimeUnit::Second), Duration(TimeUnit::Second)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::DurationSecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Duration(TimeUnit::Millisecond), Duration(TimeUnit::Millisecond)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::DurationMillisecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Duration(TimeUnit::Microsecond), Duration(TimeUnit::Microsecond)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::DurationMicrosecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Duration(TimeUnit::Nanosecond), Duration(TimeUnit::Nanosecond)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::DurationNanosecond { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Interval(IntervalUnit::YearMonth), Interval(IntervalUnit::YearMonth)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::IntervalYearMonth { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::DayTime)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::IntervalDayTime { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Interval(IntervalUnit::MonthDayNano), Interval(IntervalUnit::MonthDayNano)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::IntervalMonthDayNano { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Decimal128(_, _), Decimal128(_, _)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Decimal128 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Decimal256(_, _), Decimal256(_, _)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + Comparator::Decimal256 { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Boolean, Boolean) => { + let left = left.as_boolean(); + let right = right.as_boolean(); + Comparator::Boolean { + left: left.values().clone(), + right: right.values().clone(), + left_nulls, + right_nulls, + opts, + } + } + (Utf8, Utf8) => { + let left = left.as_string::(); + let right = right.as_string::(); + Comparator::Utf8 { + left: left.clone(), + right: right.clone(), + left_nulls, + right_nulls, + opts, + } + } + (LargeUtf8, LargeUtf8) => { + let left = left.as_string::(); + let right = right.as_string::(); + Comparator::LargeUtf8 { + left: left.clone(), + right: right.clone(), + left_nulls, + right_nulls, + opts, + } + } + (Utf8View, Utf8View) => { + let left = left.as_string_view(); + let right = right.as_string_view(); + Comparator::Utf8View { + left: left.clone(), + right: right.clone(), + left_nulls, + right_nulls, + opts, + } + } + (Binary, Binary) => { + let left = left.as_binary::(); + let right = right.as_binary::(); + Comparator::Binary { + left: left.clone(), + right: right.clone(), + left_nulls, + right_nulls, + opts, + } + } + (LargeBinary, LargeBinary) => { + let left = left.as_binary::(); + let right = right.as_binary::(); + Comparator::LargeBinary { + left: left.clone(), + right: right.clone(), + left_nulls, + right_nulls, + opts, + } + } + (BinaryView, BinaryView) => { + let left = left.as_binary_view(); + let right = right.as_binary_view(); + Comparator::BinaryView { + left: left.clone(), + right: right.clone(), + left_nulls, + right_nulls, + opts, + } + } + (FixedSizeBinary(_), FixedSizeBinary(_)) => { + let left = left.as_fixed_size_binary(); + let right = right.as_fixed_size_binary(); + Comparator::FixedSizeBinary { + left: left.clone(), + right: right.clone(), + left_nulls, + right_nulls, + opts, + } + } + // Fall back to dynamic dispatch for complex types + _ => { + let cmp = arrow_make_comparator(left, right, opts)?; + Comparator::Dynamic(cmp) + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + BooleanArray, Date32Array, Float64Array, Int32Array, StringArray, + }; + + #[test] + fn test_int32_compare() { + let left = Int32Array::from(vec![1, 2, 3]); + let right = Int32Array::from(vec![2, 2, 1]); + + let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); + + assert_eq!(cmp.compare(0, 0), Ordering::Less); // 1 < 2 + assert_eq!(cmp.compare(1, 1), Ordering::Equal); // 2 == 2 + assert_eq!(cmp.compare(2, 2), Ordering::Greater); // 3 > 1 + } + + #[test] + fn test_int32_compare_with_nulls() { + let left = Int32Array::from(vec![Some(1), None, Some(3)]); + let right = Int32Array::from(vec![Some(2), Some(2), None]); + + let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); + + assert_eq!(cmp.compare(0, 0), Ordering::Less); // 1 < 2 + assert_eq!(cmp.compare(1, 1), Ordering::Less); // null < 2 (nulls_first=true) + assert_eq!(cmp.compare(2, 2), Ordering::Greater); // 3 > null + } + + #[test] + fn test_int32_descending() { + let left = Int32Array::from(vec![1, 2, 3]); + let right = Int32Array::from(vec![2, 2, 1]); + + let cmp = make_comparator( + &left, + &right, + SortOptions { + descending: true, + nulls_first: false, + }, + ) + .unwrap(); + + assert_eq!(cmp.compare(0, 0), Ordering::Greater); // 1 > 2 (descending) + assert_eq!(cmp.compare(1, 1), Ordering::Equal); // 2 == 2 + assert_eq!(cmp.compare(2, 2), Ordering::Less); // 3 < 1 (descending) + } + + #[test] + fn test_float64_compare() { + let left = Float64Array::from(vec![1.5, 2.5, f64::NAN]); + let right = Float64Array::from(vec![2.5, 2.5, 1.5]); + + let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); + + assert_eq!(cmp.compare(0, 0), Ordering::Less); // 1.5 < 2.5 + assert_eq!(cmp.compare(1, 1), Ordering::Equal); // 2.5 == 2.5 + assert_eq!(cmp.compare(2, 2), Ordering::Greater); // NaN > 1.5 (using total_cmp) + } + + #[test] + fn test_string_compare() { + let left = StringArray::from(vec!["a", "b", "c"]); + let right = StringArray::from(vec!["b", "b", "a"]); + + let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); + + assert_eq!(cmp.compare(0, 0), Ordering::Less); // "a" < "b" + assert_eq!(cmp.compare(1, 1), Ordering::Equal); // "b" == "b" + assert_eq!(cmp.compare(2, 2), Ordering::Greater); // "c" > "a" + } + + #[test] + fn test_boolean_compare() { + let left = BooleanArray::from(vec![false, true, false]); + let right = BooleanArray::from(vec![true, true, false]); + + let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); + + assert_eq!(cmp.compare(0, 0), Ordering::Less); // false < true + assert_eq!(cmp.compare(1, 1), Ordering::Equal); // true == true + assert_eq!(cmp.compare(2, 2), Ordering::Equal); // false == false + } + + #[test] + fn test_date32_compare() { + let left = Date32Array::from(vec![100, 200, 300]); + let right = Date32Array::from(vec![200, 200, 100]); + + let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); + + assert_eq!(cmp.compare(0, 0), Ordering::Less); // 100 < 200 + assert_eq!(cmp.compare(1, 1), Ordering::Equal); // 200 == 200 + assert_eq!(cmp.compare(2, 2), Ordering::Greater); // 300 > 100 + } + + #[test] + fn test_nulls_last() { + let left = Int32Array::from(vec![Some(1), None, Some(3)]); + let right = Int32Array::from(vec![Some(2), Some(2), None]); + + let cmp = make_comparator( + &left, + &right, + SortOptions { + descending: false, + nulls_first: false, + }, + ) + .unwrap(); + + assert_eq!(cmp.compare(0, 0), Ordering::Less); // 1 < 2 + assert_eq!(cmp.compare(1, 1), Ordering::Greater); // null > 2 (nulls_first=false) + assert_eq!(cmp.compare(2, 2), Ordering::Less); // 3 < null + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 4bcfbe35d0185..5da7e18df9131 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -25,12 +25,13 @@ use std::sync::Arc; use crate::physical_expr::physical_exprs_bag_equal; use crate::PhysicalExpr; -use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; +use super::comparator::make_comparator; use arrow::array::*; use arrow::buffer::BooleanBuffer; use arrow::compute::kernels::boolean::{not, or_kleene}; -use arrow::compute::take; +use arrow::compute::{take, SortOptions}; use arrow::datatypes::*; +use arrow::downcast_dictionary_array; use arrow::util::bit_iterator::BitIndexIterator; use arrow::{downcast_dictionary_array, downcast_primitive_array}; use datafusion_common::cast::{ @@ -43,17 +44,27 @@ use datafusion_common::{ }; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::compare_with_eq; +use datafusion_common::hash_utils::with_hashes; +use datafusion_common::{exec_err, internal_err, DFSchema, Result, ScalarValue}; +use datafusion_expr::{expr_vec_fmt, ColumnarValue}; use ahash::RandomState; use datafusion_common::HashMap; use hashbrown::hash_map::RawEntryMut; +/// Static filter for InList that stores the array and hash set for O(1) lookups +#[derive(Debug, Clone)] +struct StaticFilter { + array: ArrayRef, + hash_set: ArrayHashSet, +} + /// InList pub struct InListExpr { expr: Arc, list: Vec>, negated: bool, - static_filter: Option>, + static_filter: Option, } impl Debug for InListExpr { @@ -66,13 +77,8 @@ impl Debug for InListExpr { } } -/// A type-erased container of array elements -pub trait Set: Send + Sync { - fn contains(&self, v: &dyn Array, negated: bool) -> Result; - fn has_nulls(&self) -> bool; -} - -struct ArrayHashSet { +#[derive(Debug, Clone)] +pub(crate) struct ArrayHashSet { state: RandomState, /// Used to provide a lookup from value to in list index /// @@ -81,66 +87,56 @@ struct ArrayHashSet { map: HashMap, } -struct ArraySet { - array: T, - hash_set: ArrayHashSet, -} - -impl ArraySet -where - T: Array + From, -{ - fn new(array: &T, hash_set: ArrayHashSet) -> Self { - Self { - array: downcast_array(array), - hash_set, +impl ArrayHashSet { + /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. + fn contains( + &self, + v: &dyn Array, + in_array: &dyn Array, + negated: bool, + ) -> Result { + // Null type comparisons always return null (SQL three-valued logic) + if v.data_type() == &DataType::Null || in_array.data_type() == &DataType::Null { + return Ok(BooleanArray::from(vec![None; v.len()])); } - } -} -impl Set for ArraySet -where - T: Array + 'static, - for<'a> &'a T: ArrayAccessor, - for<'a> <&'a T as ArrayAccessor>::Item: IsEqual, -{ - fn contains(&self, v: &dyn Array, negated: bool) -> Result { downcast_dictionary_array! { v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; + let values_contains = self.contains(v.values().as_ref(), in_array, negated)?; let result = take(&values_contains, v.keys(), None)?; return Ok(downcast_array(result.as_ref())) } _ => {} } - let v = v.as_any().downcast_ref::().unwrap(); - let in_array = &self.array; - let has_nulls = in_array.null_count() != 0; + let needle_nulls = v.logical_nulls(); + let needle_nulls = needle_nulls.as_ref(); + let haystack_has_nulls = in_array.null_count() != 0; + + with_hashes([v], &self.state, |hashes| { + let cmp = make_comparator(v, in_array, SortOptions::default())?; + Ok((0..v.len()) + .map(|i| { + // SQL three-valued logic: null IN (...) is always null + if needle_nulls.is_some_and(|nulls| nulls.is_null(i)) { + return None; + } - Ok(ArrayIter::new(v) - .map(|v| { - v.and_then(|v| { - let hash = v.hash_one(&self.hash_set.state); + let hash = hashes[i]; let contains = self - .hash_set .map .raw_entry() - .from_hash(hash, |idx| in_array.value(*idx).is_equal(&v)) + .from_hash(hash, |idx| cmp.compare(i, *idx).is_eq()) .is_some(); match contains { true => Some(!negated), - false if has_nulls => None, + false if haystack_has_nulls => None, false => Some(negated), } }) - }) - .collect()) - } - - fn has_nulls(&self) -> bool { - self.array.null_count() != 0 + .collect()) + }) } } @@ -150,64 +146,43 @@ where /// /// Note: This is split into a separate function as higher-rank trait bounds currently /// cause type inference to misbehave -fn make_hash_set(array: &T) -> ArrayHashSet -where - T: ArrayAccessor, - T::Item: IsEqual, -{ +fn make_hash_set(array: &dyn Array) -> Result { + // Null type has no natural order - return empty hash set + if array.data_type() == &DataType::Null { + return Ok(ArrayHashSet { + state: RandomState::new(), + map: HashMap::with_hasher(()), + }); + } + let state = RandomState::new(); - let mut map: HashMap = - HashMap::with_capacity_and_hasher(array.len(), ()); - - let insert_value = |idx| { - let value = array.value(idx); - let hash = value.hash_one(&state); - if let RawEntryMut::Vacant(v) = map - .raw_entry_mut() - .from_hash(hash, |x| array.value(*x).is_equal(&value)) - { - v.insert_with_hasher(hash, idx, (), |x| array.value(*x).hash_one(&state)); - } - }; + let mut map: HashMap = HashMap::with_hasher(()); + + with_hashes([array], &state, |hashes| -> Result<()> { + let cmp = make_comparator(array, array, SortOptions::default())?; + + let insert_value = |idx| { + let hash = hashes[idx]; + if let RawEntryMut::Vacant(v) = map + .raw_entry_mut() + .from_hash(hash, |x| cmp.compare(*x, idx).is_eq()) + { + v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); + } + }; - match array.nulls() { - Some(nulls) => { - BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) - .for_each(insert_value) + match array.nulls() { + Some(nulls) => { + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + .for_each(insert_value) + } + None => (0..array.len()).for_each(insert_value), } - None => (0..array.len()).for_each(insert_value), - } - ArrayHashSet { state, map } -} + Ok(()) + })?; -/// Creates a `Box` for the given list of `IN` expressions and `batch` -fn make_set(array: &dyn Array) -> Result> { - Ok(downcast_primitive_array! { - array => Arc::new(ArraySet::new(array, make_hash_set(&array))), - DataType::Boolean => { - let array = as_boolean_array(array)?; - Arc::new(ArraySet::new(array, make_hash_set(&array))) - }, - DataType::Utf8 => { - let array = as_string_array(array)?; - Arc::new(ArraySet::new(array, make_hash_set(&array))) - } - DataType::LargeUtf8 => { - let array = as_largestring_array(array); - Arc::new(ArraySet::new(array, make_hash_set(&array))) - } - DataType::Binary => { - let array = as_generic_binary_array::(array)?; - Arc::new(ArraySet::new(array, make_hash_set(&array))) - } - DataType::LargeBinary => { - let array = as_generic_binary_array::(array)?; - Arc::new(ArraySet::new(array, make_hash_set(&array))) - } - DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"), - d => return not_impl_err!("DataType::{d} not supported in InList") - }) + Ok(ArrayHashSet { state, map }) } /// Evaluates the list of expressions into an array, flattening any dictionaries @@ -232,56 +207,26 @@ fn evaluate_list( ScalarValue::iter_to_array(scalars) } -fn try_cast_static_filter_to_set( +/// Try to evaluate a list of expressions as constants. +/// +/// Returns an ArrayRef if all expressions are constants (can be evaluated on an +/// empty RecordBatch), otherwise returns an error. This is used to detect when +/// a list contains only literals, casts of literals, or other constant expressions. +fn try_evaluate_constant_list( list: &[Arc], schema: &Schema, -) -> Result> { +) -> Result { let batch = RecordBatch::new_empty(Arc::new(schema.clone())); - make_set(evaluate_list(list, &batch)?.as_ref()) + evaluate_list(list, &batch) } -/// Custom equality check function which is used with [`ArrayHashSet`] for existence check. -trait IsEqual: HashValue { - fn is_equal(&self, other: &Self) -> bool; -} - -impl IsEqual for &T { - fn is_equal(&self, other: &Self) -> bool { - T::is_equal(self, other) - } -} - -macro_rules! is_equal { - ($($t:ty),+) => { - $(impl IsEqual for $t { - fn is_equal(&self, other: &Self) -> bool { - self == other - } - })* - }; -} -is_equal!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64); -is_equal!(bool, str, [u8]); -is_equal!(IntervalDayTime, IntervalMonthDayNano); - -macro_rules! is_equal_float { - ($($t:ty),+) => { - $(impl IsEqual for $t { - fn is_equal(&self, other: &Self) -> bool { - self.to_bits() == other.to_bits() - } - })* - }; -} -is_equal_float!(half::f16, f32, f64); - impl InListExpr { /// Create a new InList expression - pub fn new( + fn new( expr: Arc, list: Vec>, negated: bool, - static_filter: Option>, + static_filter: Option, ) -> Self { Self { expr, @@ -305,19 +250,34 @@ impl InListExpr { pub fn negated(&self) -> bool { self.negated } -} -#[macro_export] -macro_rules! expr_vec_fmt { - ( $ARRAY:expr ) => {{ - $ARRAY - .iter() - .map(|e| format!("{e}")) - .collect::>() - .join(", ") - }}; + /// Create a new InList expression directly from an array, bypassing expression evaluation. + /// + /// This is more efficient than `in_list()` when you already have the list as an array, + /// as it avoids the conversion: `ArrayRef -> Vec -> ArrayRef -> ArrayHashSet`. + /// Instead it goes directly: `ArrayRef -> ArrayHashSet`. + /// + /// The `list` field will be empty when using this constructor, as the array is stored + /// directly in the static filter. + /// + /// This does not make the expression any more performant at runtime, but it does make it slightly + /// cheaper to build. + pub fn try_new_from_array( + expr: Arc, + array: ArrayRef, + negated: bool, + ) -> Result { + let list = (0..array.len()) + .map(|i| { + let scalar = ScalarValue::try_from_array(array.as_ref(), i)?; + Ok(crate::expressions::lit(scalar) as Arc) + }) + .collect::>>()?; + let hash_set = make_hash_set(array.as_ref())?; + let static_filter = StaticFilter { array, hash_set }; + Ok(Self::new(expr, list, negated, Some(static_filter))) + } } - impl std::fmt::Display for InListExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let list = expr_vec_fmt!(self.list); @@ -352,7 +312,7 @@ impl PhysicalExpr for InListExpr { } if let Some(static_filter) = &self.static_filter { - Ok(static_filter.has_nulls()) + Ok(static_filter.array.null_count() > 0) } else { for expr in &self.list { if expr.nullable(input_schema)? { @@ -367,18 +327,90 @@ impl PhysicalExpr for InListExpr { let num_rows = batch.num_rows(); let value = self.expr.evaluate(batch)?; let r = match &self.static_filter { - Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?, + Some(filter) => { + match value { + ColumnarValue::Array(array) => filter.hash_set.contains( + &array, + filter.array.as_ref(), + self.negated, + )?, + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + // SQL three-valued logic: null IN (...) is always null + // The code below would handle this correctly but this is a faster path + return Ok(ColumnarValue::Array(Arc::new( + BooleanArray::from(vec![None; num_rows]), + ))); + } + // Use a 1 row array to avoid code duplication/branching + // Since all we do is compute hash and lookup this should be efficient enough + let array = scalar.to_array()?; + let result_array = filter.hash_set.contains( + array.as_ref(), + filter.array.as_ref(), + self.negated, + )?; + // Broadcast the single result to all rows + // Must check is_null() to preserve NULL values (SQL three-valued logic) + if result_array.is_null(0) { + BooleanArray::from(vec![None; num_rows]) + } else { + BooleanArray::from_iter(std::iter::repeat_n( + result_array.value(0), + num_rows, + )) + } + } + } + } None => { + // No static filter: iterate through each expression, compare, and OR results let value = value.into_array(num_rows)?; - let is_nested = value.data_type().is_nested(); let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( BooleanArray::new(BooleanBuffer::new_unset(num_rows), None), |result, expr| -> Result { - let rhs = compare_with_eq( - &value, - &expr?.into_array(num_rows)?, - is_nested, - )?; + let rhs = match expr? { + ColumnarValue::Array(array) => { + let cmp = make_comparator( + value.as_ref(), + array.as_ref(), + SortOptions::default(), + )?; + (0..num_rows) + .map(|i| { + if value.is_null(i) || array.is_null(i) { + return None; + } + Some(cmp.compare(i, i).is_eq()) + }) + .collect::() + } + ColumnarValue::Scalar(scalar) => { + // Check if scalar is null once, before the loop + if scalar.is_null() { + // If scalar is null, all comparisons return null + BooleanArray::from(vec![None; num_rows]) + } else { + // Convert scalar to 1-element array + let array = scalar.to_array()?; + let cmp = make_comparator( + value.as_ref(), + array.as_ref(), + SortOptions::default(), + )?; + // Compare each row of value with the single scalar element + (0..num_rows) + .map(|i| { + if value.is_null(i) { + None + } else { + Some(cmp.compare(i, 0).is_eq()) + } + }) + .collect::() + } + } + }; Ok(or_kleene(&result, &rhs)?) }, )?; @@ -394,8 +426,7 @@ impl PhysicalExpr for InListExpr { } fn children(&self) -> Vec<&Arc> { - let mut children = vec![]; - children.push(&self.expr); + let mut children = vec![&self.expr]; children.extend(&self.list); children } @@ -444,8 +475,8 @@ impl Hash for InListExpr { fn hash(&self, state: &mut H) { self.expr.hash(state); self.negated.hash(state); - self.list.hash(state); // Add `self.static_filter` when hash is available + self.list.hash(state); } } @@ -465,7 +496,14 @@ pub fn in_list( "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" ); } - let static_filter = try_cast_static_filter_to_set(&list, schema).ok(); + + // Try to create a static filter for constant expressions + let static_filter = try_evaluate_constant_list(&list, schema) + .and_then(|array| { + make_hash_set(array.as_ref()).map(|hash_set| StaticFilter { array, hash_set }) + }) + .ok(); + Ok(Arc::new(InListExpr::new( expr, list, @@ -479,11 +517,12 @@ mod tests { use super::*; use crate::expressions; use crate::expressions::{col, lit, try_cast}; + use arrow::buffer::NullBuffer; use datafusion_common::plan_err; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_physical_expr_common::physical_expr::fmt_sql; use insta::assert_snapshot; - use itertools::Itertools as _; + use itertools::Itertools; type InListCastResult = (Arc, Vec>); @@ -519,6 +558,14 @@ mod tests { } } + fn try_cast_static_filter_to_set( + list: &[Arc], + schema: &Schema, + ) -> Result { + let array = try_evaluate_constant_list(list, schema)?; + make_hash_set(array.as_ref()) + } + // Attempts to coerce the types of `list_type` to be comparable with the // `expr_type` fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option { @@ -529,7 +576,18 @@ mod tests { }) } - // applies the in_list expr to an input batch and list + /// Test helper macro that evaluates an IN LIST expression with automatic type casting. + /// + /// # Parameters + /// - `$BATCH`: The `RecordBatch` containing the input data to evaluate against + /// - `$LIST`: A `Vec>` of literal expressions representing the IN list values + /// - `$NEGATED`: A `&bool` indicating whether this is a NOT IN operation (true) or IN operation (false) + /// - `$EXPECTED`: A `Vec>` representing the expected boolean results for each row + /// - `$COL`: An `Arc` representing the column expression to evaluate + /// - `$SCHEMA`: A `&Schema` reference for the input batch + /// + /// This macro first applies type casting to the column and list expressions to ensure + /// type compatibility, then delegates to `in_list_raw!` to perform the evaluation and assertion. macro_rules! in_list { ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; @@ -544,7 +602,19 @@ mod tests { }}; } - // applies the in_list expr to an input batch and list without cast + /// Test helper macro that evaluates an IN LIST expression without automatic type casting. + /// + /// # Parameters + /// - `$BATCH`: The `RecordBatch` containing the input data to evaluate against + /// - `$LIST`: A `Vec>` of literal expressions representing the IN list values + /// - `$NEGATED`: A `&bool` indicating whether this is a NOT IN operation (true) or IN operation (false) + /// - `$EXPECTED`: A `Vec>` representing the expected boolean results for each row + /// - `$COL`: An `Arc` representing the column expression to evaluate + /// - `$SCHEMA`: A `&Schema` reference for the input batch + /// + /// This macro creates an IN LIST expression, evaluates it against the batch, converts the result + /// to a `BooleanArray`, and asserts that it matches the expected output. Use this when the column + /// and list expressions are already the correct types and don't require casting. macro_rules! in_list_raw { ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap(); @@ -552,8 +622,7 @@ mod tests { .evaluate(&$BATCH)? .into_array($BATCH.num_rows()) .expect("Failed to convert to array"); - let result = - as_boolean_array(&result).expect("failed to downcast to BooleanArray"); + let result = as_boolean_array(&result); let expected = &BooleanArray::from($EXPECTED); assert_eq!(expected, result); }}; @@ -1134,10 +1203,11 @@ mod tests { expressions::cast(lit(2i32), &schema, DataType::Int64)?, try_cast(lit(3.13f32), &schema, DataType::Int64)?, ]; + let set_array = try_evaluate_constant_list(&phy_exprs, &schema)?; let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); let array = Int64Array::from(vec![1, 2, 3, 4]); - let r = result.contains(&array, false).unwrap(); + let r = result.contains(&array, set_array.as_ref(), false).unwrap(); assert_eq!(r, BooleanArray::from(vec![true, true, true, false])); try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); @@ -1514,4 +1584,1166 @@ mod tests { assert_snapshot!(display_string, @"a@0 NOT IN (SET) ([a, b, NULL])"); Ok(()) } + + #[test] + fn in_list_struct() -> Result<()> { + // Create schema with a struct column + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(struct_fields.clone()), + true, + )]); + + // Create test data: array of structs + let x_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let y_array = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let struct_array = + StructArray::new(struct_fields.clone(), vec![x_array, y_array], None); + + let col_a = col("a", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + + // Create literal structs for the IN list + // Struct {x: 1, y: "a"} + let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + ))); + + // Struct {x: 3, y: "c"} + let struct3 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![3])), + Arc::new(StringArray::from(vec!["c"])), + ], + None, + ))); + + // Test: a IN ({1, "a"}, {3, "c"}) + let list = vec![lit(struct1.clone()), lit(struct3.clone())]; + in_list_raw!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false), Some(true)], + Arc::clone(&col_a), + &schema + ); + + // Test: a NOT IN ({1, "a"}, {3, "c"}) + in_list_raw!( + batch, + list, + &true, + vec![Some(false), Some(true), Some(false)], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_struct_with_nulls() -> Result<()> { + // Create schema with a struct column + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(struct_fields.clone()), + true, + )]); + + // Create test data with a null struct + let x_array = Arc::new(Int32Array::from(vec![1, 2])); + let y_array = Arc::new(StringArray::from(vec!["a", "b"])); + let struct_array = StructArray::new( + struct_fields.clone(), + vec![x_array, y_array], + Some(NullBuffer::from(vec![true, false])), + ); + + let col_a = col("a", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + + // Create literal struct for the IN list + let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + ))); + + // Test: a IN ({1, "a"}) + let list = vec![lit(struct1.clone())]; + in_list_raw!( + batch, + list.clone(), + &false, + vec![Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // Test: a NOT IN ({1, "a"}) + in_list_raw!( + batch, + list, + &true, + vec![Some(false), None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_struct_with_null_in_list() -> Result<()> { + // Create schema with a struct column + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(struct_fields.clone()), + true, + )]); + + // Create test data + let x_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let y_array = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let struct_array = + StructArray::new(struct_fields.clone(), vec![x_array, y_array], None); + + let col_a = col("a", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + + // Create literal structs including a NULL + let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + ))); + + let null_struct = ScalarValue::Struct(Arc::new(StructArray::new_null( + struct_fields.clone(), + 1, + ))); + + // Test: a IN ({1, "a"}, NULL) + let list = vec![lit(struct1), lit(null_struct.clone())]; + in_list_raw!( + batch, + list.clone(), + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // Test: a NOT IN ({1, "a"}, NULL) + in_list_raw!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_nested_struct() -> Result<()> { + // Create nested struct schema + let inner_struct_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ]); + let outer_struct_fields = Fields::from(vec![ + Field::new( + "inner", + DataType::Struct(inner_struct_fields.clone()), + false, + ), + Field::new("c", DataType::Int32, false), + ]); + let schema = Schema::new(vec![Field::new( + "x", + DataType::Struct(outer_struct_fields.clone()), + true, + )]); + + // Create test data with nested structs + let inner1 = Arc::new(StructArray::new( + inner_struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + None, + )); + let c_array = Arc::new(Int32Array::from(vec![10, 20])); + let outer_array = + StructArray::new(outer_struct_fields.clone(), vec![inner1, c_array], None); + + let col_x = col("x", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(outer_array)])?; + + // Create a nested struct literal matching the first row + let inner_match = Arc::new(StructArray::new( + inner_struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["x"])), + ], + None, + )); + let outer_match = ScalarValue::Struct(Arc::new(StructArray::new( + outer_struct_fields.clone(), + vec![inner_match, Arc::new(Int32Array::from(vec![10]))], + None, + ))); + + // Test: x IN ({{1, "x"}, 10}) + let list = vec![lit(outer_match)]; + in_list_raw!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false)], + Arc::clone(&col_x), + &schema + ); + + // Test: x NOT IN ({{1, "x"}, 10}) + in_list_raw!( + batch, + list, + &true, + vec![Some(false), Some(true)], + Arc::clone(&col_x), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_struct_with_exprs_not_array() -> Result<()> { + // Test InList using expressions (not the array constructor) with structs + // By using InListExpr::new directly, we bypass the array optimization + // and use the Exprs variant, testing the expression evaluation path + + // Create schema with a struct column {x: Int32, y: Utf8} + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(struct_fields.clone()), + true, + )]); + + // Create test data: array of structs [{1, "a"}, {2, "b"}, {3, "c"}] + let x_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let y_array = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let struct_array = + StructArray::new(struct_fields.clone(), vec![x_array, y_array], None); + + let col_a = col("a", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + + // Create struct literals with the SAME shape (so types are compatible) + // Struct {x: 1, y: "a"} + let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + ))); + + // Struct {x: 3, y: "c"} + let struct3 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![3])), + Arc::new(StringArray::from(vec!["c"])), + ], + None, + ))); + + // Create list of struct expressions + let list = vec![lit(struct1), lit(struct3)]; + + // Use InListExpr::new directly (not in_list()) to bypass array optimization + // This creates an InList without a static filter + let expr = Arc::new(InListExpr::new(Arc::clone(&col_a), list, false, None)); + + // Verify that the expression doesn't have a static filter + // by checking the display string does NOT contain "(SET)" + let display_string = expr.to_string(); + assert!( + !display_string.contains("(SET)"), + "Expected display string to NOT contain '(SET)' (should use Exprs variant), but got: {display_string}", + ); + + // Evaluate the expression + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + + // Expected: first row {1, "a"} matches struct1, + // second row {2, "b"} doesn't match, + // third row {3, "c"} matches struct3 + let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); + assert_eq!(result, &expected); + + // Test NOT IN as well + let expr_not = Arc::new(InListExpr::new( + Arc::clone(&col_a), + vec![ + lit(ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + )))), + lit(ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![3])), + Arc::new(StringArray::from(vec!["c"])), + ], + None, + )))), + ], + true, + None, + )); + + let result_not = expr_not.evaluate(&batch)?.into_array(batch.num_rows())?; + let result_not = as_boolean_array(&result_not); + + let expected_not = BooleanArray::from(vec![Some(false), Some(true), Some(false)]); + assert_eq!(result_not, &expected_not); + + Ok(()) + } + + #[test] + fn test_in_list_null_handling_comprehensive() -> Result<()> { + // Comprehensive test demonstrating SQL three-valued logic for IN expressions + // This test explicitly shows all possible outcomes: true, false, and null + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + + // Test data: [1, 2, 3, null] + // - 1 will match in both lists + // - 2 will not match in either list + // - 3 will not match in either list + // - null is always null + let a = Int64Array::from(vec![Some(1), Some(2), Some(3), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // Case 1: List WITHOUT null - demonstrates true/false/null outcomes + // "a IN (1, 4)" - 1 matches, 2 and 3 don't match, null is null + let list = vec![lit(1i64), lit(4i64)]; + in_list!( + batch, + list, + &false, + vec![ + Some(true), // 1 is in the list → true + Some(false), // 2 is not in the list → false + Some(false), // 3 is not in the list → false + None, // null IN (...) → null (SQL three-valued logic) + ], + Arc::clone(&col_a), + &schema + ); + + // Case 2: List WITH null - demonstrates null propagation for non-matches + // "a IN (1, NULL)" - 1 matches (true), 2/3 don't match but list has null (null), null is null + let list = vec![lit(1i64), lit(ScalarValue::Int64(None))]; + in_list!( + batch, + list, + &false, + vec![ + Some(true), // 1 is in the list → true (found match) + None, // 2 is not in list, but list has NULL → null (might match NULL) + None, // 3 is not in list, but list has NULL → null (might match NULL) + None, // null IN (...) → null (SQL three-valued logic) + ], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn test_in_list_with_only_nulls() -> Result<()> { + // Edge case: IN list contains ONLY null values + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let a = Int64Array::from(vec![Some(1), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // "a IN (NULL, NULL)" - list has only nulls + let list = vec![lit(ScalarValue::Int64(None)), lit(ScalarValue::Int64(None))]; + + // All results should be NULL because: + // - Non-null values (1, 2) can't match anything concrete, but list might contain matching value + // - NULL value is always NULL in IN expressions + in_list!( + batch, + list.clone(), + &false, + vec![None, None, None], + Arc::clone(&col_a), + &schema + ); + + // "a NOT IN (NULL, NULL)" - list has only nulls + // All results should still be NULL due to three-valued logic + in_list!( + batch, + list, + &true, + vec![None, None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn test_in_list_multiple_nulls_deduplication() -> Result<()> { + // Test that multiple NULLs in the list are handled correctly + // This verifies deduplication doesn't break null handling + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let col_a = col("a", &schema)?; + + // Create array with multiple nulls: [1, 2, NULL, NULL, 3, NULL] + let array = Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + None, + None, + Some(3), + None, + ])) as ArrayRef; + + // Create InListExpr from array + let expr = Arc::new(InListExpr::try_new_from_array( + Arc::clone(&col_a), + array, + false, + )?) as Arc; + + // Create test data: [1, 2, 3, 4, null] + let a = Int64Array::from(vec![Some(1), Some(2), Some(3), Some(4), None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // Evaluate the expression + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + + // Expected behavior with multiple NULLs in list: + // - Values in the list (1,2,3) → true + // - Values not in the list (4) → NULL (because list contains NULL) + // - NULL input → NULL + let expected = BooleanArray::from(vec![ + Some(true), // 1 is in list + Some(true), // 2 is in list + Some(true), // 3 is in list + None, // 4 not in list, but list has NULLs + None, // NULL input + ]); + assert_eq!(result, &expected); + + Ok(()) + } + + #[test] + fn test_not_in_null_handling_comprehensive() -> Result<()> { + // Comprehensive test demonstrating SQL three-valued logic for NOT IN expressions + // This test explicitly shows all possible outcomes for NOT IN: true, false, and null + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + + // Test data: [1, 2, 3, null] + let a = Int64Array::from(vec![Some(1), Some(2), Some(3), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // Case 1: List WITHOUT null - demonstrates true/false/null outcomes for NOT IN + // "a NOT IN (1, 4)" - 1 matches (false), 2 and 3 don't match (true), null is null + let list = vec![lit(1i64), lit(4i64)]; + in_list!( + batch, + list, + &true, + vec![ + Some(false), // 1 is in the list → NOT IN returns false + Some(true), // 2 is not in the list → NOT IN returns true + Some(true), // 3 is not in the list → NOT IN returns true + None, // null NOT IN (...) → null (SQL three-valued logic) + ], + Arc::clone(&col_a), + &schema + ); + + // Case 2: List WITH null - demonstrates null propagation for NOT IN + // "a NOT IN (1, NULL)" - 1 matches (false), 2/3 don't match but list has null (null), null is null + let list = vec![lit(1i64), lit(ScalarValue::Int64(None))]; + in_list!( + batch, + list, + &true, + vec![ + Some(false), // 1 is in the list → NOT IN returns false + None, // 2 is not in known values, but list has NULL → null (can't prove it's not in list) + None, // 3 is not in known values, but list has NULL → null (can't prove it's not in list) + None, // null NOT IN (...) → null (SQL three-valued logic) + ], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn test_in_list_null_type_column() -> Result<()> { + // Test with a column that has DataType::Null (not just nullable values) + // All values in a NullArray are null by definition + let schema = Schema::new(vec![Field::new("a", DataType::Null, true)]); + let a = NullArray::new(3); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // "null_column IN (1, 2)" - comparing Null type against Int64 list + // Note: This tests type coercion behavior between Null and Int64 + let list = vec![lit(1i64), lit(2i64)]; + + // All results should be NULL because: + // - Every value in the column is null (DataType::Null) + // - null IN (anything) always returns null per SQL three-valued logic + in_list!( + batch, + list.clone(), + &false, + vec![None, None, None], + Arc::clone(&col_a), + &schema + ); + + // "null_column NOT IN (1, 2)" + // Same behavior for NOT IN - null NOT IN (anything) is still null + in_list!( + batch, + list, + &true, + vec![None, None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn test_in_list_null_type_list() -> Result<()> { + // Test with a list that has DataType::Null + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let a = Int64Array::from(vec![Some(1), Some(2), None]); + let col_a = col("a", &schema)?; + + // Create a NullArray as the list + let null_array = Arc::new(NullArray::new(2)) as ArrayRef; + + // Try to create InListExpr with a NullArray list + // This tests whether try_new_from_array can handle Null type arrays + let expr = Arc::new(InListExpr::try_new_from_array( + Arc::clone(&col_a), + null_array, + false, + )?) as Arc; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + + // If it succeeds, all results should be NULL + // because the list contains only null type values + let expected = BooleanArray::from(vec![None, None, None]); + assert_eq!(result, &expected); + + Ok(()) + } + + #[test] + fn test_in_list_null_type_both() -> Result<()> { + // Test when both column and list are DataType::Null + let schema = Schema::new(vec![Field::new("a", DataType::Null, true)]); + let a = NullArray::new(3); + let col_a = col("a", &schema)?; + + // Create a NullArray as the list + let null_array = Arc::new(NullArray::new(2)) as ArrayRef; + + // Try to create InListExpr with both Null types + let expr = Arc::new(InListExpr::try_new_from_array( + Arc::clone(&col_a), + null_array, + false, + )?) as Arc; + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + + // If successful, all results should be NULL + // null IN [null, null] -> null + let expected = BooleanArray::from(vec![None, None, None]); + assert_eq!(result, &expected); + + Ok(()) + } + + #[test] + fn test_in_list_comprehensive_null_handling() -> Result<()> { + // Comprehensive test for IN LIST operations with various NULL handling scenarios. + // This test covers the key cases validated against DuckDB as the source of truth. + // + // Note: Some scalar literal tests (like NULL IN (1, 2)) are omitted as they + // appear to expose an issue with static filter optimization. These are covered + // by existing tests like in_list_no_cols(). + + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + let col_b = col("b", &schema)?; + let null_i32 = ScalarValue::Int32(None); + + // Helper to create a batch + let make_batch = |values: Vec>| -> Result { + let array = Arc::new(Int32Array::from(values)); + Ok(RecordBatch::try_new(Arc::clone(&schema), vec![array])?) + }; + + // Helper to run a test + let run_test = |batch: &RecordBatch, + expr: Arc, + list: Vec>, + expected: Vec>| + -> Result<()> { + let in_expr = in_list(expr, list, &false, schema.as_ref())?; + let result = in_expr.evaluate(batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!(result, &BooleanArray::from(expected)); + Ok(()) + }; + + // ======================================================================== + // COLUMN TESTS - col(b) IN [1, 2] + // ======================================================================== + + // [1] IN (1, 2) => [TRUE] + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(true)], + )?; + + // [1, 2] IN (1, 2) => [TRUE, TRUE] + let batch = make_batch(vec![Some(1), Some(2)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(true), Some(true)], + )?; + + // [3, 4] IN (1, 2) => [FALSE, FALSE] + let batch = make_batch(vec![Some(3), Some(4)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(false), Some(false)], + )?; + + // [1, NULL] IN (1, 2) => [TRUE, NULL] + let batch = make_batch(vec![Some(1), None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(true), None], + )?; + + // [3, NULL] IN (1, 2) => [FALSE, NULL] (no match, NULL is NULL) + let batch = make_batch(vec![Some(3), None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(false), None], + )?; + + // ======================================================================== + // COLUMN WITH NULL IN LIST - col(b) IN [NULL, 1] + // ======================================================================== + + // [1] IN (NULL, 1) => [TRUE] (found match) + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(1i32)], + vec![Some(true)], + )?; + + // [2] IN (NULL, 1) => [NULL] (no match, but list has NULL) + let batch = make_batch(vec![Some(2)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(1i32)], + vec![None], + )?; + + // [NULL] IN (NULL, 1) => [NULL] + let batch = make_batch(vec![None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(1i32)], + vec![None], + )?; + + // ======================================================================== + // COLUMN WITH ALL NULLS IN LIST - col(b) IN [NULL, NULL] + // ======================================================================== + + // [1] IN (NULL, NULL) => [NULL] + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(null_i32.clone())], + vec![None], + )?; + + // [NULL] IN (NULL, NULL) => [NULL] + let batch = make_batch(vec![None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(null_i32.clone())], + vec![None], + )?; + + // ======================================================================== + // LITERAL IN LIST WITH COLUMN - lit(1) IN [2, col(b)] + // ======================================================================== + + // 1 IN (2, [1]) => [TRUE] (matches column value) + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + lit(1i32), + vec![lit(2i32), Arc::clone(&col_b)], + vec![Some(true)], + )?; + + // 1 IN (2, [3]) => [FALSE] (no match) + let batch = make_batch(vec![Some(3)])?; + run_test( + &batch, + lit(1i32), + vec![lit(2i32), Arc::clone(&col_b)], + vec![Some(false)], + )?; + + // 1 IN (2, [NULL]) => [NULL] (no match, column is NULL) + let batch = make_batch(vec![None])?; + run_test( + &batch, + lit(1i32), + vec![lit(2i32), Arc::clone(&col_b)], + vec![None], + )?; + + // ======================================================================== + // COLUMN IN LIST CONTAINING ITSELF - col(b) IN [1, col(b)] + // ======================================================================== + + // [1] IN (1, [1]) => [TRUE] (always matches - either list literal or itself) + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), Arc::clone(&col_b)], + vec![Some(true)], + )?; + + // [2] IN (1, [2]) => [TRUE] (matches itself) + let batch = make_batch(vec![Some(2)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), Arc::clone(&col_b)], + vec![Some(true)], + )?; + + // [NULL] IN (1, [NULL]) => [NULL] (NULL is never equal to anything) + let batch = make_batch(vec![None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), Arc::clone(&col_b)], + vec![None], + )?; + + Ok(()) + } + + #[test] + fn test_in_list_scalar_literal_cases() -> Result<()> { + // Test scalar literal cases (both NULL and non-NULL) to ensure SQL three-valued + // logic is correctly implemented. This covers the important case where a scalar + // value is tested against a list containing NULL. + + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + let null_i32 = ScalarValue::Int32(None); + + // Helper to create a batch + let make_batch = |values: Vec>| -> Result { + let array = Arc::new(Int32Array::from(values)); + Ok(RecordBatch::try_new(Arc::clone(&schema), vec![array])?) + }; + + // Helper to run a test + let run_test = |batch: &RecordBatch, + expr: Arc, + list: Vec>, + negated: bool, + expected: Vec>| + -> Result<()> { + let in_expr = in_list(expr, list, &negated, schema.as_ref())?; + let result = in_expr.evaluate(batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + let expected_array = BooleanArray::from(expected); + assert_eq!( + result, + &expected_array, + "Expected {:?}, got {:?}", + expected_array, + result.iter().collect::>() + ); + Ok(()) + }; + + let batch = make_batch(vec![Some(1)])?; + + // ======================================================================== + // NULL LITERAL TESTS + // According to SQL semantics, NULL IN (any_list) should always return NULL + // ======================================================================== + + // NULL IN (1, 1) => NULL + run_test( + &batch, + lit(null_i32.clone()), + vec![lit(1i32), lit(1i32)], + false, + vec![None], + )?; + + // NULL IN (NULL, 1) => NULL + run_test( + &batch, + lit(null_i32.clone()), + vec![lit(null_i32.clone()), lit(1i32)], + false, + vec![None], + )?; + + // NULL IN (NULL, NULL) => NULL + run_test( + &batch, + lit(null_i32.clone()), + vec![lit(null_i32.clone()), lit(null_i32.clone())], + false, + vec![None], + )?; + + // ======================================================================== + // NON-NULL SCALAR LITERALS WITH NULL IN LIST - Int32 + // When a scalar value is NOT in a list containing NULL, the result is NULL + // When a scalar value IS in the list, the result is TRUE (NULL doesn't matter) + // ======================================================================== + + // 3 IN (0, 1, 2, NULL) => NULL (not in list, but list has NULL) + run_test( + &batch, + lit(3i32), + vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], + false, + vec![None], + )?; + + // 3 NOT IN (0, 1, 2, NULL) => NULL (not in list, but list has NULL) + run_test( + &batch, + lit(3i32), + vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], + true, + vec![None], + )?; + + // 1 IN (0, 1, 2, NULL) => TRUE (found match, NULL doesn't matter) + run_test( + &batch, + lit(1i32), + vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], + false, + vec![Some(true)], + )?; + + // 1 NOT IN (0, 1, 2, NULL) => FALSE (found match, NULL doesn't matter) + run_test( + &batch, + lit(1i32), + vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], + true, + vec![Some(false)], + )?; + + // ======================================================================== + // NON-NULL SCALAR LITERALS WITH NULL IN LIST - String + // Same semantics as Int32 but with string type + // ======================================================================== + + let schema_str = + Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); + let batch_str = RecordBatch::try_new( + Arc::clone(&schema_str), + vec![Arc::new(StringArray::from(vec![Some("dummy")]))], + )?; + let null_str = ScalarValue::Utf8(None); + + let run_test_str = |expr: Arc, + list: Vec>, + negated: bool, + expected: Vec>| + -> Result<()> { + let in_expr = in_list(expr, list, &negated, schema_str.as_ref())?; + let result = in_expr + .evaluate(&batch_str)? + .into_array(batch_str.num_rows())?; + let result = as_boolean_array(&result); + let expected_array = BooleanArray::from(expected); + assert_eq!( + result, + &expected_array, + "Expected {:?}, got {:?}", + expected_array, + result.iter().collect::>() + ); + Ok(()) + }; + + // 'c' IN ('a', 'b', NULL) => NULL (not in list, but list has NULL) + run_test_str( + lit("c"), + vec![lit("a"), lit("b"), lit(null_str.clone())], + false, + vec![None], + )?; + + // 'c' NOT IN ('a', 'b', NULL) => NULL (not in list, but list has NULL) + run_test_str( + lit("c"), + vec![lit("a"), lit("b"), lit(null_str.clone())], + true, + vec![None], + )?; + + // 'a' IN ('a', 'b', NULL) => TRUE (found match, NULL doesn't matter) + run_test_str( + lit("a"), + vec![lit("a"), lit("b"), lit(null_str.clone())], + false, + vec![Some(true)], + )?; + + // 'a' NOT IN ('a', 'b', NULL) => FALSE (found match, NULL doesn't matter) + run_test_str( + lit("a"), + vec![lit("a"), lit("b"), lit(null_str.clone())], + true, + vec![Some(false)], + )?; + + Ok(()) + } + + #[test] + fn test_in_list_tuple_cases() -> Result<()> { + // Test tuple/struct cases from the original request: (lit, lit) IN (lit, lit) + // These test row-wise comparisons like (1, 2) IN ((1, 2), (3, 4)) + + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + + // Helper to create struct scalars for tuple comparisons + let make_struct = |v1: Option, v2: Option| -> ScalarValue { + let fields = Fields::from(vec![ + Field::new("field_0", DataType::Int32, true), + Field::new("field_1", DataType::Int32, true), + ]); + ScalarValue::Struct(Arc::new(StructArray::new( + fields, + vec![ + Arc::new(Int32Array::from(vec![v1])), + Arc::new(Int32Array::from(vec![v2])), + ], + None, + ))) + }; + + // Need a single row batch for scalar tests + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![Some(1)]))], + )?; + + // Helper to run tuple tests + let run_tuple_test = |lhs: ScalarValue, + list: Vec, + expected: Vec>| + -> Result<()> { + let expr = in_list( + lit(lhs), + list.into_iter().map(lit).collect(), + &false, + schema.as_ref(), + )?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!(result, &BooleanArray::from(expected)); + Ok(()) + }; + + // (NULL, NULL) IN ((1, 2)) => FALSE (tuples don't match) + run_tuple_test( + make_struct(None, None), + vec![make_struct(Some(1), Some(2))], + vec![Some(false)], + )?; + + // (NULL, NULL) IN ((NULL, 1)) => FALSE + run_tuple_test( + make_struct(None, None), + vec![make_struct(None, Some(1))], + vec![Some(false)], + )?; + + // (NULL, NULL) IN ((NULL, NULL)) => TRUE (exact match including nulls) + run_tuple_test( + make_struct(None, None), + vec![make_struct(None, None)], + vec![Some(true)], + )?; + + // (NULL, 1) IN ((1, 2)) => FALSE + run_tuple_test( + make_struct(None, Some(1)), + vec![make_struct(Some(1), Some(2))], + vec![Some(false)], + )?; + + // (NULL, 1) IN ((NULL, 1)) => TRUE (exact match) + run_tuple_test( + make_struct(None, Some(1)), + vec![make_struct(None, Some(1))], + vec![Some(true)], + )?; + + // (NULL, 1) IN ((NULL, NULL)) => FALSE + run_tuple_test( + make_struct(None, Some(1)), + vec![make_struct(None, None)], + vec![Some(false)], + )?; + + // (1, 2) IN ((1, 2)) => TRUE + run_tuple_test( + make_struct(Some(1), Some(2)), + vec![make_struct(Some(1), Some(2))], + vec![Some(true)], + )?; + + // (1, 3) IN ((1, 2)) => FALSE + run_tuple_test( + make_struct(Some(1), Some(3)), + vec![make_struct(Some(1), Some(2))], + vec![Some(false)], + )?; + + // (4, 4) IN ((1, 2)) => FALSE + run_tuple_test( + make_struct(Some(4), Some(4)), + vec![make_struct(Some(1), Some(2))], + vec![Some(false)], + )?; + + // (1, 1) IN ((NULL, 1)) => FALSE + run_tuple_test( + make_struct(Some(1), Some(1)), + vec![make_struct(None, Some(1))], + vec![Some(false)], + )?; + + // (1, 1) IN ((NULL, NULL)) => FALSE + run_tuple_test( + make_struct(Some(1), Some(1)), + vec![make_struct(None, None)], + vec![Some(false)], + )?; + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 59d675753d985..0ccb5c13b3a90 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,6 +23,7 @@ mod case; mod cast; mod cast_column; mod column; +mod comparator; mod dynamic_filters; mod in_list; mod is_not_null; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index f837423d2b616..4cba85a85128d 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1680,7 +1680,7 @@ pub fn update_hash( hash_map: &mut dyn JoinHashMapType, offset: usize, random_state: &RandomState, - hashes_buffer: &mut Vec, + hashes_buffer: &mut [u64], deleted_offset: usize, fifo_hashmap: bool, ) -> Result<()> { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c69e7a19e4f78..77197721e1f14 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6436,7 +6436,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) +06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6464,7 +6464,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) +06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6492,7 +6492,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) +06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6520,7 +6520,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) +06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6548,7 +6548,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) +06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 87345b833e264..df88d26c9c9de 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -1066,6 +1066,213 @@ SELECT '2' NOT IN ('a','b',NULL,1) ---- NULL +# ======================================================================== +# Comprehensive IN LIST tests with NULL handling +# These tests validate SQL three-valued logic for IN operations +# ======================================================================== + +# test_in_list_null_literals +# NULL IN (any_list) should always return NULL per SQL three-valued logic + +query B +SELECT NULL IN (1, 1) +---- +NULL + +query B +SELECT NULL IN (NULL, 1) +---- +NULL + +query B +SELECT NULL IN (NULL, NULL) +---- +NULL + +# test_in_list_with_columns +# Create test table for column-based IN LIST tests + +statement ok +CREATE OR REPLACE TABLE in_list_test(b INT) AS VALUES (1), (2), (3), (4), (NULL); + +# Test: b IN (1, 2) with various values + +query B +SELECT b IN (1, 2) FROM in_list_test WHERE b = 1; +---- +true + +query IB +SELECT b, b IN (1, 2) FROM in_list_test WHERE b IN (1, 2) ORDER BY b; +---- +1 true +2 true + +query IB +SELECT b, b IN (1, 2) FROM in_list_test WHERE b IN (3, 4) ORDER BY b; +---- +3 false +4 false + +query B +SELECT b IN (1, 2) FROM in_list_test WHERE b = 1; +---- +true + +query B +SELECT b IN (1, 2) FROM in_list_test WHERE b = 3; +---- +false + +query B +SELECT b IN (1, 2) FROM in_list_test WHERE b IS NULL; +---- +NULL + +# Test: b IN (NULL, 1) - list contains NULL + +query B +SELECT b IN (NULL, 1) FROM in_list_test WHERE b = 1; +---- +true + +query B +SELECT b IN (NULL, 1) FROM in_list_test WHERE b = 2; +---- +NULL + +query B +SELECT b IN (NULL, 1) FROM in_list_test WHERE b IS NULL; +---- +NULL + +# Test: b IN (NULL, NULL) - list contains only NULLs + +query B +SELECT b IN (NULL, NULL) FROM in_list_test WHERE b = 1; +---- +NULL + +query B +SELECT b IN (NULL, NULL) FROM in_list_test WHERE b IS NULL; +---- +NULL + +# Test: literal IN (list_with_column) - column appears in the list + +statement ok +CREATE OR REPLACE TABLE in_list_col_test(b INT) AS VALUES (1), (3), (NULL); + +query B +SELECT 1 IN (2, b) FROM in_list_col_test WHERE b = 1; +---- +true + +query B +SELECT 1 IN (2, b) FROM in_list_col_test WHERE b = 3; +---- +false + +query B +SELECT 1 IN (2, b) FROM in_list_col_test WHERE b IS NULL; +---- +NULL + +# Test: b IN (1, b) - column references itself in list + +query B +SELECT b IN (1, b) FROM in_list_col_test WHERE b = 1; +---- +true + +query B +SELECT b IN (1, b) FROM in_list_col_test WHERE b = 3; +---- +true + +query B +SELECT b IN (1, b) FROM in_list_col_test WHERE b IS NULL; +---- +NULL + +# test_in_list_tuples +# Test tuple/row-wise IN comparisons using struct syntax +# Note: Using arrow_cast for precise type control + +# (NULL, NULL) IN ((1, 2)) => FALSE +query B +SELECT struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32')) IN (struct(1, 2)) +---- +false + +# (NULL, NULL) IN ((NULL, 1)) => FALSE +query B +SELECT struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32')) IN (struct(arrow_cast(NULL, 'Int32'), 1)) +---- +false + +# (NULL, NULL) IN ((NULL, NULL)) => TRUE (exact match) +query B +SELECT struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32')) IN (struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32'))) +---- +true + +# (NULL, 1) IN ((1, 2)) => FALSE +query B +SELECT struct(arrow_cast(NULL, 'Int32'), 1) IN (struct(1, 2)) +---- +false + +# (NULL, 1) IN ((NULL, 1)) => TRUE (exact match) +query B +SELECT struct(arrow_cast(NULL, 'Int32'), 1) IN (struct(arrow_cast(NULL, 'Int32'), 1)) +---- +true + +# (NULL, 1) IN ((NULL, NULL)) => FALSE +query B +SELECT struct(arrow_cast(NULL, 'Int32'), 1) IN (struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32'))) +---- +false + +# (1, 2) IN ((1, 2)) => TRUE +query B +SELECT struct(1, 2) IN (struct(1, 2)) +---- +true + +# (1, 3) IN ((1, 2)) => FALSE +query B +SELECT struct(1, 3) IN (struct(1, 2)) +---- +false + +# (4, 4) IN ((1, 2)) => FALSE +query B +SELECT struct(4, 4) IN (struct(1, 2)) +---- +false + +# (1, 1) IN ((NULL, 1)) => FALSE +query B +SELECT struct(1, 1) IN (struct(NULL, 1)) +---- +false + +# (1, 1) IN ((NULL, NULL)) => FALSE +query B +SELECT struct(1, 1) IN (struct(NULL, NULL)) +---- +false + +# Cleanup test tables + +statement ok +DROP TABLE in_list_test; + +statement ok +DROP TABLE in_list_col_test; + query T SELECT encode('tom','base64'); ---- diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part index 12efc64555b29..d20f090fa5b8f 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part @@ -69,13 +69,13 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN ([SM CASE, SM BOX, SM PACK, SM PKG]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([MED BAG, MED BOX, MED PKG, MED PACK]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([LG CASE, LG BOX, LG PACK, LG PKG]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] +06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 09)----------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON, projection=[l_partkey@0, l_quantity@1, l_extendedprice@2, l_discount@3] 10)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], file_type=csv, has_header=false 11)------------CoalesceBatchesExec: target_batch_size=8192 12)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -13)----------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15) AND p_size@2 >= 1 +13)----------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15) AND p_size@2 >= 1 14)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 15)--------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index 818c7bc989655..a9d95fb1ab79f 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -90,7 +90,7 @@ physical_plan 14)--------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2] 15)----------------------------CoalesceBatchesExec: target_batch_size=8192 16)------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 -17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([13, 31, 23, 29, 30, 18, 17]) +17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) 18)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 19)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false 20)----------------------------CoalesceBatchesExec: target_batch_size=8192 @@ -99,6 +99,6 @@ physical_plan 23)--------------------AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] 24)----------------------CoalescePartitionsExec 25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] -26)--------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] +26)--------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] 27)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 28)------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false From cd9e2f5724c3c2d1e53720b967de216233d4eec3 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 15 Nov 2025 07:27:13 +0800 Subject: [PATCH 2/6] remove enum comparator --- .../src/expressions/comparator.rs | 1444 ----------------- .../physical-expr/src/expressions/in_list.rs | 9 +- .../physical-expr/src/expressions/mod.rs | 1 - 3 files changed, 4 insertions(+), 1450 deletions(-) delete mode 100644 datafusion/physical-expr/src/expressions/comparator.rs diff --git a/datafusion/physical-expr/src/expressions/comparator.rs b/datafusion/physical-expr/src/expressions/comparator.rs deleted file mode 100644 index d0cf0ffc045db..0000000000000 --- a/datafusion/physical-expr/src/expressions/comparator.rs +++ /dev/null @@ -1,1444 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Enum-based comparator that eliminates dynamic dispatch for scalar types -//! -//! This module provides an optimized comparator implementation that uses an enum -//! with variants for each scalar Arrow type, eliminating the overhead of dynamic -//! dispatch for common comparison operations. Complex recursive types (List, Struct, -//! Map, Dictionary) fall back to dynamic dispatch. -//! -//! While we are implementing this in DataFusion for now we hope to upstream this into arrow-rs -//! and replace the existing completely dynamic comparator there with this more efficient one. - -use arrow::array::types::*; -use arrow::array::{make_comparator as arrow_make_comparator, *}; -use arrow::buffer::{BooleanBuffer, NullBuffer, ScalarBuffer}; -use arrow::compute::SortOptions; -use arrow::datatypes::{ - i256, DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, TimeUnit, -}; -use arrow::error::ArrowError; -use std::cmp::Ordering; - -// Type alias for dynamic comparator (same as arrow_ord::ord::DynComparator) -type DynComparator = Box Ordering + Send + Sync>; - -/// Comparator that uses enum dispatch for scalar types and dynamic dispatch for complex types -pub(crate) enum Comparator { - // Primitive integer types - Int8 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Int16 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Int32 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Int64 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Unsigned integer types - UInt8 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - UInt16 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - UInt32 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - UInt64 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Floating point types - Float16 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Float32 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Float64 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Date and time types - Date32 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Date64 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Time32Second { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Time32Millisecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Time64Microsecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Time64Nanosecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Timestamp types - TimestampSecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - TimestampMillisecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - TimestampMicrosecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - TimestampNanosecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Duration types - DurationSecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - DurationMillisecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - DurationMicrosecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - DurationNanosecond { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Interval types - IntervalYearMonth { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - IntervalDayTime { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - IntervalMonthDayNano { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Decimal types - Decimal128 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Decimal256 { - left: ScalarBuffer, - right: ScalarBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Boolean type - Boolean { - left: BooleanBuffer, - right: BooleanBuffer, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // String types - Utf8 { - left: GenericByteArray, - right: GenericByteArray, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - LargeUtf8 { - left: GenericByteArray, - right: GenericByteArray, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - Utf8View { - left: GenericByteViewArray, - right: GenericByteViewArray, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Binary types - Binary { - left: GenericByteArray, - right: GenericByteArray, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - LargeBinary { - left: GenericByteArray, - right: GenericByteArray, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - BinaryView { - left: GenericByteViewArray, - right: GenericByteViewArray, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // FixedSizeBinary - FixedSizeBinary { - left: FixedSizeBinaryArray, - right: FixedSizeBinaryArray, - left_nulls: Option, - right_nulls: Option, - opts: SortOptions, - }, - - // Dynamic fallback for recursive/complex types: - // - List, LargeList, FixedSizeList - // - Struct - // - Map - // - Dictionary - Dynamic(DynComparator), -} - -/// Helper macro to reduce duplication for float comparisons using total_cmp -macro_rules! compare_float { - ($left:expr, $right:expr, $left_nulls:expr, $right_nulls:expr, $opts:expr, $i:expr, $j:expr) => {{ - let ord = match ( - $left_nulls.as_ref().is_some_and(|n| n.is_null($i)), - $right_nulls.as_ref().is_some_and(|n| n.is_null($j)), - ) { - (true, true) => return Ordering::Equal, - (true, false) => { - return if $opts.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - return if $opts.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - (false, false) => { - let left_slice = $left.as_ref(); - let right_slice = $right.as_ref(); - left_slice[$i].total_cmp(&right_slice[$j]) - } - }; - - if $opts.descending { - ord.reverse() - } else { - ord - } - }}; -} - -impl Comparator { - /// Compare elements at indices i (from left array) and j (from right array) - #[inline] - pub fn compare(&self, i: usize, j: usize) -> Ordering { - match self { - Self::Int8 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Int16 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Int32 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Int64 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::UInt8 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::UInt16 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::UInt32 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::UInt64 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Float16 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_float!(left, right, left_nulls, right_nulls, opts, i, j), - Self::Float32 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_float!(left, right, left_nulls, right_nulls, opts, i, j), - Self::Float64 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_float!(left, right, left_nulls, right_nulls, opts, i, j), - Self::Date32 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Date64 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Time32Second { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Time32Millisecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Time64Microsecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Time64Nanosecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::TimestampSecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::TimestampMillisecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::TimestampMicrosecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::TimestampNanosecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::DurationSecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::DurationMillisecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::DurationMicrosecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::DurationNanosecond { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::IntervalYearMonth { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::IntervalDayTime { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::IntervalMonthDayNano { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Decimal128 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Decimal256 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_ord_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Boolean { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_boolean_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Utf8 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_bytes_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::LargeUtf8 { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_bytes_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::Utf8View { - left, - right, - left_nulls, - right_nulls, - opts, - } => { - compare_byte_view_values(left, right, left_nulls, right_nulls, opts, i, j) - } - Self::Binary { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_bytes_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::LargeBinary { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_bytes_values(left, right, left_nulls, right_nulls, opts, i, j), - Self::BinaryView { - left, - right, - left_nulls, - right_nulls, - opts, - } => { - compare_byte_view_values(left, right, left_nulls, right_nulls, opts, i, j) - } - Self::FixedSizeBinary { - left, - right, - left_nulls, - right_nulls, - opts, - } => compare_fixed_binary_values( - left, - right, - left_nulls, - right_nulls, - opts, - i, - j, - ), - Self::Dynamic(cmp) => cmp(i, j), - } - } -} - -// Helper functions for comparing values with null handling -use arrow::datatypes::ArrowNativeType; - -/// Compare values using Ord::cmp for types that implement Ord (integers, decimals, intervals, etc.) -#[inline] -fn compare_ord_values( - left: &ScalarBuffer, - right: &ScalarBuffer, - left_nulls: &Option, - right_nulls: &Option, - opts: &SortOptions, - i: usize, - j: usize, -) -> Ordering { - // Check nulls first - let ord = match ( - left_nulls.as_ref().is_some_and(|n| n.is_null(i)), - right_nulls.as_ref().is_some_and(|n| n.is_null(j)), - ) { - (true, true) => return Ordering::Equal, - (true, false) => { - return if opts.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - return if opts.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - (false, false) => { - let left_slice: &[T] = left.as_ref(); - let right_slice: &[T] = right.as_ref(); - left_slice[i].cmp(&right_slice[j]) - } - }; - - if opts.descending { - ord.reverse() - } else { - ord - } -} - -#[inline] -fn compare_boolean_values( - left: &BooleanBuffer, - right: &BooleanBuffer, - left_nulls: &Option, - right_nulls: &Option, - opts: &SortOptions, - i: usize, - j: usize, -) -> Ordering { - // Check nulls first - let ord = match ( - left_nulls.as_ref().is_some_and(|n| n.is_null(i)), - right_nulls.as_ref().is_some_and(|n| n.is_null(j)), - ) { - (true, true) => return Ordering::Equal, - (true, false) => { - return if opts.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - return if opts.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - (false, false) => left.value(i).cmp(&right.value(j)), - }; - - if opts.descending { - ord.reverse() - } else { - ord - } -} - -#[inline] -fn compare_bytes_values( - left: &GenericByteArray, - right: &GenericByteArray, - left_nulls: &Option, - right_nulls: &Option, - opts: &SortOptions, - i: usize, - j: usize, -) -> Ordering { - // Check nulls first - let ord = match ( - left_nulls.as_ref().is_some_and(|n| n.is_null(i)), - right_nulls.as_ref().is_some_and(|n| n.is_null(j)), - ) { - (true, true) => return Ordering::Equal, - (true, false) => { - return if opts.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - return if opts.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - (false, false) => { - let l: &[u8] = left.value(i).as_ref(); - let r: &[u8] = right.value(j).as_ref(); - l.cmp(r) - } - }; - - if opts.descending { - ord.reverse() - } else { - ord - } -} - -#[inline] -fn compare_byte_view_values( - left: &GenericByteViewArray, - right: &GenericByteViewArray, - left_nulls: &Option, - right_nulls: &Option, - opts: &SortOptions, - i: usize, - j: usize, -) -> Ordering { - // Check nulls first - let ord = match ( - left_nulls.as_ref().is_some_and(|n| n.is_null(i)), - right_nulls.as_ref().is_some_and(|n| n.is_null(j)), - ) { - (true, true) => return Ordering::Equal, - (true, false) => { - return if opts.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - return if opts.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - (false, false) => { - let l: &[u8] = left.value(i).as_ref(); - let r: &[u8] = right.value(j).as_ref(); - l.cmp(r) - } - }; - - if opts.descending { - ord.reverse() - } else { - ord - } -} - -#[inline] -fn compare_fixed_binary_values( - left: &FixedSizeBinaryArray, - right: &FixedSizeBinaryArray, - left_nulls: &Option, - right_nulls: &Option, - opts: &SortOptions, - i: usize, - j: usize, -) -> Ordering { - // Check nulls first - let ord = match ( - left_nulls.as_ref().is_some_and(|n| n.is_null(i)), - right_nulls.as_ref().is_some_and(|n| n.is_null(j)), - ) { - (true, true) => return Ordering::Equal, - (true, false) => { - return if opts.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - return if opts.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - (false, false) => left.value(i).cmp(right.value(j)), - }; - - if opts.descending { - ord.reverse() - } else { - ord - } -} - -/// Create a comparator for the given arrays and sort options. -/// -/// This wraps Arrow's `make_comparator` but returns our enum-based `Comparator` -/// for scalar types, falling back to dynamic dispatch for complex types. -/// -/// # Errors -/// If the data types of the arrays are not supported for comparison. -pub(crate) fn make_comparator( - left: &dyn Array, - right: &dyn Array, - opts: SortOptions, -) -> Result { - use DataType::*; - - let left_nulls = left.nulls().filter(|x| x.null_count() > 0).cloned(); - let right_nulls = right.nulls().filter(|x| x.null_count() > 0).cloned(); - - Ok(match (left.data_type(), right.data_type()) { - (Int8, Int8) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Int8 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Int16, Int16) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Int16 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Int32, Int32) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Int32 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Int64, Int64) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Int64 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (UInt8, UInt8) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::UInt8 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (UInt16, UInt16) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::UInt16 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (UInt32, UInt32) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::UInt32 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (UInt64, UInt64) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::UInt64 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Float16, Float16) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Float16 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Float32, Float32) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Float32 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Float64, Float64) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Float64 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Date32, Date32) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Date32 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Date64, Date64) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Date64 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Time32(TimeUnit::Second), Time32(TimeUnit::Second)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Time32Second { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Millisecond)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Time32Millisecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Microsecond)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Time64Microsecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Nanosecond)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Time64Nanosecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Timestamp(TimeUnit::Second, _), Timestamp(TimeUnit::Second, _)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::TimestampSecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Timestamp(TimeUnit::Millisecond, _), Timestamp(TimeUnit::Millisecond, _)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::TimestampMillisecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Timestamp(TimeUnit::Microsecond, _), Timestamp(TimeUnit::Microsecond, _)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::TimestampMicrosecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Timestamp(TimeUnit::Nanosecond, _), Timestamp(TimeUnit::Nanosecond, _)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::TimestampNanosecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Duration(TimeUnit::Second), Duration(TimeUnit::Second)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::DurationSecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Duration(TimeUnit::Millisecond), Duration(TimeUnit::Millisecond)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::DurationMillisecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Duration(TimeUnit::Microsecond), Duration(TimeUnit::Microsecond)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::DurationMicrosecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Duration(TimeUnit::Nanosecond), Duration(TimeUnit::Nanosecond)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::DurationNanosecond { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Interval(IntervalUnit::YearMonth), Interval(IntervalUnit::YearMonth)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::IntervalYearMonth { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::DayTime)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::IntervalDayTime { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Interval(IntervalUnit::MonthDayNano), Interval(IntervalUnit::MonthDayNano)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::IntervalMonthDayNano { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Decimal128(_, _), Decimal128(_, _)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Decimal128 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Decimal256(_, _), Decimal256(_, _)) => { - let left = left.as_primitive::(); - let right = right.as_primitive::(); - Comparator::Decimal256 { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Boolean, Boolean) => { - let left = left.as_boolean(); - let right = right.as_boolean(); - Comparator::Boolean { - left: left.values().clone(), - right: right.values().clone(), - left_nulls, - right_nulls, - opts, - } - } - (Utf8, Utf8) => { - let left = left.as_string::(); - let right = right.as_string::(); - Comparator::Utf8 { - left: left.clone(), - right: right.clone(), - left_nulls, - right_nulls, - opts, - } - } - (LargeUtf8, LargeUtf8) => { - let left = left.as_string::(); - let right = right.as_string::(); - Comparator::LargeUtf8 { - left: left.clone(), - right: right.clone(), - left_nulls, - right_nulls, - opts, - } - } - (Utf8View, Utf8View) => { - let left = left.as_string_view(); - let right = right.as_string_view(); - Comparator::Utf8View { - left: left.clone(), - right: right.clone(), - left_nulls, - right_nulls, - opts, - } - } - (Binary, Binary) => { - let left = left.as_binary::(); - let right = right.as_binary::(); - Comparator::Binary { - left: left.clone(), - right: right.clone(), - left_nulls, - right_nulls, - opts, - } - } - (LargeBinary, LargeBinary) => { - let left = left.as_binary::(); - let right = right.as_binary::(); - Comparator::LargeBinary { - left: left.clone(), - right: right.clone(), - left_nulls, - right_nulls, - opts, - } - } - (BinaryView, BinaryView) => { - let left = left.as_binary_view(); - let right = right.as_binary_view(); - Comparator::BinaryView { - left: left.clone(), - right: right.clone(), - left_nulls, - right_nulls, - opts, - } - } - (FixedSizeBinary(_), FixedSizeBinary(_)) => { - let left = left.as_fixed_size_binary(); - let right = right.as_fixed_size_binary(); - Comparator::FixedSizeBinary { - left: left.clone(), - right: right.clone(), - left_nulls, - right_nulls, - opts, - } - } - // Fall back to dynamic dispatch for complex types - _ => { - let cmp = arrow_make_comparator(left, right, opts)?; - Comparator::Dynamic(cmp) - } - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::{ - BooleanArray, Date32Array, Float64Array, Int32Array, StringArray, - }; - - #[test] - fn test_int32_compare() { - let left = Int32Array::from(vec![1, 2, 3]); - let right = Int32Array::from(vec![2, 2, 1]); - - let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); - - assert_eq!(cmp.compare(0, 0), Ordering::Less); // 1 < 2 - assert_eq!(cmp.compare(1, 1), Ordering::Equal); // 2 == 2 - assert_eq!(cmp.compare(2, 2), Ordering::Greater); // 3 > 1 - } - - #[test] - fn test_int32_compare_with_nulls() { - let left = Int32Array::from(vec![Some(1), None, Some(3)]); - let right = Int32Array::from(vec![Some(2), Some(2), None]); - - let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); - - assert_eq!(cmp.compare(0, 0), Ordering::Less); // 1 < 2 - assert_eq!(cmp.compare(1, 1), Ordering::Less); // null < 2 (nulls_first=true) - assert_eq!(cmp.compare(2, 2), Ordering::Greater); // 3 > null - } - - #[test] - fn test_int32_descending() { - let left = Int32Array::from(vec![1, 2, 3]); - let right = Int32Array::from(vec![2, 2, 1]); - - let cmp = make_comparator( - &left, - &right, - SortOptions { - descending: true, - nulls_first: false, - }, - ) - .unwrap(); - - assert_eq!(cmp.compare(0, 0), Ordering::Greater); // 1 > 2 (descending) - assert_eq!(cmp.compare(1, 1), Ordering::Equal); // 2 == 2 - assert_eq!(cmp.compare(2, 2), Ordering::Less); // 3 < 1 (descending) - } - - #[test] - fn test_float64_compare() { - let left = Float64Array::from(vec![1.5, 2.5, f64::NAN]); - let right = Float64Array::from(vec![2.5, 2.5, 1.5]); - - let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); - - assert_eq!(cmp.compare(0, 0), Ordering::Less); // 1.5 < 2.5 - assert_eq!(cmp.compare(1, 1), Ordering::Equal); // 2.5 == 2.5 - assert_eq!(cmp.compare(2, 2), Ordering::Greater); // NaN > 1.5 (using total_cmp) - } - - #[test] - fn test_string_compare() { - let left = StringArray::from(vec!["a", "b", "c"]); - let right = StringArray::from(vec!["b", "b", "a"]); - - let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); - - assert_eq!(cmp.compare(0, 0), Ordering::Less); // "a" < "b" - assert_eq!(cmp.compare(1, 1), Ordering::Equal); // "b" == "b" - assert_eq!(cmp.compare(2, 2), Ordering::Greater); // "c" > "a" - } - - #[test] - fn test_boolean_compare() { - let left = BooleanArray::from(vec![false, true, false]); - let right = BooleanArray::from(vec![true, true, false]); - - let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); - - assert_eq!(cmp.compare(0, 0), Ordering::Less); // false < true - assert_eq!(cmp.compare(1, 1), Ordering::Equal); // true == true - assert_eq!(cmp.compare(2, 2), Ordering::Equal); // false == false - } - - #[test] - fn test_date32_compare() { - let left = Date32Array::from(vec![100, 200, 300]); - let right = Date32Array::from(vec![200, 200, 100]); - - let cmp = make_comparator(&left, &right, SortOptions::default()).unwrap(); - - assert_eq!(cmp.compare(0, 0), Ordering::Less); // 100 < 200 - assert_eq!(cmp.compare(1, 1), Ordering::Equal); // 200 == 200 - assert_eq!(cmp.compare(2, 2), Ordering::Greater); // 300 > 100 - } - - #[test] - fn test_nulls_last() { - let left = Int32Array::from(vec![Some(1), None, Some(3)]); - let right = Int32Array::from(vec![Some(2), Some(2), None]); - - let cmp = make_comparator( - &left, - &right, - SortOptions { - descending: false, - nulls_first: false, - }, - ) - .unwrap(); - - assert_eq!(cmp.compare(0, 0), Ordering::Less); // 1 < 2 - assert_eq!(cmp.compare(1, 1), Ordering::Greater); // null > 2 (nulls_first=false) - assert_eq!(cmp.compare(2, 2), Ordering::Less); // 3 < null - } -} diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 5da7e18df9131..1012fdb8b20b4 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -25,7 +25,6 @@ use std::sync::Arc; use crate::physical_expr::physical_exprs_bag_equal; use crate::PhysicalExpr; -use super::comparator::make_comparator; use arrow::array::*; use arrow::buffer::BooleanBuffer; use arrow::compute::kernels::boolean::{not, or_kleene}; @@ -126,7 +125,7 @@ impl ArrayHashSet { let contains = self .map .raw_entry() - .from_hash(hash, |idx| cmp.compare(i, *idx).is_eq()) + .from_hash(hash, |idx| cmp(i, *idx).is_eq()) .is_some(); match contains { @@ -165,7 +164,7 @@ fn make_hash_set(array: &dyn Array) -> Result { let hash = hashes[idx]; if let RawEntryMut::Vacant(v) = map .raw_entry_mut() - .from_hash(hash, |x| cmp.compare(*x, idx).is_eq()) + .from_hash(hash, |x| cmp(*x, idx).is_eq()) { v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); } @@ -381,7 +380,7 @@ impl PhysicalExpr for InListExpr { if value.is_null(i) || array.is_null(i) { return None; } - Some(cmp.compare(i, i).is_eq()) + Some(cmp(i, i).is_eq()) }) .collect::() } @@ -404,7 +403,7 @@ impl PhysicalExpr for InListExpr { if value.is_null(i) { None } else { - Some(cmp.compare(i, 0).is_eq()) + Some(cmp(i, 0).is_eq()) } }) .collect::() diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 0ccb5c13b3a90..59d675753d985 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,7 +23,6 @@ mod case; mod cast; mod cast_column; mod column; -mod comparator; mod dynamic_filters; mod in_list; mod is_not_null; From 896820e2d2ebab56346954263fcda083b19ca2f6 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 15 Nov 2025 07:31:41 +0800 Subject: [PATCH 3/6] use const thread local --- datafusion/common/src/hash_utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index e52203244321f..0fa47671d303a 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -52,7 +52,7 @@ thread_local! { /// The buffer is reused across calls and truncated if it exceeds MAX_BUFFER_SIZE. /// Defaults to a capacity of 8192 u64 elements which is the default batch size. /// This corresponds to 64KB of memory. - static HASH_BUFFER: RefCell> = RefCell::new(Vec::with_capacity(8192)); + static HASH_BUFFER: RefCell> = const { RefCell::new(Vec::new()) }; } /// Creates hashes for the given arrays using a thread-local buffer, then calls the provided callback From 621cfe5aa6fc0e2e0cb29c429e5f20fccfc97d76 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 17 Nov 2025 17:38:35 -0500 Subject: [PATCH 4/6] Consolidate StaticFilter and ArrayHashSet (#44) * Consolidate StaticFilter and ArrayHashSet * Fix docs --- .../physical-expr/src/expressions/in_list.rs | 154 ++++++++---------- 1 file changed, 71 insertions(+), 83 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 1012fdb8b20b4..aa56fba40ae9a 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -54,8 +54,13 @@ use hashbrown::hash_map::RawEntryMut; /// Static filter for InList that stores the array and hash set for O(1) lookups #[derive(Debug, Clone)] struct StaticFilter { - array: ArrayRef, - hash_set: ArrayHashSet, + in_array: ArrayRef, + state: RandomState, + /// Used to provide a lookup from value to in list index + /// + /// Note: usize::hash is not used, instead the raw entry + /// API is used to store entries w.r.t their value + map: HashMap, } /// InList @@ -76,32 +81,19 @@ impl Debug for InListExpr { } } -#[derive(Debug, Clone)] -pub(crate) struct ArrayHashSet { - state: RandomState, - /// Used to provide a lookup from value to in list index - /// - /// Note: usize::hash is not used, instead the raw entry - /// API is used to store entries w.r.t their value - map: HashMap, -} - -impl ArrayHashSet { +impl StaticFilter { /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. - fn contains( - &self, - v: &dyn Array, - in_array: &dyn Array, - negated: bool, - ) -> Result { + fn contains(&self, v: &dyn Array, negated: bool) -> Result { // Null type comparisons always return null (SQL three-valued logic) - if v.data_type() == &DataType::Null || in_array.data_type() == &DataType::Null { + if v.data_type() == &DataType::Null + || self.in_array.data_type() == &DataType::Null + { return Ok(BooleanArray::from(vec![None; v.len()])); } downcast_dictionary_array! { v => { - let values_contains = self.contains(v.values().as_ref(), in_array, negated)?; + let values_contains = self.contains(v.values().as_ref(), negated)?; let result = take(&values_contains, v.keys(), None)?; return Ok(downcast_array(result.as_ref())) } @@ -110,10 +102,10 @@ impl ArrayHashSet { let needle_nulls = v.logical_nulls(); let needle_nulls = needle_nulls.as_ref(); - let haystack_has_nulls = in_array.null_count() != 0; + let haystack_has_nulls = self.in_array.null_count() != 0; with_hashes([v], &self.state, |hashes| { - let cmp = make_comparator(v, in_array, SortOptions::default())?; + let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; Ok((0..v.len()) .map(|i| { // SQL three-valued logic: null IN (...) is always null @@ -137,51 +129,56 @@ impl ArrayHashSet { .collect()) }) } -} -/// Computes an [`ArrayHashSet`] for the provided [`Array`] if there -/// are nulls present or there are more than the configured number of -/// elements. -/// -/// Note: This is split into a separate function as higher-rank trait bounds currently -/// cause type inference to misbehave -fn make_hash_set(array: &dyn Array) -> Result { - // Null type has no natural order - return empty hash set - if array.data_type() == &DataType::Null { - return Ok(ArrayHashSet { - state: RandomState::new(), - map: HashMap::with_hasher(()), - }); - } + /// Computes a [`StaticFilter`] for the provided [`Array`] if there + /// are nulls present or there are more than the configured number of + /// elements. + /// + /// Note: This is split into a separate function as higher-rank trait bounds currently + /// cause type inference to misbehave + fn try_new(in_array: ArrayRef) -> Result { + // Null type has no natural order - return empty hash set + if in_array.data_type() == &DataType::Null { + return Ok(StaticFilter { + in_array, + state: RandomState::new(), + map: HashMap::with_hasher(()), + }); + } - let state = RandomState::new(); - let mut map: HashMap = HashMap::with_hasher(()); + let state = RandomState::new(); + let mut map: HashMap = HashMap::with_hasher(()); - with_hashes([array], &state, |hashes| -> Result<()> { - let cmp = make_comparator(array, array, SortOptions::default())?; + with_hashes([&in_array], &state, |hashes| -> Result<()> { + let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; - let insert_value = |idx| { - let hash = hashes[idx]; - if let RawEntryMut::Vacant(v) = map - .raw_entry_mut() - .from_hash(hash, |x| cmp(*x, idx).is_eq()) - { - v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); - } - }; + let insert_value = |idx| { + let hash = hashes[idx]; + if let RawEntryMut::Vacant(v) = map + .raw_entry_mut() + .from_hash(hash, |x| cmp(*x, idx).is_eq()) + { + v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); + } + }; - match array.nulls() { - Some(nulls) => { - BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) - .for_each(insert_value) + match in_array.nulls() { + Some(nulls) => { + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + .for_each(insert_value) + } + None => (0..in_array.len()).for_each(insert_value), } - None => (0..array.len()).for_each(insert_value), - } - Ok(()) - })?; + Ok(()) + })?; - Ok(ArrayHashSet { state, map }) + Ok(Self { + in_array, + state, + map, + }) + } } /// Evaluates the list of expressions into an array, flattening any dictionaries @@ -253,8 +250,8 @@ impl InListExpr { /// Create a new InList expression directly from an array, bypassing expression evaluation. /// /// This is more efficient than `in_list()` when you already have the list as an array, - /// as it avoids the conversion: `ArrayRef -> Vec -> ArrayRef -> ArrayHashSet`. - /// Instead it goes directly: `ArrayRef -> ArrayHashSet`. + /// as it avoids the conversion: `ArrayRef -> Vec -> ArrayRef -> StaticFilter`. + /// Instead it goes directly: `ArrayRef -> StaticFilter`. /// /// The `list` field will be empty when using this constructor, as the array is stored /// directly in the static filter. @@ -272,8 +269,7 @@ impl InListExpr { Ok(crate::expressions::lit(scalar) as Arc) }) .collect::>>()?; - let hash_set = make_hash_set(array.as_ref())?; - let static_filter = StaticFilter { array, hash_set }; + let static_filter = StaticFilter::try_new(array)?; Ok(Self::new(expr, list, negated, Some(static_filter))) } } @@ -311,7 +307,7 @@ impl PhysicalExpr for InListExpr { } if let Some(static_filter) = &self.static_filter { - Ok(static_filter.array.null_count() > 0) + Ok(static_filter.in_array.null_count() > 0) } else { for expr in &self.list { if expr.nullable(input_schema)? { @@ -328,11 +324,9 @@ impl PhysicalExpr for InListExpr { let r = match &self.static_filter { Some(filter) => { match value { - ColumnarValue::Array(array) => filter.hash_set.contains( - &array, - filter.array.as_ref(), - self.negated, - )?, + ColumnarValue::Array(array) => { + filter.contains(&array, self.negated)? + } ColumnarValue::Scalar(scalar) => { if scalar.is_null() { // SQL three-valued logic: null IN (...) is always null @@ -344,11 +338,8 @@ impl PhysicalExpr for InListExpr { // Use a 1 row array to avoid code duplication/branching // Since all we do is compute hash and lookup this should be efficient enough let array = scalar.to_array()?; - let result_array = filter.hash_set.contains( - array.as_ref(), - filter.array.as_ref(), - self.negated, - )?; + let result_array = + filter.contains(array.as_ref(), self.negated)?; // Broadcast the single result to all rows // Must check is_null() to preserve NULL values (SQL three-valued logic) if result_array.is_null(0) { @@ -498,9 +489,7 @@ pub fn in_list( // Try to create a static filter for constant expressions let static_filter = try_evaluate_constant_list(&list, schema) - .and_then(|array| { - make_hash_set(array.as_ref()).map(|hash_set| StaticFilter { array, hash_set }) - }) + .and_then(StaticFilter::try_new) .ok(); Ok(Arc::new(InListExpr::new( @@ -560,9 +549,9 @@ mod tests { fn try_cast_static_filter_to_set( list: &[Arc], schema: &Schema, - ) -> Result { + ) -> Result { let array = try_evaluate_constant_list(list, schema)?; - make_hash_set(array.as_ref()) + StaticFilter::try_new(array) } // Attempts to coerce the types of `list_type` to be comparable with the @@ -1202,11 +1191,10 @@ mod tests { expressions::cast(lit(2i32), &schema, DataType::Int64)?, try_cast(lit(3.13f32), &schema, DataType::Int64)?, ]; - let set_array = try_evaluate_constant_list(&phy_exprs, &schema)?; - let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); + let static_filter = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); let array = Int64Array::from(vec![1, 2, 3, 4]); - let r = result.contains(&array, set_array.as_ref(), false).unwrap(); + let r = static_filter.contains(&array, false).unwrap(); assert_eq!(r, BooleanArray::from(vec![true, true, true, false])); try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); From 8a2ee060a4ae43a04c2ff1f8df2571ba75b52e96 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 18 Nov 2025 17:18:58 +0800 Subject: [PATCH 5/6] fix rebase --- .../physical-expr/src/expressions/in_list.rs | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index aa56fba40ae9a..92fdb2fbeef45 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -30,21 +30,11 @@ use arrow::buffer::BooleanBuffer; use arrow::compute::kernels::boolean::{not, or_kleene}; use arrow::compute::{take, SortOptions}; use arrow::datatypes::*; -use arrow::downcast_dictionary_array; use arrow::util::bit_iterator::BitIndexIterator; -use arrow::{downcast_dictionary_array, downcast_primitive_array}; -use datafusion_common::cast::{ - as_boolean_array, as_generic_binary_array, as_string_array, -}; -use datafusion_common::hash_utils::HashValue; +use datafusion_common::hash_utils::with_hashes; use datafusion_common::{ - assert_or_internal_err, exec_err, not_impl_err, DFSchema, DataFusionError, Result, - ScalarValue, + assert_or_internal_err, exec_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::datum::compare_with_eq; -use datafusion_common::hash_utils::with_hashes; -use datafusion_common::{exec_err, internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::{expr_vec_fmt, ColumnarValue}; use ahash::RandomState; From 06a476370428da1f3abb87f63552396f99518f05 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 17 Nov 2025 10:36:18 -0500 Subject: [PATCH 6/6] Add specialized sets for primitive types --- .../physical-expr/src/expressions/in_list.rs | 144 +++++++++++++++--- 1 file changed, 120 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 92fdb2fbeef45..95029c1efe74c 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -33,7 +33,8 @@ use arrow::datatypes::*; use arrow::util::bit_iterator::BitIndexIterator; use datafusion_common::hash_utils::with_hashes; use datafusion_common::{ - assert_or_internal_err, exec_err, DFSchema, DataFusionError, Result, ScalarValue, + assert_or_internal_err, exec_datafusion_err, exec_err, DFSchema, DataFusionError, + HashSet, Result, ScalarValue, }; use datafusion_expr::{expr_vec_fmt, ColumnarValue}; @@ -41,16 +42,12 @@ use ahash::RandomState; use datafusion_common::HashMap; use hashbrown::hash_map::RawEntryMut; -/// Static filter for InList that stores the array and hash set for O(1) lookups -#[derive(Debug, Clone)] -struct StaticFilter { - in_array: ArrayRef, - state: RandomState, - /// Used to provide a lookup from value to in list index - /// - /// Note: usize::hash is not used, instead the raw entry - /// API is used to store entries w.r.t their value - map: HashMap, +/// Trait for InList static filters +trait StaticFilter { + fn null_count(&self) -> usize; + + /// Checks if values in `v` are contained in the filter + fn contains(&self, v: &dyn Array, negated: bool) -> Result; } /// InList @@ -58,7 +55,7 @@ pub struct InListExpr { expr: Arc, list: Vec>, negated: bool, - static_filter: Option, + static_filter: Option>, } impl Debug for InListExpr { @@ -71,7 +68,23 @@ impl Debug for InListExpr { } } -impl StaticFilter { +/// Static filter for InList that stores the array and hash set for O(1) lookups +#[derive(Debug, Clone)] +struct ArrayStaticFilter { + in_array: ArrayRef, + state: RandomState, + /// Used to provide a lookup from value to in list index + /// + /// Note: usize::hash is not used, instead the raw entry + /// API is used to store entries w.r.t their value + map: HashMap, +} + +impl StaticFilter for ArrayStaticFilter { + fn null_count(&self) -> usize { + self.in_array.null_count() + } + /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. fn contains(&self, v: &dyn Array, negated: bool) -> Result { // Null type comparisons always return null (SQL three-valued logic) @@ -119,17 +132,31 @@ impl StaticFilter { .collect()) }) } +} + +fn instantiate_static_filter( + in_array: ArrayRef, +) -> Result> { + match in_array.data_type() { + DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), + _ => { + /* fall through to generic implementation */ + Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) + } + } +} +impl ArrayStaticFilter { /// Computes a [`StaticFilter`] for the provided [`Array`] if there /// are nulls present or there are more than the configured number of /// elements. /// /// Note: This is split into a separate function as higher-rank trait bounds currently /// cause type inference to misbehave - fn try_new(in_array: ArrayRef) -> Result { + fn try_new(in_array: ArrayRef) -> Result { // Null type has no natural order - return empty hash set if in_array.data_type() == &DataType::Null { - return Ok(StaticFilter { + return Ok(ArrayStaticFilter { in_array, state: RandomState::new(), map: HashMap::with_hasher(()), @@ -171,6 +198,68 @@ impl StaticFilter { } } +struct Int32StaticFilter { + null_count: usize, + values: HashSet, +} + +impl Int32StaticFilter { + fn try_new(in_array: &ArrayRef) -> Result { + let in_array = in_array + .as_primitive_opt::() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(v); + } + + Ok(Self { null_count, values }) + } +} + +impl StaticFilter for Int32StaticFilter { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + let v = v + .as_primitive_opt::() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + + let result = match (v.null_count() > 0, negated) { + (true, false) => { + // has nulls, not negated" + BooleanArray::from_iter( + v.iter().map(|value| Some(self.values.contains(&value?))), + ) + } + (true, true) => { + // has nulls, negated + BooleanArray::from_iter( + v.iter().map(|value| Some(!self.values.contains(&value?))), + ) + } + (false, false) => { + //no null, not negated + BooleanArray::from_iter( + v.values().iter().map(|value| self.values.contains(value)), + ) + } + (false, true) => { + // no null, negated + BooleanArray::from_iter( + v.values().iter().map(|value| !self.values.contains(value)), + ) + } + }; + Ok(result) + } +} + /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc], @@ -212,7 +301,7 @@ impl InListExpr { expr: Arc, list: Vec>, negated: bool, - static_filter: Option, + static_filter: Option>, ) -> Self { Self { expr, @@ -259,8 +348,12 @@ impl InListExpr { Ok(crate::expressions::lit(scalar) as Arc) }) .collect::>>()?; - let static_filter = StaticFilter::try_new(array)?; - Ok(Self::new(expr, list, negated, Some(static_filter))) + Ok(Self::new( + expr, + list, + negated, + Some(instantiate_static_filter(array)?), + )) } } impl std::fmt::Display for InListExpr { @@ -297,7 +390,7 @@ impl PhysicalExpr for InListExpr { } if let Some(static_filter) = &self.static_filter { - Ok(static_filter.in_array.null_count() > 0) + Ok(static_filter.null_count() > 0) } else { for expr in &self.list { if expr.nullable(input_schema)? { @@ -420,7 +513,7 @@ impl PhysicalExpr for InListExpr { Arc::clone(&children[0]), children[1..].to_vec(), self.negated, - self.static_filter.clone(), + self.static_filter.as_ref().map(Arc::clone), ))) } @@ -479,8 +572,11 @@ pub fn in_list( // Try to create a static filter for constant expressions let static_filter = try_evaluate_constant_list(&list, schema) - .and_then(StaticFilter::try_new) - .ok(); + .and_then(ArrayStaticFilter::try_new) + .ok() + .map(|static_filter| { + Arc::new(static_filter) as Arc + }); Ok(Arc::new(InListExpr::new( expr, @@ -539,9 +635,9 @@ mod tests { fn try_cast_static_filter_to_set( list: &[Arc], schema: &Schema, - ) -> Result { + ) -> Result { let array = try_evaluate_constant_list(list, schema)?; - StaticFilter::try_new(array) + ArrayStaticFilter::try_new(array) } // Attempts to coerce the types of `list_type` to be comparable with the