diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index dcf477135a37..5515fb7a8923 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -21,13 +21,17 @@ use crate::fuzz_cases::aggregation_fuzzer::{ AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, QueryBuilder, }; -use arrow::array::{types::Int64Type, Array, ArrayRef, AsArray, Int64Array, RecordBatch}; +use arrow::array::{ + types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, + StringArray, +}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::{ DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{Field, Schema, SchemaRef}; use datafusion::common::Result; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; @@ -42,14 +46,18 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor} use datafusion_common::HashMap; use datafusion_common_runtime::JoinSet; use datafusion_functions_aggregate::sum::sum_udaf; -use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; +use datafusion_execution::memory_pool::FairSpillPool; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::metrics::MetricValue; use rand::rngs::StdRng; -use rand::{thread_rng, Rng, SeedableRng}; +use rand::{random, thread_rng, Rng, SeedableRng}; // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] @@ -663,3 +671,134 @@ fn extract_result_counts(results: Vec) -> HashMap, i } output } + +fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc) { + if let Some(metrics_set) = single_aggregate.metrics() { + let mut spill_count = 0; + + // Inspect metrics for SpillCount + for metric in metrics_set.iter() { + if let MetricValue::SpillCount(count) = metric.value() { + spill_count = count.value(); + break; + } + } + + if expect_spill && spill_count == 0 { + panic!("Expected spill but SpillCount metric not found or SpillCount was 0."); + } else if !expect_spill && spill_count > 0 { + panic!("Expected no spill but found SpillCount metric with value greater than 0."); + } + } else { + panic!("No metrics returned from the operator; cannot verify spilling."); + } +} + +// Fix for https://github.com/apache/datafusion/issues/15530 +#[tokio::test] +async fn test_single_mode_aggregate_with_spill() -> Result<()> { + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::Int64, true), + Field::new("col_1", DataType::Utf8, true), + Field::new("col_2", DataType::Utf8, true), + Field::new("col_3", DataType::Utf8, true), + Field::new("col_4", DataType::Utf8, true), + Field::new("col_5", DataType::Int32, true), + Field::new("col_6", DataType::Utf8, true), + Field::new("col_7", DataType::Utf8, true), + Field::new("col_8", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![ + (Arc::new(Column::new("col_1", 1)), "col_1".to_string()), + (Arc::new(Column::new("col_7", 7)), "col_7".to_string()), + (Arc::new(Column::new("col_0", 0)), "col_0".to_string()), + (Arc::new(Column::new("col_8", 8)), "col_8".to_string()), + ]); + + fn generate_int64_array() -> ArrayRef { + Arc::new(Int64Array::from_iter_values( + (0..1024).map(|_| random::()), + )) + } + fn generate_int32_array() -> ArrayRef { + Arc::new(Int32Array::from_iter_values( + (0..1024).map(|_| random::()), + )) + } + + fn generate_string_array() -> ArrayRef { + Arc::new(StringArray::from( + (0..1024) + .map(|_| -> String { + thread_rng() + .sample_iter::(rand::distributions::Standard) + .take(5) + .collect() + }) + .collect::>(), + )) + } + + fn generate_record_batch(schema: &SchemaRef) -> Result { + RecordBatch::try_new( + Arc::clone(schema), + vec![ + generate_int64_array(), + generate_string_array(), + generate_string_array(), + generate_string_array(), + generate_string_array(), + generate_int32_array(), + generate_string_array(), + generate_string_array(), + generate_string_array(), + ], + ) + .map_err(|err| err.into()) + } + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![lit(1i64)]) + .schema(Arc::clone(&scan_schema)) + .alias("SUM(1i64)") + .build()?, + )]; + + let batches = (0..5) + .map(|_| generate_record_batch(&scan_schema)) + .collect::>>()?; + + let plan: Arc = + MemorySourceConfig::try_new_exec(&[batches], Arc::clone(&scan_schema), None) + .unwrap(); + + let single_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + + let memory_pool = Arc::new(FairSpillPool::new(250000)); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(248)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )), + ); + + datafusion_physical_plan::common::collect( + single_aggregate.execute(0, Arc::clone(&task_ctx))?, + ) + .await?; + + assert_spill_count_metric(true, single_aggregate); + + Ok(()) +} diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 077f18d51033..502ea3317adc 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -507,6 +507,16 @@ impl GroupedHashAggregateStream { AggregateMode::Partial, )?; + // Need to update the GROUP BY expressions to point to the correct column after schema change + let merging_group_by_expr = agg_group_by + .expr + .iter() + .enumerate() + .map(|(idx, (_, name))| { + (Arc::new(Column::new(name.as_str(), idx)) as _, name.clone()) + }) + .collect(); + let partial_agg_schema = Arc::new(partial_agg_schema); let spill_expr = group_schema @@ -550,7 +560,7 @@ impl GroupedHashAggregateStream { spill_schema: partial_agg_schema, is_stream_merging: false, merging_aggregate_arguments, - merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), + merging_group_by: PhysicalGroupBy::new_single(merging_group_by_expr), peak_mem_used: MetricBuilder::new(&agg.metrics) .gauge("peak_mem_used", partition), spill_manager,