diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index ff9cdedab8b1..409f248621f7 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -82,6 +82,23 @@ pub fn project_schema( Ok(schema) } +/// Extracts a row at the specified index from a set of columns and stores it in the provided buffer. +pub fn extract_row_at_idx_to_buf( + columns: &[ArrayRef], + idx: usize, + buf: &mut Vec, +) -> Result<()> { + buf.clear(); + + let iter = columns + .iter() + .map(|arr| ScalarValue::try_from_array(arr, idx)); + for v in iter.into_iter() { + buf.push(v?); + } + + Ok(()) +} /// Given column vectors, returns row at `idx`. pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result> { columns diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 46221acfcc9b..1b98a19581ea 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::str; use std::sync::Arc; use crate::fuzz_cases::aggregation_fuzzer::{ @@ -88,6 +87,32 @@ async fn test_min() { .await; } +#[tokio::test(flavor = "multi_thread")] +async fn test_first_val() { + let mut data_gen_config: DatasetGeneratorConfig = baseline_config(); + + for i in 0..data_gen_config.columns.len() { + if data_gen_config.columns[i].get_max_num_distinct().is_none() { + data_gen_config.columns[i] = data_gen_config.columns[i] + .clone() + // Minimize the chance of identical values in the order by columns to make the test more stable + .with_max_num_distinct(usize::MAX); + } + } + + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("first_value") + .with_aggregate_arguments(data_gen_config.all_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + #[tokio::test(flavor = "multi_thread")] async fn test_max() { let data_gen_config = baseline_config(); diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index 54c5744c861b..d61835a0804e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -228,6 +228,10 @@ impl ColumnDescr { } } + pub fn get_max_num_distinct(&self) -> Option { + self.max_num_distinct + } + /// set the maximum number of distinct values in this column /// /// If `None`, the number of distinct values is randomly selected between 1 diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index c608adda5d1c..bb24fb554d65 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; use std::sync::Arc; +use std::{collections::HashSet, str::FromStr}; use arrow::array::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion_common::{DataFusionError, Result}; use datafusion_common_runtime::JoinSet; +use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; use crate::fuzz_cases::aggregation_fuzzer::{ @@ -452,7 +453,11 @@ impl QueryBuilder { pub fn generate_query(&self) -> String { let group_by = self.random_group_by(); let mut query = String::from("SELECT "); - query.push_str(&self.random_aggregate_functions().join(", ")); + query.push_str(&group_by.join(", ")); + if !group_by.is_empty() { + query.push_str(", "); + } + query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); query.push_str(" FROM "); query.push_str(&self.table_name); if !group_by.is_empty() { @@ -474,7 +479,7 @@ impl QueryBuilder { /// * `function_names` are randomly selected from [`Self::aggregate_functions`] /// * ` argument` is randomly selected from [`Self::arguments`] /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) - fn random_aggregate_functions(&self) -> Vec { + fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec { const MAX_NUM_FUNCTIONS: usize = 5; let mut rng = thread_rng(); let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); @@ -482,6 +487,14 @@ impl QueryBuilder { let mut alias_gen = 1; let mut aggregate_functions = vec![]; + + let mut order_by_black_list: HashSet = + group_by_cols.iter().cloned().collect(); + // remove one random col + if let Some(first) = order_by_black_list.iter().next().cloned() { + order_by_black_list.remove(&first); + } + while aggregate_functions.len() < num_aggregate_functions { let idx = rng.gen_range(0..self.aggregate_functions.len()); let (function_name, is_distinct) = &self.aggregate_functions[idx]; @@ -489,7 +502,19 @@ impl QueryBuilder { let alias = format!("col{}", alias_gen); let distinct = if *is_distinct { "DISTINCT " } else { "" }; alias_gen += 1; - let function = format!("{function_name}({distinct}{argument}) as {alias}"); + + let (order_by, null_opt) = if function_name.eq("first_value") { + ( + self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ + self.null_opt(), + ) + } else { + ("".to_string(), "".to_string()) + }; + + let function = format!( + "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" + ); aggregate_functions.push(function); } aggregate_functions @@ -502,6 +527,39 @@ impl QueryBuilder { self.arguments[idx].clone() } + fn order_by(&self, black_list: &HashSet) -> String { + let mut available_columns: Vec = self + .arguments + .iter() + .filter(|col| !black_list.contains(*col)) + .cloned() + .collect(); + + available_columns.shuffle(&mut thread_rng()); + + let num_of_order_by_col = 12; + let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); + + let selected_columns = &available_columns[0..column_count]; + + let mut rng = thread_rng(); + let mut result = String::from_str(" order by ").unwrap(); + for col in selected_columns { + let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" }; + result.push_str(&format!("{} {},", col, order)); + } + + result.strip_suffix(",").unwrap().to_string() + } + + fn null_opt(&self) -> String { + if thread_rng().gen_bool(0.5) { + "RESPECT NULLS".to_string() + } else { + "IGNORE NULLS".to_string() + } + } + /// Pick a random number of fields to group by (non-repeating) /// /// Limited to 3 group by columns to ensure coverage for large groups. With diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 6df8ede4fc77..28e6a8723dfd 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -22,18 +22,30 @@ use std::fmt::Debug; use std::mem::size_of_val; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::compute::{self, LexicographicalComparator, SortColumn}; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder, + PrimitiveArray, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions}; +use arrow::datatypes::{ + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float16Type, + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, +}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt, Signature, - SortExpr, Volatility, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt, + GroupsAccumulator, Signature, SortExpr, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_macros::user_doc; @@ -153,6 +165,106 @@ impl AggregateUDFImpl for FirstValue { Ok(fields) } + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + use DataType::*; + matches!( + args.return_type, + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + fn create_accumulator( + args: AccumulatorArgs, + ) -> Result> + where + T: ArrowPrimitiveType + Send, + { + let ordering_dtypes = args + .ordering_req + .iter() + .map(|e| e.expr.data_type(args.schema)) + .collect::>>()?; + + Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( + args.ordering_req.clone(), + args.ignore_nulls, + args.return_type, + &ordering_dtypes, + )?)) + } + + match args.return_type { + DataType::Int8 => create_accumulator::(args), + DataType::Int16 => create_accumulator::(args), + DataType::Int32 => create_accumulator::(args), + DataType::Int64 => create_accumulator::(args), + DataType::UInt8 => create_accumulator::(args), + DataType::UInt16 => create_accumulator::(args), + DataType::UInt32 => create_accumulator::(args), + DataType::UInt64 => create_accumulator::(args), + DataType::Float16 => create_accumulator::(args), + DataType::Float32 => create_accumulator::(args), + DataType::Float64 => create_accumulator::(args), + + DataType::Decimal128(_, _) => create_accumulator::(args), + DataType::Decimal256(_, _) => create_accumulator::(args), + + DataType::Timestamp(TimeUnit::Second, _) => { + create_accumulator::(args) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + create_accumulator::(args) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + create_accumulator::(args) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + create_accumulator::(args) + } + + DataType::Date32 => create_accumulator::(args), + DataType::Date64 => create_accumulator::(args), + DataType::Time32(TimeUnit::Second) => { + create_accumulator::(args) + } + DataType::Time32(TimeUnit::Millisecond) => { + create_accumulator::(args) + } + + DataType::Time64(TimeUnit::Microsecond) => { + create_accumulator::(args) + } + DataType::Time64(TimeUnit::Nanosecond) => { + create_accumulator::(args) + } + + _ => internal_err!( + "GroupsAccumulator not supported for first({})", + args.return_type + ), + } + } + fn aliases(&self) -> &[String] { &[] } @@ -179,6 +291,460 @@ impl AggregateUDFImpl for FirstValue { } } +struct FirstPrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + // ================ state =========== + vals: Vec, + // Stores ordering values, of the aggregator requirement corresponding to first value + // of the aggregator. + // The `orderings` are stored row-wise, meaning that `orderings[group_idx]` + // represents the ordering values corresponding to the `group_idx`-th group. + orderings: Vec>, + // At the beginning, `is_sets[group_idx]` is false, which means `first` is not seen yet. + // Once we see the first value, we set the `is_sets[group_idx]` flag + is_sets: BooleanBufferBuilder, + // null_builder[group_idx] == false => vals[group_idx] is null + null_builder: BooleanBufferBuilder, + // size of `self.orderings` + // Calculating the memory usage of `self.orderings` using `ScalarValue::size_of_vec` is quite costly. + // Therefore, we cache it and compute `size_of` only after each update + // to avoid calling `ScalarValue::size_of_vec` by Self.size. + size_of_orderings: usize, + + // buffer for `get_filtered_min_of_each_group` + // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val + // only valid if filter_min_of_each_group_buf.1[group_idx] == true + min_of_each_group_buf: (Vec, BooleanBufferBuilder), + + // =========== option ============ + + // Stores the applicable ordering requirement. + ordering_req: LexOrdering, + // derived from `ordering_req`. + sort_options: Vec, + // Stores whether incoming data already satisfies the ordering requirement. + input_requirement_satisfied: bool, + // Ignore null values. + ignore_nulls: bool, + /// The output type + data_type: DataType, + default_orderings: Vec, +} + +impl FirstPrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn try_new( + ordering_req: LexOrdering, + ignore_nulls: bool, + data_type: &DataType, + ordering_dtypes: &[DataType], + ) -> Result { + let requirement_satisfied = ordering_req.is_empty(); + + let default_orderings = ordering_dtypes + .iter() + .map(ScalarValue::try_from) + .collect::>>()?; + + let sort_options = get_sort_options(ordering_req.as_ref()); + + Ok(Self { + null_builder: BooleanBufferBuilder::new(0), + ordering_req, + sort_options, + input_requirement_satisfied: requirement_satisfied, + ignore_nulls, + default_orderings, + data_type: data_type.clone(), + vals: Vec::new(), + orderings: Vec::new(), + is_sets: BooleanBufferBuilder::new(0), + size_of_orderings: 0, + min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)), + }) + } + + fn need_update(&self, group_idx: usize) -> bool { + if !self.is_sets.get_bit(group_idx) { + return true; + } + + if self.ignore_nulls && !self.null_builder.get_bit(group_idx) { + return true; + } + + !self.input_requirement_satisfied + } + + fn should_update_state( + &self, + group_idx: usize, + new_ordering_values: &[ScalarValue], + ) -> Result { + if !self.is_sets.get_bit(group_idx) { + return Ok(true); + } + + assert!(new_ordering_values.len() == self.ordering_req.len()); + let current_ordering = &self.orderings[group_idx]; + compare_rows(current_ordering, new_ordering_values, &self.sort_options) + .map(|x| x.is_gt()) + } + + fn take_orderings(&mut self, emit_to: EmitTo) -> Vec> { + let result = emit_to.take_needed(&mut self.orderings); + + match emit_to { + EmitTo::All => self.size_of_orderings = 0, + EmitTo::First(_) => { + self.size_of_orderings -= + result.iter().map(ScalarValue::size_of_vec).sum::() + } + } + + result + } + + fn take_need( + bool_buf_builder: &mut BooleanBufferBuilder, + emit_to: EmitTo, + ) -> BooleanBuffer { + let bool_buf = bool_buf_builder.finish(); + match emit_to { + EmitTo::All => bool_buf, + EmitTo::First(n) => { + // split off the first N values in seen_values + // + // TODO make this more efficient rather than two + // copies and bitwise manipulation + let first_n: BooleanBuffer = bool_buf.iter().take(n).collect(); + // reset the existing buffer + for b in bool_buf.iter().skip(n) { + bool_buf_builder.append(b); + } + first_n + } + } + } + + fn resize_states(&mut self, new_size: usize) { + self.vals.resize(new_size, T::default_value()); + + self.null_builder.resize(new_size); + + if self.orderings.len() < new_size { + let current_len = self.orderings.len(); + + self.orderings + .resize(new_size, self.default_orderings.clone()); + + self.size_of_orderings += (new_size - current_len) + * ScalarValue::size_of_vec( + // Note: In some cases (such as in the unit test below) + // ScalarValue::size_of_vec(&self.default_orderings) != ScalarValue::size_of_vec(&self.default_orderings.clone()) + // This may be caused by the different vec.capacity() values? + self.orderings.last().unwrap(), + ); + } + + self.is_sets.resize(new_size); + + self.min_of_each_group_buf.0.resize(new_size, 0); + self.min_of_each_group_buf.1.resize(new_size); + } + + fn update_state( + &mut self, + group_idx: usize, + orderings: &[ScalarValue], + new_val: T::Native, + is_null: bool, + ) { + self.vals[group_idx] = new_val; + self.is_sets.set_bit(group_idx, true); + + self.null_builder.set_bit(group_idx, !is_null); + + assert!(orderings.len() == self.ordering_req.len()); + let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]); + self.orderings[group_idx].clear(); + self.orderings[group_idx].extend_from_slice(orderings); + let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]); + self.size_of_orderings = self.size_of_orderings - old_size + new_size; + } + + fn take_state( + &mut self, + emit_to: EmitTo, + ) -> (ArrayRef, Vec>, BooleanBuffer) { + emit_to.take_needed(&mut self.min_of_each_group_buf.0); + self.min_of_each_group_buf + .1 + .truncate(self.min_of_each_group_buf.0.len()); + + ( + self.take_vals_and_null_buf(emit_to), + self.take_orderings(emit_to), + Self::take_need(&mut self.is_sets, emit_to), + ) + } + + // should be used in test only + #[cfg(test)] + fn compute_size_of_orderings(&self) -> usize { + self.orderings + .iter() + .map(ScalarValue::size_of_vec) + .sum::() + } + + /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the + /// minimum value in `orderings` for each group, using lexicographical comparison. + /// Values are filtered using `opt_filter` and `is_set_arr` if provided. + fn get_filtered_min_of_each_group( + &mut self, + orderings: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + vals: &PrimitiveArray, + is_set_arr: Option<&BooleanArray>, + ) -> Result> { + // Set all values in min_of_each_group_buf.1 to false. + self.min_of_each_group_buf.1.truncate(0); + self.min_of_each_group_buf + .1 + .append_n(self.vals.len(), false); + + // No need to call `clear` since `self.min_of_each_group_buf.0[group_idx]` + // is only valid when `self.min_of_each_group_buf.1[group_idx] == true`. + + let comparator = { + assert_eq!(orderings.len(), self.ordering_req.len()); + let sort_columns = orderings + .iter() + .zip(self.ordering_req.iter()) + .map(|(array, req)| SortColumn { + values: Arc::clone(array), + options: Some(req.options), + }) + .collect::>(); + + LexicographicalComparator::try_new(&sort_columns)? + }; + + for (idx_in_val, group_idx) in group_indices.iter().enumerate() { + let group_idx = *group_idx; + + let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val)); + + let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val)); + + if !passed_filter || !is_set { + continue; + } + + if !self.need_update(group_idx) { + continue; + } + + if self.ignore_nulls && vals.is_null(idx_in_val) { + continue; + } + + let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx); + if is_valid + && comparator + .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val) + .is_gt() + { + self.min_of_each_group_buf.0[group_idx] = idx_in_val; + } else if !is_valid { + self.min_of_each_group_buf.1.set_bit(group_idx, true); + self.min_of_each_group_buf.0[group_idx] = idx_in_val; + } + } + + Ok(self + .min_of_each_group_buf + .0 + .iter() + .enumerate() + .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx)) + .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val)) + .collect::>()) + } + + fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef { + let r = emit_to.take_needed(&mut self.vals); + + let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to)); + + let values = PrimitiveArray::::new(r.into(), Some(null_buf)) // no copy + .with_data_type(self.data_type.clone()); + Arc::new(values) + } +} + +impl GroupsAccumulator for FirstPrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn update_batch( + &mut self, + // e.g. first_value(a order by b): values_and_order_cols will be [a, b] + values_and_order_cols: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.resize_states(total_num_groups); + + let vals = values_and_order_cols[0].as_primitive::(); + + let mut ordering_buf = Vec::with_capacity(self.ordering_req.len()); + + // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible. + for (group_idx, idx) in self + .get_filtered_min_of_each_group( + &values_and_order_cols[1..], + group_indices, + opt_filter, + vals, + None, + )? + .into_iter() + { + extract_row_at_idx_to_buf( + &values_and_order_cols[1..], + idx, + &mut ordering_buf, + )?; + + if self.should_update_state(group_idx, &ordering_buf)? { + self.update_state( + group_idx, + &ordering_buf, + vals.value(idx), + vals.is_null(idx), + ); + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + Ok(self.take_state(emit_to).0) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let (val_arr, orderings, is_sets) = self.take_state(emit_to); + let mut result = Vec::with_capacity(self.orderings.len() + 2); + + result.push(val_arr); + + let ordering_cols = { + let mut ordering_cols = Vec::with_capacity(self.ordering_req.len()); + for _ in 0..self.ordering_req.len() { + ordering_cols.push(Vec::with_capacity(self.orderings.len())); + } + for row in orderings.into_iter() { + assert_eq!(row.len(), self.ordering_req.len()); + for (col_idx, ordering) in row.into_iter().enumerate() { + ordering_cols[col_idx].push(ordering); + } + } + + ordering_cols + }; + for ordering_col in ordering_cols { + result.push(ScalarValue::iter_to_array(ordering_col)?); + } + + result.push(Arc::new(BooleanArray::new(is_sets, None))); + + Ok(result) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.resize_states(total_num_groups); + + let mut ordering_buf = Vec::with_capacity(self.ordering_req.len()); + + let (is_set_arr, val_and_order_cols) = match values.split_last() { + Some(result) => result, + None => return internal_err!("Empty row in FISRT_VALUE"), + }; + + let is_set_arr = as_boolean_array(is_set_arr)?; + + let vals = values[0].as_primitive::(); + // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible. + let groups = self.get_filtered_min_of_each_group( + &val_and_order_cols[1..], + group_indices, + opt_filter, + vals, + Some(is_set_arr), + )?; + + for (group_idx, idx) in groups.into_iter() { + extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?; + + if self.should_update_state(group_idx, &ordering_buf)? { + self.update_state( + group_idx, + &ordering_buf, + vals.value(idx), + vals.is_null(idx), + ); + } + } + + Ok(()) + } + + fn size(&self) -> usize { + self.vals.capacity() * size_of::() + + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes + + self.is_sets.capacity() / 8 + + self.size_of_orderings + + self.min_of_each_group_buf.0.capacity() * size_of::() + + self.min_of_each_group_buf.1.capacity() / 8 + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let mut result = values.to_vec(); + match opt_filter { + Some(f) => { + result.push(Arc::new(f.clone())); + Ok(result) + } + None => { + result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()]))); + Ok(result) + } + } + } +} #[derive(Debug)] pub struct FirstValueAccumulator { first: ScalarValue, @@ -684,7 +1250,8 @@ fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Boolean, true), + ])); + + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]); + + let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + sort_key, + true, + &DataType::Int64, + &[DataType::Int64], + )?; + + let mut val_with_orderings = { + let mut val_with_orderings = Vec::::new(); + + let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)])); + let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6])); + + val_with_orderings.push(vals); + val_with_orderings.push(orderings); + + val_with_orderings + }; + + group_acc.update_batch( + &val_with_orderings, + &[0, 1, 2, 1], + Some(&BooleanArray::from(vec![true, true, false, true])), + 3, + )?; + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + + let state = group_acc.state(EmitTo::All)?; + + let expected_state: Vec> = vec![ + Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])), + Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])), + Arc::new(BooleanArray::from(vec![true, true, false])), + ]; + assert_eq!(state, expected_state); + + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + + group_acc.merge_batch( + &state, + &[0, 1, 2], + Some(&BooleanArray::from(vec![true, false, false])), + 3, + )?; + + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + + val_with_orderings.clear(); + val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6]))); + val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6]))); + + group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?; + + let binding = group_acc.evaluate(EmitTo::All)?; + let eval_result = binding.as_any().downcast_ref::().unwrap(); + + let expect: PrimitiveArray = + Int64Array::from(vec![Some(1), Some(6), Some(6), None]); + + assert_eq!(eval_result, &expect); + + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + + Ok(()) + } + + #[test] + fn test_frist_group_acc_size_of_ordering() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Boolean, true), + ])); + + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]); + + let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + sort_key, + true, + &DataType::Int64, + &[DataType::Int64], + )?; + + let val_with_orderings = { + let mut val_with_orderings = Vec::::new(); + + let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)])); + let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6])); + + val_with_orderings.push(vals); + val_with_orderings.push(orderings); + + val_with_orderings + }; + + for _ in 0..10 { + group_acc.update_batch( + &val_with_orderings, + &[0, 1, 2, 1], + Some(&BooleanArray::from(vec![true, true, false, true])), + 100, + )?; + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + + group_acc.state(EmitTo::First(2))?; + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + + let s = group_acc.state(EmitTo::All)?; + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + + group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?; + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + + group_acc.evaluate(EmitTo::First(2))?; + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + + group_acc.evaluate(EmitTo::All)?; + assert_eq!( + group_acc.size_of_orderings, + group_acc.compute_size_of_orderings() + ); + } + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 0cc8045dccd0..135d9c620fbb 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -2665,6 +2665,47 @@ TUR [100.0, 75.0] 175 # test_reverse_aggregate_expr # Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering # that have contradictory requirements at first glance. + +statement ok +CREATE TABLE null_group ( + a INT, b INT, c INT, d INT +) as VALUES + (6, 6, null, null), + (6, 6, 1, null), + (6, 6, null, 1) + +query III rowsort +select c, d, first_value(a order by b) from null_group group by c, d; +---- +1 NULL 6 +NULL 1 6 +NULL NULL 6 + + + +statement ok +CREATE TABLE first_null ( + k INT, + val INT, + o int + ) as VALUES + (0, NULL, -9), + (0, 1, 1), + (1, 1, 1); + +query II rowsort +select k, first_value(val order by o) IGNORE NULLS from first_null group by k; +---- +0 1 +1 1 + +query II rowsort +select k, first_value(val order by o) respect NULLS from first_null group by k; +---- +0 NULL +1 1 + + query TT EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1,