From 7dc8b9f9b0d9e46427f6f07c8938df9a554b3242 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 8 Oct 2025 21:55:11 +0300 Subject: [PATCH] fix: update `PrimitiveGroupValueBuilder` to match NaN correctly --- .../group_values/multi_group_by/primitive.rs | 56 +++++++++++-------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index f560121cd79ed..f2b12a8c954f7 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -70,7 +70,7 @@ impl GroupColumn // Otherwise, we need to check their values } - self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) + self.group_values[lhs_row].is_eq(array.as_primitive::().value(rhs_row)) } fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { @@ -217,14 +217,14 @@ mod tests { use std::sync::Arc; use crate::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder; - use arrow::array::{ArrayRef, Int64Array, NullBufferBuilder}; - use arrow::datatypes::{DataType, Int64Type}; + use arrow::array::{ArrayRef, Float32Array, Int64Array, NullBufferBuilder}; + use arrow::datatypes::{DataType, Float32Type, Int64Type}; use super::GroupColumn; #[test] fn test_nullable_primitive_equal_to() { - let append = |builder: &mut PrimitiveGroupValueBuilder, + let append = |builder: &mut PrimitiveGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { @@ -232,7 +232,7 @@ mod tests { } }; - let equal_to = |builder: &PrimitiveGroupValueBuilder, + let equal_to = |builder: &PrimitiveGroupValueBuilder, lhs_rows: &[usize], input_array: &ArrayRef, rhs_rows: &[usize], @@ -248,7 +248,7 @@ mod tests { #[test] fn test_nullable_primitive_vectorized_equal_to() { - let append = |builder: &mut PrimitiveGroupValueBuilder, + let append = |builder: &mut PrimitiveGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { builder @@ -256,7 +256,7 @@ mod tests { .unwrap(); }; - let equal_to = |builder: &PrimitiveGroupValueBuilder, + let equal_to = |builder: &PrimitiveGroupValueBuilder, lhs_rows: &[usize], input_array: &ArrayRef, rhs_rows: &[usize], @@ -274,9 +274,9 @@ mod tests { fn test_nullable_primitive_equal_to_internal(mut append: A, mut equal_to: E) where - A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), + A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), E: FnMut( - &PrimitiveGroupValueBuilder, + &PrimitiveGroupValueBuilder, &[usize], &ArrayRef, &[usize], @@ -293,48 +293,58 @@ mod tests { // Define PrimitiveGroupValueBuilder let mut builder = - PrimitiveGroupValueBuilder::::new(DataType::Int64); - let builder_array = Arc::new(Int64Array::from(vec![ + PrimitiveGroupValueBuilder::::new(DataType::Float32); + let builder_array = Arc::new(Float32Array::from(vec![ None, None, None, - Some(1), - Some(2), - Some(3), + Some(1.0), + Some(2.0), + Some(f32::NAN), + Some(3.0), ])) as ArrayRef; - append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]); + append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6]); // Define input array - let (_, values, _nulls) = - Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) - .into_parts(); + let (_, values, _nulls) = Float32Array::from(vec![ + Some(1.0), + Some(2.0), + None, + Some(1.0), + None, + Some(f32::NAN), + None, + ]) + .into_parts(); // explicitly build a null buffer where one of the null values also happens to match let mut nulls = NullBufferBuilder::new(6); nulls.append_non_null(); nulls.append_null(); // this sets Some(2) to null above nulls.append_null(); - nulls.append_null(); nulls.append_non_null(); + nulls.append_null(); nulls.append_non_null(); - let input_array = Arc::new(Int64Array::new(values, nulls.finish())) as ArrayRef; + nulls.append_null(); + let input_array = Arc::new(Float32Array::new(values, nulls.finish())) as ArrayRef; // Check let mut equal_to_results = vec![true; builder.len()]; equal_to( &builder, - &[0, 1, 2, 3, 4, 5], + &[0, 1, 2, 3, 4, 5, 6], &input_array, - &[0, 1, 2, 3, 4, 5], + &[0, 1, 2, 3, 4, 5, 6], &mut equal_to_results, ); assert!(!equal_to_results[0]); assert!(equal_to_results[1]); assert!(equal_to_results[2]); - assert!(!equal_to_results[3]); + assert!(equal_to_results[3]); assert!(!equal_to_results[4]); assert!(equal_to_results[5]); + assert!(!equal_to_results[6]); } #[test]