From e9beec8d2f400a0ac7f40fa0f708f5b1e7b998e4 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 13 Apr 2025 22:35:52 +0300 Subject: [PATCH 01/21] feat: add multi level merge sort that will always fit in memory --- .../physical-plan/src/aggregates/row_hash.rs | 32 +- datafusion/physical-plan/src/sorts/mod.rs | 1 + .../src/sorts/multi_level_merge.rs | 342 ++++++++++++++++++ datafusion/physical-plan/src/sorts/sort.rs | 40 +- .../src/sorts/streaming_merge.rs | 59 ++- .../physical-plan/src/spill/get_size.rs | 199 ++++++++++ datafusion/physical-plan/src/spill/mod.rs | 1 + .../physical-plan/src/spill/spill_manager.rs | 165 ++++++++- 8 files changed, 799 insertions(+), 40 deletions(-) create mode 100644 datafusion/physical-plan/src/sorts/multi_level_merge.rs create mode 100644 datafusion/physical-plan/src/spill/get_size.rs diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 232565a04466..ec414f27b192 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -29,7 +29,7 @@ use crate::aggregates::{ }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; use crate::sorts::sort::sort_batch; -use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::spill_manager::SpillManager; use crate::stream::RecordBatchStreamAdapter; use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; @@ -39,7 +39,6 @@ use arrow::array::*; use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, DataFusionError, Result}; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; @@ -100,7 +99,7 @@ struct SpillState { // ======================================================================== /// If data has previously been spilled, the locations of the /// spill files (in Arrow IPC format) - spills: Vec, + spills: Vec, /// true when streaming merge is in progress is_stream_merging: bool, @@ -999,13 +998,21 @@ impl GroupedHashAggregateStream { let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?; // Spill sorted state to disk - let spillfile = self.spill_state.spill_manager.spill_record_batch_by_size( - &sorted, - "HashAggSpill", - self.batch_size, - )?; + let spillfile = self + .spill_state + .spill_manager + .spill_record_batch_by_size_and_return_max_batch_memory( + &sorted, + "HashAggSpill", + self.batch_size, + )?; match spillfile { - Some(spillfile) => self.spill_state.spills.push(spillfile), + Some((spillfile, max_record_batch_memory)) => { + self.spill_state.spills.push(SortedSpillFile { + file: spillfile, + max_record_batch_memory, + }) + } None => { return internal_err!( "Calling spill with no intermediate batch to spill" @@ -1066,14 +1073,13 @@ impl GroupedHashAggregateStream { sort_batch(&batch, expr.as_ref(), None) })), ))); - for spill in self.spill_state.spills.drain(..) { - let stream = self.spill_state.spill_manager.read_spill_as_stream(spill)?; - streams.push(stream); - } + self.spill_state.is_stream_merging = true; self.input = StreamingMergeBuilder::new() .with_streams(streams) .with_schema(schema) + .with_spill_manager(self.spill_state.spill_manager.clone()) + .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills)) .with_expressions(self.spill_state.spill_expr.as_ref()) .with_metrics(self.baseline_metrics.clone()) .with_batch_size(self.batch_size) diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index c7ffae4061c0..9c72e34fe343 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -20,6 +20,7 @@ mod builder; mod cursor; mod merge; +mod multi_level_merge; pub mod partial_sort; pub mod sort; pub mod sort_preserving_merge; diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs new file mode 100644 index 000000000000..8bcf995be059 --- /dev/null +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -0,0 +1,342 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Create a stream that do a multi level merge stream + +use crate::metrics::BaselineMetrics; +use crate::{EmptyRecordBatchStream, SpillManager}; +use arrow::array::RecordBatch; +use std::fmt::{Debug, Formatter}; +use std::mem; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::Result; +use datafusion_execution::memory_pool::{ + MemoryConsumer, MemoryPool, MemoryReservation, UnboundedMemoryPool, +}; + +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; +use crate::stream::RecordBatchStreamAdapter; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::TryStreamExt; +use futures::{Stream, StreamExt}; + +/// Merges a stream of sorted cursors and record batches into a single sorted stream +pub(crate) struct MultiLevelMergeBuilder { + spill_manager: SpillManager, + schema: SchemaRef, + sorted_spill_files: Vec, + sorted_streams: Vec, + expr: LexOrdering, + metrics: BaselineMetrics, + batch_size: usize, + reservation: MemoryReservation, + fetch: Option, + enable_round_robin_tie_breaker: bool, + + // This is for avoiding double reservation of memory from our side and the sort preserving merge stream + // side. + // and doing a lot of code changes to avoid accounting for the memory used by the streams + unbounded_memory_pool: Arc, +} + +impl Debug for MultiLevelMergeBuilder { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MultiLevelMergeBuilder") + } +} + +impl MultiLevelMergeBuilder { + pub(crate) fn new( + spill_manager: SpillManager, + schema: SchemaRef, + sorted_spill_files: Vec, + sorted_streams: Vec, + expr: LexOrdering, + metrics: BaselineMetrics, + batch_size: usize, + reservation: MemoryReservation, + fetch: Option, + enable_round_robin_tie_breaker: bool, + ) -> Self { + Self { + spill_manager, + schema, + sorted_spill_files, + sorted_streams, + expr, + metrics, + batch_size, + reservation, + enable_round_robin_tie_breaker, + fetch, + unbounded_memory_pool: Arc::new(UnboundedMemoryPool::default()), + } + } + + pub(crate) fn create_spillable_merge_stream(self) -> SendableRecordBatchStream { + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + futures::stream::once(self.create_stream()).try_flatten(), + )) + } + + async fn create_stream(mut self) -> Result { + loop { + // Hold this for the lifetime of the stream + let mut current_memory_reservation = self.reservation.new_empty(); + let mut stream = + self.create_sorted_stream(&mut current_memory_reservation)?; + + // TODO - add a threshold for number of files to disk even if empty and reading from disk so + // we can avoid the memory reservation + + // If no spill files are left, we can return the stream as this is the last sorted run + // TODO - We can write to disk before reading it back to avoid having multiple streams in memory + if self.sorted_spill_files.is_empty() { + // Attach the memory reservation to the stream as we are done with it + // but because we replaced the memory reservation of the merge stream, we must hold + // this to make sure we have enough memory + return Ok(Box::pin(StreamAttachedReservation::new( + stream, + current_memory_reservation, + ))); + } + + // Need to sort to a spill file + let Some((spill_file, max_record_batch_memory)) = self + .spill_manager + .spill_record_batch_stream_by_size( + &mut stream, + self.batch_size, + "MultiLevelMergeBuilder intermediate spill", + ) + .await? + else { + continue; + }; + + // Add the spill file + self.sorted_spill_files.push(SortedSpillFile { + file: spill_file, + max_record_batch_memory, + }); + } + } + + fn create_sorted_stream( + &mut self, + memory_reservation: &mut MemoryReservation, + ) -> Result { + match (self.sorted_spill_files.len(), self.sorted_streams.len()) { + // No data so empty batch + (0, 0) => Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone( + &self.schema, + )))), + + // Only in-memory stream, return that + (0, 1) => Ok(self.sorted_streams.remove(0)), + + // Only single sorted spill file so return it + (1, 0) => { + let spill_file = self.sorted_spill_files.remove(0); + + self.spill_manager + .read_spill_as_stream(spill_file.file) + } + + // Only in memory streams, so merge them all in a single pass + (0, _) => { + let sorted_stream = mem::take(&mut self.sorted_streams); + self.create_new_merge_sort( + sorted_stream, + // If we have no sorted spill files left, this is the last run + true, + ) + } + + // Need to merge multiple streams + (_, _) => { + // Don't account for existing streams memory + // as we are not holding the memory for them + let mut sorted_streams = mem::take(&mut self.sorted_streams); + + let (sorted_spill_files, buffer_size) = self + .get_sorted_spill_files_to_merge( + 2, + // we must have at least 2 streams to merge + 2_usize.saturating_sub(sorted_streams.len()), + memory_reservation, + )?; + + for spill in sorted_spill_files { + let stream = self + .spill_manager + .clone() + .with_batch_read_buffer_capacity(buffer_size) + .read_spill_as_stream(spill.file)?; + sorted_streams.push(stream); + } + + self.create_new_merge_sort( + sorted_streams, + // If we have no sorted spill files left, this is the last run + self.sorted_spill_files.is_empty(), + ) + } + } + } + + fn create_new_merge_sort( + &mut self, + streams: Vec, + is_output: bool, + ) -> Result { + StreamingMergeBuilder::new() + .with_schema(Arc::clone(&self.schema)) + .with_expressions(&self.expr) + .with_batch_size(self.batch_size) + .with_fetch(self.fetch) + .with_metrics(if is_output { + // Only add the metrics to the last run + self.metrics.clone() + } else { + self.metrics.intermediate() + }) + .with_round_robin_tie_breaker(self.enable_round_robin_tie_breaker) + .with_streams(streams) + // Don't track memory used by this stream as we reserve that memory by worst case sceneries + // (reserving memory for the biggest batch in each stream) + // This is a hack + .with_reservation( + MemoryConsumer::new("merge stream mock memory") + .register(&self.unbounded_memory_pool), + ) + .build() + } + + /// Return the sorted spill files to use for the next phase, and the buffer size + /// This will try to get as many spill files as possible to merge, and if we don't have enough streams + /// it will try to reduce the buffer size until we have enough streams to merge + /// otherwise it will return an error + fn get_sorted_spill_files_to_merge( + &mut self, + buffer_size: usize, + minimum_number_of_required_streams: usize, + reservation: &mut MemoryReservation, + ) -> Result<(Vec, usize)> { + assert_ne!(buffer_size, 0, "Buffer size must be greater than 0"); + let mut number_of_spills_to_read_for_current_phase = 0; + + for spill in &self.sorted_spill_files { + // For memory pools that are not shared this is good, for other this is not + // and there should be some upper limit to memory reservation so we won't starve the system + match reservation.try_grow(spill.max_record_batch_memory * buffer_size) { + Ok(_) => { + number_of_spills_to_read_for_current_phase += 1; + } + // If we can't grow the reservation, we need to stop + Err(err) => { + // We must have at least 2 streams to merge, so if we don't have enough memory + // fail + if minimum_number_of_required_streams + > number_of_spills_to_read_for_current_phase + { + // Free the memory we reserved for this merge as we either try again or fail + reservation.free(); + if buffer_size > 1 { + // Try again with smaller buffer size, it will be slower but at least we can merge + return self.get_sorted_spill_files_to_merge( + buffer_size - 1, + minimum_number_of_required_streams, + reservation, + ); + } + + return Err(err); + } + + // We reached the maximum amount of memory we can use + // for this merge + break; + } + } + } + + let spills = self + .sorted_spill_files + .drain(..number_of_spills_to_read_for_current_phase) + .collect::>(); + + Ok((spills, buffer_size)) + } +} + +struct StreamAttachedReservation { + stream: SendableRecordBatchStream, + reservation: MemoryReservation, +} + +impl StreamAttachedReservation { + fn new(stream: SendableRecordBatchStream, reservation: MemoryReservation) -> Self { + Self { + stream, + reservation, + } + } +} + +impl Stream for StreamAttachedReservation { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = self.stream.poll_next_unpin(cx); + + match res { + Poll::Ready(res) => { + match res { + Some(Ok(batch)) => Poll::Ready(Some(Ok(batch))), + Some(Err(err)) => { + // Had an error so drop the data + self.reservation.free(); + Poll::Ready(Some(Err(err))) + } + None => { + // Stream is done so free the memory + self.reservation.free(); + + Poll::Ready(None) + } + } + } + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for StreamAttachedReservation { + fn schema(&self) -> SchemaRef { + self.stream.schema() + } +} diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 5cc7512f8813..bfac868480c4 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -32,8 +32,9 @@ use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; use crate::projection::{make_with_child, update_expr, ProjectionExec}; -use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; +use crate::spill::get_size::GetActualSize; use crate::spill::in_progress_spill_file::InProgressSpillFile; use crate::spill::spill_manager::SpillManager; use crate::stream::RecordBatchStreamAdapter; @@ -53,7 +54,6 @@ use arrow::row::{RowConverter, Rows, SortField}; use datafusion_common::{ exec_datafusion_err, internal_datafusion_err, internal_err, Result, }; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; @@ -223,12 +223,12 @@ struct ExternalSorter { /// During external sorting, in-memory intermediate data will be appended to /// this file incrementally. Once finished, this file will be moved to [`Self::finished_spill_files`]. - in_progress_spill_file: Option, + in_progress_spill_file: Option<(InProgressSpillFile, usize)>, /// If data has previously been spilled, the locations of the spill files (in /// Arrow IPC format) /// Within the same spill file, the data might be chunked into multiple batches, /// and ordered by sort keys. - finished_spill_files: Vec, + finished_spill_files: Vec, // ======================================================================== // EXECUTION RESOURCES: @@ -355,8 +355,6 @@ impl ExternalSorter { self.merge_reservation.free(); if self.spilled_before() { - let mut streams = vec![]; - // Sort `in_mem_batches` and spill it first. If there are many // `in_mem_batches` and the memory limit is almost reached, merging // them with the spilled files at the same time might cause OOM. @@ -364,18 +362,11 @@ impl ExternalSorter { self.sort_and_spill_in_mem_batches().await?; } - for spill in self.finished_spill_files.drain(..) { - if !spill.path().exists() { - return internal_err!("Spill file {:?} does not exist", spill.path()); - } - let stream = self.spill_manager.read_spill_as_stream(spill)?; - streams.push(stream); - } - let expressions: LexOrdering = self.expr.iter().cloned().collect(); StreamingMergeBuilder::new() - .with_streams(streams) + .with_sorted_spill_files(std::mem::take(&mut self.finished_spill_files)) + .with_spill_manager(self.spill_manager.clone()) .with_schema(Arc::clone(&self.schema)) .with_expressions(expressions.as_ref()) .with_metrics(self.metrics.baseline.clone()) @@ -421,7 +412,7 @@ impl ExternalSorter { // Lazily initialize the in-progress spill file if self.in_progress_spill_file.is_none() { self.in_progress_spill_file = - Some(self.spill_manager.create_in_progress_file("Sorting")?); + Some((self.spill_manager.create_in_progress_file("Sorting")?, 0)); } Self::organize_stringview_arrays(globally_sorted_batches)?; @@ -431,12 +422,16 @@ impl ExternalSorter { let batches_to_spill = std::mem::take(globally_sorted_batches); self.reservation.free(); - let in_progress_file = self.in_progress_spill_file.as_mut().ok_or_else(|| { - internal_datafusion_err!("In-progress spill file should be initialized") - })?; + let (in_progress_file, max_record_batch_size) = + self.in_progress_spill_file.as_mut().ok_or_else(|| { + internal_datafusion_err!("In-progress spill file should be initialized") + })?; for batch in batches_to_spill { in_progress_file.append_batch(&batch)?; + + *max_record_batch_size = + (*max_record_batch_size).max(batch.get_actually_used_size()); } if !globally_sorted_batches.is_empty() { @@ -448,14 +443,17 @@ impl ExternalSorter { /// Finishes the in-progress spill file and moves it to the finished spill files. async fn spill_finish(&mut self) -> Result<()> { - let mut in_progress_file = + let (mut in_progress_file, max_record_batch_memory) = self.in_progress_spill_file.take().ok_or_else(|| { internal_datafusion_err!("Should be called after `spill_append`") })?; let spill_file = in_progress_file.finish()?; if let Some(spill_file) = spill_file { - self.finished_spill_files.push(spill_file); + self.finished_spill_files.push(SortedSpillFile { + file: spill_file, + max_record_batch_memory, + }); } Ok(()) diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index 3f022ec6095a..a5d055d8f693 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -19,15 +19,17 @@ //! This is an order-preserving merge. use crate::metrics::BaselineMetrics; +use crate::sorts::multi_level_merge::MultiLevelMergeBuilder; use crate::sorts::{ merge::SortPreservingMergeStream, stream::{FieldCursorStream, RowCursorStream}, }; -use crate::SendableRecordBatchStream; +use crate::{SendableRecordBatchStream, SpillManager}; use arrow::array::*; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{internal_err, Result}; -use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::{human_readable_size, MemoryReservation}; use datafusion_physical_expr_common::sort_expr::LexOrdering; macro_rules! primitive_merge_helper { @@ -52,8 +54,28 @@ macro_rules! merge_helper { }}; } +pub struct SortedSpillFile { + pub file: RefCountedTempFile, + + /// how much memory the largest memory batch is taking + pub max_record_batch_memory: usize, +} + +impl std::fmt::Debug for SortedSpillFile { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "SortedSpillFile({:?}) takes {}", + self.file.path(), + human_readable_size(self.max_record_batch_memory) + ) + } +} + pub struct StreamingMergeBuilder<'a> { streams: Vec, + sorted_spill_files: Vec, + spill_manager: Option, schema: Option, expressions: &'a LexOrdering, metrics: Option, @@ -67,6 +89,8 @@ impl Default for StreamingMergeBuilder<'_> { fn default() -> Self { Self { streams: vec![], + sorted_spill_files: vec![], + spill_manager: None, schema: None, expressions: LexOrdering::empty(), metrics: None, @@ -91,6 +115,19 @@ impl<'a> StreamingMergeBuilder<'a> { self } + pub fn with_sorted_spill_files( + mut self, + sorted_spill_files: Vec, + ) -> Self { + self.sorted_spill_files = sorted_spill_files; + self + } + + pub fn with_spill_manager(mut self, spill_manager: SpillManager) -> Self { + self.spill_manager = Some(spill_manager); + self + } + pub fn with_schema(mut self, schema: SchemaRef) -> Self { self.schema = Some(schema); self @@ -136,6 +173,8 @@ impl<'a> StreamingMergeBuilder<'a> { pub fn build(self) -> Result { let Self { streams, + sorted_spill_files, + spill_manager, schema, metrics, batch_size, @@ -145,6 +184,22 @@ impl<'a> StreamingMergeBuilder<'a> { enable_round_robin_tie_breaker, } = self; + if !sorted_spill_files.is_empty() && spill_manager.is_some() { + return Ok(MultiLevelMergeBuilder::new( + spill_manager.unwrap(), + schema.unwrap(), + sorted_spill_files, + streams, + expressions.clone(), + metrics.unwrap(), + batch_size.unwrap(), + reservation.unwrap(), + fetch, + enable_round_robin_tie_breaker, + ) + .create_spillable_merge_stream()); + } + // Early return if streams or expressions are empty let checks = [ ( diff --git a/datafusion/physical-plan/src/spill/get_size.rs b/datafusion/physical-plan/src/spill/get_size.rs new file mode 100644 index 000000000000..3e67c03214d0 --- /dev/null +++ b/datafusion/physical-plan/src/spill/get_size.rs @@ -0,0 +1,199 @@ +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, BooleanArray, FixedSizeBinaryArray, + FixedSizeListArray, GenericByteArray, GenericListArray, OffsetSizeTrait, + PrimitiveArray, RecordBatch, StructArray, +}; +use arrow::buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer}; +use arrow::datatypes::{ArrowNativeType, ByteArrayType}; +use arrow::downcast_primitive_array; +use arrow_schema::DataType; + +/// TODO - NEED TO MOVE THIS TO ARROW +/// this is needed as unlike `get_buffer_memory_size` or `get_array_memory_size` +/// or even `get_record_batch_memory_size` calling it on a batch before writing to spill and after reading it back +/// will return the same size (minus some optimization that IPC writer does for dictionaries) +/// +pub trait GetActualSize { + fn get_actually_used_size(&self) -> usize; +} + +impl GetActualSize for RecordBatch { + fn get_actually_used_size(&self) -> usize { + self.columns() + .iter() + .map(|c| c.get_actually_used_size()) + .sum() + } +} + +pub trait GetActualSizeArray: Array { + fn get_actually_used_size(&self) -> usize { + self.get_buffer_memory_size() + } +} + +impl GetActualSize for dyn Array { + fn get_actually_used_size(&self) -> usize { + use arrow::array::AsArray; + let array = self; + + // we can avoid this is we move this trait function to be in arrow + downcast_primitive_array!( + array => { + array.get_actually_used_size() + }, + DataType::Utf8 => { + array.as_string::().get_actually_used_size() + }, + DataType::LargeUtf8 => { + array.as_string::().get_actually_used_size() + }, + DataType::Binary => { + array.as_binary::().get_actually_used_size() + }, + DataType::LargeBinary => { + array.as_binary::().get_actually_used_size() + }, + DataType::FixedSizeBinary(_) => { + array.as_fixed_size_binary().get_actually_used_size() + }, + DataType::Struct(_) => { + array.as_struct().get_actually_used_size() + }, + DataType::List(_) => { + array.as_list::().get_actually_used_size() + }, + DataType::LargeList(_) => { + array.as_list::().get_actually_used_size() + }, + DataType::FixedSizeList(_, _) => { + array.as_fixed_size_list().get_actually_used_size() + }, + DataType::Boolean => { + array.as_boolean().get_actually_used_size() + }, + + _ => { + array.get_buffer_memory_size() + } + ) + } +} + +impl GetActualSizeArray for ArrayRef { + fn get_actually_used_size(&self) -> usize { + self.as_ref().get_actually_used_size() + } +} + +impl GetActualSize for Option<&NullBuffer> { + fn get_actually_used_size(&self) -> usize { + self.map(|b| b.get_actually_used_size()).unwrap_or(0) + } +} + +impl GetActualSize for NullBuffer { + fn get_actually_used_size(&self) -> usize { + // len return in bits + self.len() / 8 + } +} +impl GetActualSize for Buffer { + fn get_actually_used_size(&self) -> usize { + self.len() + } +} + +impl GetActualSize for BooleanBuffer { + fn get_actually_used_size(&self) -> usize { + // len return in bits + self.len() / 8 + } +} + +impl GetActualSize for OffsetBuffer { + fn get_actually_used_size(&self) -> usize { + self.inner().inner().get_actually_used_size() + } +} + +impl GetActualSizeArray for BooleanArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self.values().get_actually_used_size(); + null_size + values_size + } +} + +impl GetActualSizeArray for GenericByteArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let offsets_size = self.offsets().get_actually_used_size(); + + let values_size = { + let first_offset = self.value_offsets()[0].as_usize(); + let last_offset = self.value_offsets()[self.len()].as_usize(); + last_offset - first_offset + }; + null_size + offsets_size + values_size + } +} + +impl GetActualSizeArray for FixedSizeBinaryArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self.value_length() as usize * self.len(); + + null_size + values_size + } +} + +impl GetActualSizeArray for PrimitiveArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self.values().inner().get_actually_used_size(); + + null_size + values_size + } +} + +impl GetActualSizeArray for GenericListArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let offsets_size = self.offsets().get_actually_used_size(); + + let values_size = { + let first_offset = self.value_offsets()[0].as_usize(); + let last_offset = self.value_offsets()[self.len()].as_usize(); + + self.values() + .slice(first_offset, last_offset - first_offset) + .get_actually_used_size() + }; + null_size + offsets_size + values_size + } +} + +impl GetActualSizeArray for FixedSizeListArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self.values().get_actually_used_size(); + + null_size + values_size + } +} + +impl GetActualSizeArray for StructArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self + .columns() + .iter() + .map(|array| array.get_actually_used_size()) + .sum::(); + + null_size + values_size + } +} + +// TODO - need to add to more arrays diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 1101616a4106..b007276c5858 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -17,6 +17,7 @@ //! Defines the spilling functions +pub(crate) mod get_size; pub(crate) mod in_progress_spill_file; pub(crate) mod spill_manager; diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 78cd47a8bad0..2731c46aba26 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -17,19 +17,18 @@ //! Define the `SpillManager` struct, which is responsible for reading and writing `RecordBatch`es to raw files based on the provided configurations. -use std::sync::Arc; - use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_execution::runtime_env::RuntimeEnv; +use std::sync::Arc; use datafusion_common::Result; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::SendableRecordBatchStream; -use crate::{common::spawn_buffered, metrics::SpillMetrics}; - use super::{in_progress_spill_file::InProgressSpillFile, SpillReaderStream}; +use crate::spill::get_size::GetActualSize; +use crate::{common::spawn_buffered, metrics::SpillMetrics}; /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. @@ -57,6 +56,14 @@ impl SpillManager { } } + pub fn with_batch_read_buffer_capacity( + mut self, + batch_read_buffer_capacity: usize, + ) -> Self { + self.batch_read_buffer_capacity = batch_read_buffer_capacity; + self + } + /// Creates a temporary file for in-progress operations, returning an error /// message if file creation fails. The file can be used to append batches /// incrementally and then finish the file when done. @@ -118,6 +125,156 @@ impl SpillManager { self.spill_record_batch_and_finish(&batches, request_description) } + /// Refer to the documentation for [`Self::spill_record_batch_and_finish`]. This method + /// additionally spills the `RecordBatch` into smaller batches, divided by `row_limit`. + /// + /// # Errors + /// - Returns an error if spilling would exceed the disk usage limit configured + /// by `max_temp_directory_size` in `DiskManager` + pub(crate) fn spill_record_batch_by_size_and_return_max_batch_memory( + &self, + batch: &RecordBatch, + request_description: &str, + row_limit: usize, + ) -> Result> { + let total_rows = batch.num_rows(); + let mut batches = Vec::new(); + let mut offset = 0; + + // It's ok to calculate all slices first, because slicing is zero-copy. + while offset < total_rows { + let length = std::cmp::min(total_rows - offset, row_limit); + let sliced_batch = batch.slice(offset, length); + batches.push(sliced_batch); + offset += length; + } + + let mut in_progress_file = self.create_in_progress_file(request_description)?; + + let mut max_record_batch_size = 0; + + for batch in batches { + in_progress_file.append_batch(&batch)?; + + max_record_batch_size = + max_record_batch_size.max(batch.get_actually_used_size()); + } + + let file = in_progress_file.finish()?; + + Ok(file.map(|f| (f, max_record_batch_size))) + } + + /// Spill the `RecordBatch` to disk as smaller batches + /// split by `batch_size_rows`. + /// + /// will return the spill file and the size of the largest batch in memory + pub async fn spill_record_batch_stream_by_size( + &self, + stream: &mut SendableRecordBatchStream, + batch_size_rows: usize, + request_msg: &str, + ) -> Result> { + use futures::StreamExt; + let mut in_progress_file = self.create_in_progress_file(request_msg)?; + + let mut max_record_batch_size = 0; + + let mut maybe_last_batch: Option = None; + + while let Some(batch) = stream.next().await { + let mut batch = batch?; + + if let Some(mut last_batch) = maybe_last_batch.take() { + assert!( + last_batch.num_rows() < batch_size_rows, + "last batch size must be smaller than the requested batch size" + ); + + // Get the number of rows to take from current batch so the last_batch + // will have `batch_size_rows` rows + let current_batch_offset = std::cmp::min( + // rows needed to fill + batch_size_rows - last_batch.num_rows(), + // Current length of the batch + batch.num_rows(), + ); + + // if have last batch that has less rows than concat and spill + last_batch = arrow::compute::concat_batches( + &stream.schema(), + &[last_batch, batch.slice(0, current_batch_offset)], + )?; + + assert!(last_batch.num_rows() <= batch_size_rows, "must build a batch that is smaller or equal to the requested batch size from the current batch"); + + // If not enough rows + if last_batch.num_rows() < batch_size_rows { + // keep the last batch for next iteration + maybe_last_batch = Some(last_batch); + continue; + } + + max_record_batch_size = + max_record_batch_size.max(last_batch.get_actually_used_size()); + + in_progress_file.append_batch(&last_batch)?; + + if current_batch_offset == batch.num_rows() { + // No remainder + continue; + } + + // remainder + batch = batch.slice( + current_batch_offset, + batch.num_rows() - current_batch_offset, + ); + } + + let mut offset = 0; + let total_rows = batch.num_rows(); + + // Keep slicing the batch until we have left with a batch that is smaller than + // the wanted batch size + while total_rows - offset >= batch_size_rows { + let batch = batch.slice(offset, batch_size_rows); + offset += batch_size_rows; + + max_record_batch_size = + max_record_batch_size.max(batch.get_actually_used_size()); + + in_progress_file.append_batch(&batch)?; + } + + // If there is a remainder for the current batch that is smaller than the wanted batch size + // keep it for next iteration + if offset < total_rows { + // remainder + let batch = batch.slice(offset, total_rows - offset); + + maybe_last_batch = Some(batch); + } + } + if let Some(last_batch) = maybe_last_batch.take() { + assert!( + last_batch.num_rows() < batch_size_rows, + "last batch size must be smaller than the requested batch size" + ); + + // Write it to disk + in_progress_file.append_batch(&last_batch)?; + + max_record_batch_size = + max_record_batch_size.max(last_batch.get_actually_used_size()); + } + + // Flush disk + let spill_file = in_progress_file.finish()?; + + Ok(spill_file.map(|f| (f, max_record_batch_size))) + } + /// Reads a spill file as a stream. The file must be created by the current `SpillManager`. /// This method will generate output in FIFO order: the batch appended first /// will be read first. From 58b299dacc75671f295a4763517711cbbfdcc51b Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 14 Apr 2025 00:50:57 +0300 Subject: [PATCH 02/21] test: add fuzz test for aggregate --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 253 +++++++++++++++++- 1 file changed, 244 insertions(+), 9 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index ff3b66986ced..09def1ff9754 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::fuzz_cases::aggregation_fuzzer::{ AggregationFuzzerBuilder, DatasetGeneratorConfig, QueryBuilder, }; +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; use arrow::array::{ types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, - StringArray, + StringArray, UInt64Array, }; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; @@ -40,24 +43,31 @@ use datafusion::physical_plan::aggregates::{ use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::HashMap; +use datafusion_common::{DataFusionError, HashMap}; use datafusion_common_runtime::JoinSet; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{col, lit, Column}; -use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::InputOrderMode; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, InputOrderMode, PlanProperties, +}; +use futures::{Stream, StreamExt}; use test_utils::{add_empty_batches, StringBatchGenerator}; +use super::record_batch_generator::get_supported_types_columns; +use datafusion_datasource::source::DataSource; +use datafusion_execution::memory_pool::units::MB; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_execution::TaskContext; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::metrics::MetricValue; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::rngs::StdRng; use rand::{random, thread_rng, Rng, SeedableRng}; -use super::record_batch_generator::get_supported_types_columns; - // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] // ======================================================================== @@ -640,6 +650,8 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc 0 { panic!("Expected no spill but found SpillCount metric with value greater than 0."); } + + println!("SpillCount = {}", spill_count); } else { panic!("No metrics returned from the operator; cannot verify spilling."); } @@ -753,3 +765,226 @@ async fn test_single_mode_aggregate_with_spill() -> Result<()> { Ok(()) } + +/// A Mock ExecutionPlan that can be used for writing tests of other +/// ExecutionPlans +pub struct StreamExec { + /// the results to send back + stream: Mutex>, + /// if true (the default), sends data using a separate task to ensure the + /// batches are not available without this stream yielding first + use_task: bool, + cache: PlanProperties, +} + +impl Debug for StreamExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "StreamExec") + } +} + +impl StreamExec { + /// Create a new `MockExec` with a single partition that returns + /// the specified `Results`s. + /// + /// By default, the batches are not produced immediately (the + /// caller has to actually yield and another task must run) to + /// ensure any poll loops are correct. This behavior can be + /// changed with `with_use_task` + pub fn new(stream: SendableRecordBatchStream) -> Self { + let cache = Self::compute_properties(stream.schema()); + Self { + stream: Mutex::new(Some(stream)), + use_task: true, + cache, + } + } + + /// If `use_task` is true (the default) then the batches are sent + /// back using a separate task to ensure the underlying stream is + /// not immediately ready + pub fn with_use_task(mut self, use_task: bool) -> Self { + self.use_task = use_task; + self + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + } +} + +impl DisplayAs for StreamExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "StreamExec:") + } + DisplayFormatType::TreeRender => { + // TODO: collect info + write!(f, "") + } + } + } +} + +impl ExecutionPlan for StreamExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + /// Returns a stream which yields data + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + assert_eq!(partition, 0); + + let stream = self.stream.lock().unwrap().take(); + + stream.ok_or(DataFusionError::Internal( + "Stream already consumed".to_string(), + )) + } +} + + +#[tokio::test] +async fn test_low_cardinality() -> Result<()> { + let record_batch_size = 8192; + let two_mb = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + run_test_low_cardinality(task_ctx, 100).await +} + +async fn run_test_low_cardinality( + task_ctx: TaskContext, + number_of_record_batches: usize, +) -> Result<()> { + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("col_0", 0)), + "col_0".to_string(), + )]); + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new( + array_agg_udaf(), + vec![col("col_1", &scan_schema).unwrap()], + ) + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, + )]; + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map(move |index| { + let string_array = if index == number_of_record_batches as u64 / 2 { + // Simulate a large record batch in a middle of the stream + Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "b".repeat(24)), + )) + } else { + Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(8)), + )) + }; + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) + .map_err(|err| err.into()) + })), + )))); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + aggregate_exec, + Arc::clone(&scan_schema), + )?); + + let task_ctx = Arc::new(task_ctx); + + let mut result = aggregate_final.execute(0, task_ctx)?; + + let mut number_of_groups = 0; + + while let Some(batch) = result.next().await { + let batch = batch?; + number_of_groups += batch.num_rows(); + } + + assert_eq!(number_of_groups, number_of_record_batches * record_batch_size as usize); + + assert_spill_count_metric(true, aggregate_final); + + Ok(()) +} From a8ac81e8cdd07ce0ce9b431ef1304b895ff16348 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 22:45:24 +0300 Subject: [PATCH 03/21] update --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 55 ++++++++++--------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 09def1ff9754..e0d73eab03e1 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -926,32 +926,35 @@ async fn run_test_low_cardinality( let plan: Arc = Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map(move |index| { - let string_array = if index == number_of_record_batches as u64 / 2 { - // Simulate a large record batch in a middle of the stream - Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "b".repeat(24)), - )) - } else { - Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(8)), - )) - }; - - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - // Grouping key - Arc::new(UInt64Array::from_iter_values( - (index * record_batch_size) - ..(index * record_batch_size) + record_batch_size, - )), - // Grouping value - string_array, - ], - ) - .map_err(|err| err.into()) - })), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + // Simulate a large record batch 3 times in a stream + let string_array = + if index % (number_of_record_batches as u64 / 3) == 0 { + Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "b".repeat(64)), + )) + } else { + Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(8)), + )) + }; + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), )))); let aggregate_exec = Arc::new(AggregateExec::try_new( From 082a009c3edeace867f17037617bc85ad763870c Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:16:46 +0300 Subject: [PATCH 04/21] add more tests --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 136 ++++++++++++------ 1 file changed, 89 insertions(+), 47 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index e0d73eab03e1..c9bf89e8af09 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -29,7 +29,7 @@ use arrow::array::{ StringArray, UInt64Array, }; use arrow::compute::{concat_batches, SortOptions}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, UInt64Type}; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{Field, Schema, SchemaRef}; use datafusion::common::Result; @@ -57,7 +57,7 @@ use test_utils::{add_empty_batches, StringBatchGenerator}; use super::record_batch_generator::get_supported_types_columns; use datafusion_datasource::source::DataSource; -use datafusion_execution::memory_pool::units::MB; +use datafusion_execution::memory_pool::units::{KB, MB}; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; @@ -879,25 +879,69 @@ impl ExecutionPlan for StreamExec { #[tokio::test] -async fn test_low_cardinality() -> Result<()> { +async fn test_high_cardinality_with_limited_memory() -> Result<()> { let record_batch_size = 8192; let two_mb = 2 * MB as usize; let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(two_mb)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; - run_test_low_cardinality(task_ctx, 100).await + + run_test_high_cardinality(task_ctx, 100, Box::pin(|_| (16 * KB) as usize)).await } -async fn run_test_low_cardinality( + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> { + let record_batch_size = 8192; + let two_mb = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(task_ctx, 100, Box::pin(|i| { + if i % 25 == 0 { + (64 * KB) as usize + } else { + (16 * KB) as usize + } + })).await +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let two_mb = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| two_mb / 4)).await +} + +async fn run_test_high_cardinality( task_ctx: TaskContext, number_of_record_batches: usize, + get_size_of_record_batch_to_generate: Pin usize + Send + 'static>>, ) -> Result<()> { let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::UInt64, true), @@ -914,48 +958,43 @@ async fn run_test_low_cardinality( array_agg_udaf(), vec![col("col_1", &scan_schema).unwrap()], ) - .schema(Arc::clone(&scan_schema)) - .alias("array_agg(col_1)") - .build()?, + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, )]; let record_batch_size = task_ctx.session_config().batch_size() as u64; - let schema = Arc::clone(&scan_schema); let plan: Arc = - Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map( - move |index| { - // Simulate a large record batch 3 times in a stream - let string_array = - if index % (number_of_record_batches as u64 / 3) == 0 { - Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "b".repeat(64)), - )) - } else { - Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(8)), - )) - }; - - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - // Grouping key - Arc::new(UInt64Array::from_iter_values( - (index * record_batch_size) - ..(index * record_batch_size) + record_batch_size, - )), - // Grouping value - string_array, - ], - ) + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size.saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) .map_err(|err| err.into()) - }, - )), - )))); + }, + )), + )))); let aggregate_exec = Arc::new(AggregateExec::try_new( AggregateMode::Partial, @@ -985,7 +1024,10 @@ async fn run_test_low_cardinality( number_of_groups += batch.num_rows(); } - assert_eq!(number_of_groups, number_of_record_batches * record_batch_size as usize); + assert_eq!( + number_of_groups, + number_of_record_batches * record_batch_size as usize + ); assert_spill_count_metric(true, aggregate_final); From 6cdf5dc676cb2e2a980b6cbe0e3554f5dd0706d3 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:27:58 +0300 Subject: [PATCH 05/21] fix test --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 187 ++++++++++-------- .../src/sorts/multi_level_merge.rs | 3 +- 2 files changed, 111 insertions(+), 79 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index c9bf89e8af09..7dc5d162080b 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -633,7 +633,10 @@ fn extract_result_counts(results: Vec) -> HashMap, i output } -fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc) { +fn assert_spill_count_metric( + expect_spill: bool, + single_aggregate: Arc, +) -> usize { if let Some(metrics_set) = single_aggregate.metrics() { let mut spill_count = 0; @@ -652,6 +655,8 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "StreamExec:") @@ -877,72 +878,101 @@ impl ExecutionPlan for StreamExec { } } - #[tokio::test] async fn test_high_cardinality_with_limited_memory() -> Result<()> { let record_batch_size = 8192; - let two_mb = 2 * MB as usize; + let pool_size = 2 * MB as usize; let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; - run_test_high_cardinality(task_ctx, 100, Box::pin(|_| (16 * KB) as usize)).await -} + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| record_batch_size)) + .await?; + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {} should be greater than pool size {}", + total_spill_files_size, + pool_size + ); + + Ok(()) +} #[tokio::test] -async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> { +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( +) -> Result<()> { let record_batch_size = 8192; - let two_mb = 2 * MB as usize; + let pool_size = 2 * MB as usize; let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; - run_test_high_cardinality(task_ctx, 100, Box::pin(|i| { - if i % 25 == 0 { - (64 * KB) as usize - } else { - (16 * KB) as usize - } - })).await + run_test_high_cardinality( + task_ctx, + 100, + Box::pin(move |i| { + if i + 1 % 25 == 0 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + ) + .await?; + + Ok(()) } #[tokio::test] -async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> { +async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> +{ let record_batch_size = 8192; - let two_mb = 2 * MB as usize; + let pool_size = 2 * MB as usize; let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; - run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| two_mb / 4)).await + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| pool_size / 4)).await?; + + Ok(()) } async fn run_test_high_cardinality( task_ctx: TaskContext, number_of_record_batches: usize, - get_size_of_record_batch_to_generate: Pin usize + Send + 'static>>, -) -> Result<()> { + get_size_of_record_batch_to_generate: Pin< + Box usize + Send + 'static>, + >, +) -> Result { let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::UInt64, true), Field::new("col_1", DataType::Utf8, true), @@ -958,43 +988,46 @@ async fn run_test_high_cardinality( array_agg_udaf(), vec![col("col_1", &scan_schema).unwrap()], ) - .schema(Arc::clone(&scan_schema)) - .alias("array_agg(col_1)") - .build()?, + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, )]; let record_batch_size = task_ctx.session_config().batch_size() as u64; let schema = Arc::clone(&scan_schema); let plan: Arc = - Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map( - move |index| { - let mut record_batch_memory_size = get_size_of_record_batch_to_generate(index as usize); - record_batch_memory_size = record_batch_memory_size.saturating_sub(size_of::() * record_batch_size as usize); - - let string_item_size = record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); - - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - // Grouping key - Arc::new(UInt64Array::from_iter_values( - (index * record_batch_size) - ..(index * record_batch_size) + record_batch_size, - )), - // Grouping value - string_array, - ], - ) + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) .map_err(|err| err.into()) - }, - )), - )))); + }, + )), + )))); let aggregate_exec = Arc::new(AggregateExec::try_new( AggregateMode::Partial, @@ -1029,7 +1062,7 @@ async fn run_test_high_cardinality( number_of_record_batches * record_batch_size as usize ); - assert_spill_count_metric(true, aggregate_final); + let spill_count = assert_spill_count_metric(true, aggregate_final); - Ok(()) + Ok(spill_count) } diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 8bcf995be059..c07fb375ff18 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -159,8 +159,7 @@ impl MultiLevelMergeBuilder { (1, 0) => { let spill_file = self.sorted_spill_files.remove(0); - self.spill_manager - .read_spill_as_stream(spill_file.file) + self.spill_manager.read_spill_as_stream(spill_file.file) } // Only in memory streams, so merge them all in a single pass From 04bb5ed8b3c8ef190c006c542590932780e76ce1 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:32:54 +0300 Subject: [PATCH 06/21] update tests --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 110 +++++++++--------- 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 7dc5d162080b..4fc29acdf21e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -885,12 +885,12 @@ async fn test_high_cardinality_with_limited_memory() -> Result<()> { let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; let record_batch_size = pool_size / 16; @@ -898,8 +898,8 @@ async fn test_high_cardinality_with_limited_memory() -> Result<()> { // Basic test with a lot of groups that cannot all fit in memory and 1 record batch // from each spill file is too much memory let spill_count = - run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| record_batch_size)) - .await?; + run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| record_batch_size)) + .await?; let total_spill_files_size = spill_count * record_batch_size; assert!( @@ -920,26 +920,26 @@ async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; run_test_high_cardinality( task_ctx, 100, Box::pin(move |i| { - if i + 1 % 25 == 0 { + if i % 25 == 1 { pool_size / 4 } else { (16 * KB) as usize } }), ) - .await?; + .await?; Ok(()) } @@ -952,12 +952,12 @@ async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> R let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory @@ -988,46 +988,46 @@ async fn run_test_high_cardinality( array_agg_udaf(), vec![col("col_1", &scan_schema).unwrap()], ) - .schema(Arc::clone(&scan_schema)) - .alias("array_agg(col_1)") - .build()?, + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, )]; let record_batch_size = task_ctx.session_config().batch_size() as u64; let schema = Arc::clone(&scan_schema); let plan: Arc = - Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map( - move |index| { - let mut record_batch_memory_size = - get_size_of_record_batch_to_generate(index as usize); - record_batch_memory_size = record_batch_memory_size - .saturating_sub(size_of::() * record_batch_size as usize); - - let string_item_size = - record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); - - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - // Grouping key - Arc::new(UInt64Array::from_iter_values( - (index * record_batch_size) - ..(index * record_batch_size) + record_batch_size, - )), - // Grouping value - string_array, - ], - ) + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) .map_err(|err| err.into()) - }, - )), - )))); + }, + )), + )))); let aggregate_exec = Arc::new(AggregateExec::try_new( AggregateMode::Partial, From d630cf1254fb312fd8373ce9f6853ec59c18176d Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 16 Apr 2025 00:11:01 +0300 Subject: [PATCH 07/21] added more aggregate fuzz --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 278 +++++++++++++----- 1 file changed, 211 insertions(+), 67 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 4fc29acdf21e..093941696049 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -58,7 +58,9 @@ use test_utils::{add_empty_batches, StringBatchGenerator}; use super::record_batch_generator::get_supported_types_columns; use datafusion_datasource::source::DataSource; use datafusion_execution::memory_pool::units::{KB, MB}; -use datafusion_execution::memory_pool::FairSpillPool; +use datafusion_execution::memory_pool::{ + FairSpillPool, MemoryConsumer, MemoryReservation, +}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_functions_aggregate::array_agg::array_agg_udaf; @@ -885,21 +887,26 @@ async fn test_high_cardinality_with_limited_memory() -> Result<()> { let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; let record_batch_size = pool_size / 16; // Basic test with a lot of groups that cannot all fit in memory and 1 record batch // from each spill file is too much memory - let spill_count = - run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| record_batch_size)) - .await?; + let spill_count = run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; let total_spill_files_size = spill_count * record_batch_size; assert!( @@ -920,26 +927,96 @@ async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; - run_test_high_cardinality( + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, task_ctx, - 100, - Box::pin(move |i| { + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { pool_size / 4 } else { (16 * KB) as usize } }), - ) - .await?; + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; Ok(()) } @@ -952,27 +1029,52 @@ async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> R let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory - run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| pool_size / 4)).await?; + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; Ok(()) } -async fn run_test_high_cardinality( +struct RunTestHighCardinalityArgs { + pool_size: usize, task_ctx: TaskContext, number_of_record_batches: usize, - get_size_of_record_batch_to_generate: Pin< - Box usize + Send + 'static>, - >, -) -> Result { + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_test_high_cardinality(args: RunTestHighCardinalityArgs) -> Result { + let RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::UInt64, true), Field::new("col_1", DataType::Utf8, true), @@ -988,46 +1090,46 @@ async fn run_test_high_cardinality( array_agg_udaf(), vec![col("col_1", &scan_schema).unwrap()], ) - .schema(Arc::clone(&scan_schema)) - .alias("array_agg(col_1)") - .build()?, + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, )]; let record_batch_size = task_ctx.session_config().batch_size() as u64; let schema = Arc::clone(&scan_schema); let plan: Arc = - Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map( - move |index| { - let mut record_batch_memory_size = - get_size_of_record_batch_to_generate(index as usize); - record_batch_memory_size = record_batch_memory_size - .saturating_sub(size_of::() * record_batch_size as usize); - - let string_item_size = - record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); - - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - // Grouping key - Arc::new(UInt64Array::from_iter_values( - (index * record_batch_size) - ..(index * record_batch_size) + record_batch_size, - )), - // Grouping value - string_array, - ], - ) + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) .map_err(|err| err.into()) - }, - )), - )))); + }, + )), + )))); let aggregate_exec = Arc::new(AggregateExec::try_new( AggregateMode::Partial, @@ -1048,13 +1150,43 @@ async fn run_test_high_cardinality( let task_ctx = Arc::new(task_ctx); - let mut result = aggregate_final.execute(0, task_ctx)?; + let mut result = aggregate_final.execute(0, Arc::clone(&task_ctx))?; let mut number_of_groups = 0; + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("consume_memory_in_middle"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + let batch = batch?; number_of_groups += batch.num_rows(); + + index += 1; } assert_eq!( @@ -1066,3 +1198,15 @@ async fn run_test_high_cardinality( Ok(spill_count) } + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} From 5a551353b0654d140b706b8de898df53d43eb399 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 16 Apr 2025 00:24:25 +0300 Subject: [PATCH 08/21] align with add fuzz tests --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 133 ++---------------- datafusion/core/tests/fuzz_cases/mod.rs | 1 + .../core/tests/fuzz_cases/stream_exec.rs | 115 +++++++++++++++ 3 files changed, 125 insertions(+), 124 deletions(-) create mode 100644 datafusion/core/tests/fuzz_cases/stream_exec.rs diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 093941696049..eaa7c624c5ee 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -18,18 +18,15 @@ use crate::fuzz_cases::aggregation_fuzzer::{ AggregationFuzzerBuilder, DatasetGeneratorConfig, QueryBuilder, }; -use std::any::Any; -use std::fmt::{Debug, Formatter}; use std::pin::Pin; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll}; +use std::sync::Arc; use arrow::array::{ types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, StringArray, UInt64Array, }; use arrow::compute::{concat_batches, SortOptions}; -use arrow::datatypes::{DataType, UInt64Type}; +use arrow::datatypes::DataType; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{Field, Schema, SchemaRef}; use datafusion::common::Result; @@ -43,28 +40,25 @@ use datafusion::physical_plan::aggregates::{ use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::{DataFusionError, HashMap}; +use datafusion_common::HashMap; use datafusion_common_runtime::JoinSet; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{col, lit, Column}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalSortExpr}; +use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, InputOrderMode, PlanProperties, -}; -use futures::{Stream, StreamExt}; +use datafusion_physical_plan::InputOrderMode; +use futures::StreamExt; use test_utils::{add_empty_batches, StringBatchGenerator}; use super::record_batch_generator::get_supported_types_columns; -use datafusion_datasource::source::DataSource; +use crate::fuzz_cases::stream_exec::StreamExec; use datafusion_execution::memory_pool::units::{KB, MB}; use datafusion_execution::memory_pool::{ FairSpillPool, MemoryConsumer, MemoryReservation, }; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; +use datafusion_execution::TaskContext; use datafusion_functions_aggregate::array_agg::array_agg_udaf; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::metrics::MetricValue; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::rngs::StdRng; @@ -656,8 +650,6 @@ fn assert_spill_count_metric( panic!("Expected no spill but found SpillCount metric with value greater than 0."); } - println!("SpillCount = {}", spill_count); - spill_count } else { panic!("No metrics returned from the operator; cannot verify spilling."); @@ -773,113 +765,6 @@ async fn test_single_mode_aggregate_with_spill() -> Result<()> { Ok(()) } -/// A Mock ExecutionPlan that can be used for writing tests of other -/// ExecutionPlans -pub struct StreamExec { - /// the results to send back - stream: Mutex>, - /// if true (the default), sends data using a separate task to ensure the - /// batches are not available without this stream yielding first - use_task: bool, - cache: PlanProperties, -} - -impl Debug for StreamExec { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "StreamExec") - } -} - -impl StreamExec { - /// Create a new `MockExec` with a single partition that returns - /// the specified `Results`s. - /// - /// By default, the batches are not produced immediately (the - /// caller has to actually yield and another task must run) to - /// ensure any poll loops are correct. This behavior can be - /// changed with `with_use_task` - pub fn new(stream: SendableRecordBatchStream) -> Self { - let cache = Self::compute_properties(stream.schema()); - Self { - stream: Mutex::new(Some(stream)), - use_task: true, - cache, - } - } - - /// If `use_task` is true (the default) then the batches are sent - /// back using a separate task to ensure the underlying stream is - /// not immediately ready - pub fn with_use_task(mut self, use_task: bool) -> Self { - self.use_task = use_task; - self - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties(schema: SchemaRef) -> PlanProperties { - PlanProperties::new( - EquivalenceProperties::new(schema), - Partitioning::UnknownPartitioning(1), - EmissionType::Incremental, - Boundedness::Bounded, - ) - } -} - -impl DisplayAs for StreamExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "StreamExec:") - } - DisplayFormatType::TreeRender => { - // TODO: collect info - write!(f, "") - } - } - } -} - -impl ExecutionPlan for StreamExec { - fn name(&self) -> &'static str { - Self::static_name() - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - &self.cache - } - - fn children(&self) -> Vec<&Arc> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - unimplemented!() - } - - /// Returns a stream which yields data - fn execute( - &self, - partition: usize, - _context: Arc, - ) -> Result { - assert_eq!(partition, 0); - - let stream = self.stream.lock().unwrap().take(); - - stream.ok_or(DataFusionError::Internal( - "Stream already consumed".to_string(), - )) - } -} - #[tokio::test] async fn test_high_cardinality_with_limited_memory() -> Result<()> { let record_batch_size = 8192; @@ -1155,7 +1040,7 @@ async fn run_test_high_cardinality(args: RunTestHighCardinalityArgs) -> Result>, + cache: PlanProperties, +} + +impl Debug for StreamExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "StreamExec") + } +} + +impl StreamExec { + pub fn new(stream: SendableRecordBatchStream) -> Self { + let cache = Self::compute_properties(stream.schema()); + Self { + stream: Mutex::new(Some(stream)), + cache, + } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + } +} + +impl DisplayAs for StreamExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "StreamExec:") + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for StreamExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion_common::Result> { + unimplemented!() + } + + /// Returns a stream which yields data + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> datafusion_common::Result { + assert_eq!(partition, 0); + + let stream = self.stream.lock().unwrap().take(); + + stream.ok_or(DataFusionError::Internal( + "Stream already consumed".to_string(), + )) + } +} From a818582b68e863bbb9efdcbd3f779d8a9338f111 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 16 Apr 2025 00:25:13 +0300 Subject: [PATCH 09/21] add sort fuzz --- datafusion/core/tests/fuzz_cases/sort_fuzz.rs | 346 +++++++++++++++++- 1 file changed, 345 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 0b0f0aa2f105..bdb576c8b142 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -17,13 +17,17 @@ //! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill +use std::pin::Pin; use std::sync::Arc; +use arrow::array::UInt64Array; use arrow::{ array::{as_string_array, ArrayRef, Int32Array, StringArray}, compute::SortOptions, record_batch::RecordBatch, }; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::common::Result; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -31,10 +35,17 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::cast::as_int32_array; -use datafusion_execution::memory_pool::GreedyMemoryPool; +use datafusion_execution::memory_pool::{ + FairSpillPool, GreedyMemoryPool, MemoryConsumer, MemoryReservation, +}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::StreamExt; +use crate::fuzz_cases::stream_exec::StreamExec; +use datafusion_execution::memory_pool::units::MB; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::Rng; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; @@ -379,3 +390,336 @@ fn make_staggered_i32_utf8_batches(len: usize) -> Vec { batches } + +#[tokio::test] +async fn test_sort_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {} should be greater than pool size {}", + total_spill_files_size, + pool_size + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +struct RunSortTestWithLimitedMemoryArgs { + pool_size: usize, + task_ctx: TaskContext, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_sort_test_with_limited_memory( + args: RunSortTestWithLimitedMemoryArgs, +) -> Result { + let RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("col_0", &scan_schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]), + plan, + )); + + let task_ctx = Arc::new(task_ctx); + + let mut result = sort_exec.execute(0, Arc::clone(&task_ctx))?; + + let mut number_of_rows = 0; + + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_rows += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_rows, + number_of_record_batches * record_batch_size as usize + ); + + let spill_count = sort_exec.metrics().unwrap().spill_count().unwrap(); + assert!( + spill_count > 0, + "Expected spill, but did not: {number_of_record_batches:?}" + ); + + Ok(spill_count) +} + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} From ba173293b96fe21c9fd26ce161a94fa91bea4fd7 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 16 Apr 2025 00:47:09 +0300 Subject: [PATCH 10/21] fix lints and formatting --- Cargo.lock | 8 ++++---- datafusion/core/tests/fuzz_cases/sort_fuzz.rs | 6 +++--- .../src/sorts/multi_level_merge.rs | 1 + .../src/sorts/streaming_merge.rs | 20 +++++++++++++------ .../physical-plan/src/spill/get_size.rs | 17 ++++++++++++++++ 5 files changed, 39 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2b3eeecf5d9b..ee6dda88a0e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3966,9 +3966,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "902bc563b5d65ad9bba616b490842ef0651066a1a1dc3ce1087113ffcb873c8d" +checksum = "6489ca9bd760fe9642d7644e827b0c9add07df89857b0416ee15c1cc1a3b8c5a" dependencies = [ "zlib-rs", ] @@ -7512,9 +7512,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b20717f0917c908dc63de2e44e97f1e6b126ca58d0e391cee86d504eb8fbd05" +checksum = "868b928d7949e09af2f6086dfc1e01936064cc7a819253bce650d4e2a2d63ba8" [[package]] name = "zstd" diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index bdb576c8b142..7a5ebf4b439d 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -463,7 +463,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> if i % 25 == 1 { pool_size / 4 } else { - (16 * KB) as usize + 16 * KB } }), memory_behavior: Default::default(), @@ -501,7 +501,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_c if i % 25 == 1 { pool_size / 4 } else { - (16 * KB) as usize + 16 * KB } }), memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), @@ -539,7 +539,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_t if i % 25 == 1 { pool_size / 4 } else { - (16 * KB) as usize + 16 * KB } }), memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index c07fb375ff18..fc55465b0e01 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -65,6 +65,7 @@ impl Debug for MultiLevelMergeBuilder { } impl MultiLevelMergeBuilder { + #[allow(clippy::too_many_arguments)] pub(crate) fn new( spill_manager: SpillManager, schema: SchemaRef, diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index a5d055d8f693..3f3cc57de9ef 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -184,16 +184,24 @@ impl<'a> StreamingMergeBuilder<'a> { enable_round_robin_tie_breaker, } = self; - if !sorted_spill_files.is_empty() && spill_manager.is_some() { + if !sorted_spill_files.is_empty() { + // Unwrapping mandatory fields + let schema = schema.expect("Schema cannot be empty for streaming merge"); + let metrics = metrics.expect("Metrics cannot be empty for streaming merge"); + let batch_size = + batch_size.expect("Batch size cannot be empty for streaming merge"); + let reservation = + reservation.expect("Reservation cannot be empty for streaming merge"); + return Ok(MultiLevelMergeBuilder::new( - spill_manager.unwrap(), - schema.unwrap(), + spill_manager.expect("spill_manager should exist"), + schema, sorted_spill_files, streams, expressions.clone(), - metrics.unwrap(), - batch_size.unwrap(), - reservation.unwrap(), + metrics, + batch_size, + reservation, fetch, enable_round_robin_tie_breaker, ) diff --git a/datafusion/physical-plan/src/spill/get_size.rs b/datafusion/physical-plan/src/spill/get_size.rs index 3e67c03214d0..1cb56b2678ce 100644 --- a/datafusion/physical-plan/src/spill/get_size.rs +++ b/datafusion/physical-plan/src/spill/get_size.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, BooleanArray, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, GenericListArray, OffsetSizeTrait, From efde4d8a793224c7c8734397a90307164c9359ed Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Thu, 17 Jul 2025 12:26:04 +0300 Subject: [PATCH 11/21] moved spill in memory constrained envs to separate test --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 352 +-------- datafusion/core/tests/fuzz_cases/mod.rs | 1 + datafusion/core/tests/fuzz_cases/sort_fuzz.rs | 345 +-------- ...spilling_fuzz_in_memory_constrained_env.rs | 682 ++++++++++++++++++ 4 files changed, 691 insertions(+), 689 deletions(-) create mode 100644 datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 14712d949bed..7fa09c3468e3 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -22,11 +22,10 @@ use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ AggregationFuzzerBuilder, DatasetGeneratorConfig, }; -use std::pin::Pin; use arrow::array::{ types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, - StringArray, UInt64Array, + StringArray, }; use arrow::compute::concat_batches; use arrow::datatypes::DataType; @@ -43,27 +42,19 @@ use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::InputOrderMode; -use futures::StreamExt; use test_utils::{add_empty_batches, StringBatchGenerator}; -use crate::fuzz_cases::stream_exec::StreamExec; -use datafusion_execution::memory_pool::units::{KB, MB}; -use datafusion_execution::memory_pool::{ - FairSpillPool, MemoryConsumer, MemoryReservation, -}; +use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; -use datafusion_functions_aggregate::array_agg::array_agg_udaf; -use datafusion_physical_plan::metrics::MetricValue; -use datafusion_physical_plan::stream::RecordBatchStreamAdapter; -use rand::rngs::StdRng; -use rand::{random, rng, Rng, SeedableRng}; - use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; +use datafusion_physical_plan::metrics::MetricValue; use datafusion_physical_plan::{collect, displayable, ExecutionPlan}; +use rand::rngs::StdRng; +use rand::{random, rng, Rng, SeedableRng}; // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] @@ -641,7 +632,7 @@ fn extract_result_counts(results: Vec) -> HashMap, i output } -fn assert_spill_count_metric( +pub(crate) fn assert_spill_count_metric( expect_spill: bool, single_aggregate: Arc, ) -> usize { @@ -670,7 +661,7 @@ fn assert_spill_count_metric( // Fix for https://github.com/apache/datafusion/issues/15530 #[tokio::test] -async fn test_single_mode_aggregate_with_spill() -> Result<()> { +async fn test_single_mode_aggregate_single_mode_aggregate_with_spill() -> Result<()> { let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::Int64, true), Field::new("col_1", DataType::Utf8, true), @@ -776,332 +767,3 @@ async fn test_single_mode_aggregate_with_spill() -> Result<()> { Ok(()) } - -#[tokio::test] -async fn test_high_cardinality_with_limited_memory() -> Result<()> { - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - let record_batch_size = pool_size / 16; - - // Basic test with a lot of groups that cannot all fit in memory and 1 record batch - // from each spill file is too much memory - let spill_count = run_test_high_cardinality(RunTestHighCardinalityArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), - memory_behavior: Default::default(), - }) - .await?; - - let total_spill_files_size = spill_count * record_batch_size; - assert!( - total_spill_files_size > pool_size, - "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", - ); - - Ok(()) -} - -#[tokio::test] -async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( -) -> Result<()> { - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - run_test_high_cardinality(RunTestHighCardinalityArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |i| { - if i % 25 == 1 { - pool_size / 4 - } else { - (16 * KB) as usize - } - }), - memory_behavior: Default::default(), - }) - .await?; - - Ok(()) -} - -#[tokio::test] -async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( -) -> Result<()> { - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - run_test_high_cardinality(RunTestHighCardinalityArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |i| { - if i % 25 == 1 { - pool_size / 4 - } else { - (16 * KB) as usize - } - }), - memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), - }) - .await?; - - Ok(()) -} - -#[tokio::test] -async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( -) -> Result<()> { - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - run_test_high_cardinality(RunTestHighCardinalityArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |i| { - if i % 25 == 1 { - pool_size / 4 - } else { - (16 * KB) as usize - } - }), - memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, - }) - .await?; - - Ok(()) -} - -#[tokio::test] -async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> -{ - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory - run_test_high_cardinality(RunTestHighCardinalityArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), - memory_behavior: Default::default(), - }) - .await?; - - Ok(()) -} - -struct RunTestHighCardinalityArgs { - pool_size: usize, - task_ctx: TaskContext, - number_of_record_batches: usize, - get_size_of_record_batch_to_generate: - Pin usize + Send + 'static>>, - memory_behavior: MemoryBehavior, -} - -#[derive(Default)] -enum MemoryBehavior { - #[default] - AsIs, - TakeAllMemoryAtTheBeginning, - TakeAllMemoryAndReleaseEveryNthBatch(usize), -} - -async fn run_test_high_cardinality(args: RunTestHighCardinalityArgs) -> Result { - let RunTestHighCardinalityArgs { - pool_size, - task_ctx, - number_of_record_batches, - get_size_of_record_batch_to_generate, - memory_behavior, - } = args; - let scan_schema = Arc::new(Schema::new(vec![ - Field::new("col_0", DataType::UInt64, true), - Field::new("col_1", DataType::Utf8, true), - ])); - - let group_by = PhysicalGroupBy::new_single(vec![( - Arc::new(Column::new("col_0", 0)), - "col_0".to_string(), - )]); - - let aggregate_expressions = vec![Arc::new( - AggregateExprBuilder::new( - array_agg_udaf(), - vec![col("col_1", &scan_schema).unwrap()], - ) - .schema(Arc::clone(&scan_schema)) - .alias("array_agg(col_1)") - .build()?, - )]; - - let record_batch_size = task_ctx.session_config().batch_size() as u64; - - let schema = Arc::clone(&scan_schema); - let plan: Arc = - Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map( - move |index| { - let mut record_batch_memory_size = - get_size_of_record_batch_to_generate(index as usize); - record_batch_memory_size = record_batch_memory_size - .saturating_sub(size_of::() * record_batch_size as usize); - - let string_item_size = - record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); - - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - // Grouping key - Arc::new(UInt64Array::from_iter_values( - (index * record_batch_size) - ..(index * record_batch_size) + record_batch_size, - )), - // Grouping value - string_array, - ], - ) - .map_err(|err| err.into()) - }, - )), - )))); - - let aggregate_exec = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, - group_by.clone(), - aggregate_expressions.clone(), - vec![None; aggregate_expressions.len()], - plan, - Arc::clone(&scan_schema), - )?); - let aggregate_final = Arc::new(AggregateExec::try_new( - AggregateMode::Final, - group_by, - aggregate_expressions.clone(), - vec![None; aggregate_expressions.len()], - aggregate_exec, - Arc::clone(&scan_schema), - )?); - - let task_ctx = Arc::new(task_ctx); - - let mut result = aggregate_final.execute(0, Arc::clone(&task_ctx))?; - - let mut number_of_groups = 0; - - let memory_pool = task_ctx.memory_pool(); - let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); - let mut memory_reservation = memory_consumer.register(memory_pool); - - let mut index = 0; - let mut memory_took = false; - - while let Some(batch) = result.next().await { - match memory_behavior { - MemoryBehavior::AsIs => { - // Do nothing - } - MemoryBehavior::TakeAllMemoryAtTheBeginning => { - if !memory_took { - memory_took = true; - grow_memory_as_much_as_possible(10, &mut memory_reservation)?; - } - } - MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { - if !memory_took { - memory_took = true; - grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; - } else if index % n == 0 { - // release memory - memory_reservation.free(); - } - } - } - - let batch = batch?; - number_of_groups += batch.num_rows(); - - index += 1; - } - - assert_eq!( - number_of_groups, - number_of_record_batches * record_batch_size as usize - ); - - let spill_count = assert_spill_count_metric(true, aggregate_final); - - Ok(spill_count) -} - -fn grow_memory_as_much_as_possible( - memory_step: usize, - memory_reservation: &mut MemoryReservation, -) -> Result { - let mut was_able_to_grow = false; - while memory_reservation.try_grow(memory_step).is_ok() { - was_able_to_grow = true; - } - - Ok(was_able_to_grow) -} diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index c48695914906..69fbf44a0609 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -34,4 +34,5 @@ mod window_fuzz; // Utility modules mod record_batch_generator; +mod spilling_fuzz_in_memory_constrained_env; mod stream_exec; diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 025f2957514b..27becaa96f62 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -17,17 +17,13 @@ //! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill -use std::pin::Pin; use std::sync::Arc; -use arrow::array::UInt64Array; use arrow::{ array::{as_string_array, ArrayRef, Int32Array, StringArray}, compute::SortOptions, record_batch::RecordBatch, }; -use arrow_schema::{DataType, Field, Schema}; -use datafusion::common::Result; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -35,17 +31,10 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::cast::as_int32_array; -use datafusion_execution::memory_pool::{ - FairSpillPool, GreedyMemoryPool, MemoryConsumer, MemoryReservation, -}; +use datafusion_execution::memory_pool::GreedyMemoryPool; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use futures::StreamExt; -use crate::fuzz_cases::stream_exec::StreamExec; -use datafusion_execution::memory_pool::units::MB; -use datafusion_execution::TaskContext; -use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::Rng; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; @@ -388,335 +377,3 @@ fn make_staggered_i32_utf8_batches(len: usize) -> Vec { batches } - -#[tokio::test] -async fn test_sort_with_limited_memory() -> Result<()> { - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config( - SessionConfig::new() - .with_batch_size(record_batch_size) - .with_sort_spill_reservation_bytes(1), - ) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - let record_batch_size = pool_size / 16; - - // Basic test with a lot of groups that cannot all fit in memory and 1 record batch - // from each spill file is too much memory - let spill_count = - run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), - memory_behavior: Default::default(), - }) - .await?; - - let total_spill_files_size = spill_count * record_batch_size; - assert!( - total_spill_files_size > pool_size, - "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", - ); - - Ok(()) -} - -#[tokio::test] -async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> -{ - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config( - SessionConfig::new() - .with_batch_size(record_batch_size) - .with_sort_spill_reservation_bytes(1), - ) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |i| { - if i % 25 == 1 { - pool_size / 4 - } else { - 16 * KB - } - }), - memory_behavior: Default::default(), - }) - .await?; - - Ok(()) -} - -#[tokio::test] -async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( -) -> Result<()> { - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config( - SessionConfig::new() - .with_batch_size(record_batch_size) - .with_sort_spill_reservation_bytes(1), - ) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |i| { - if i % 25 == 1 { - pool_size / 4 - } else { - 16 * KB - } - }), - memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), - }) - .await?; - - Ok(()) -} - -#[tokio::test] -async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( -) -> Result<()> { - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config( - SessionConfig::new() - .with_batch_size(record_batch_size) - .with_sort_spill_reservation_bytes(1), - ) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |i| { - if i % 25 == 1 { - pool_size / 4 - } else { - 16 * KB - } - }), - memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, - }) - .await?; - - Ok(()) -} - -#[tokio::test] -async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { - let record_batch_size = 8192; - let pool_size = 2 * MB as usize; - let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(pool_size)); - TaskContext::default() - .with_session_config( - SessionConfig::new() - .with_batch_size(record_batch_size) - .with_sort_spill_reservation_bytes(1), - ) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) - }; - - // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory - run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { - pool_size, - task_ctx, - number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), - memory_behavior: Default::default(), - }) - .await?; - - Ok(()) -} - -struct RunSortTestWithLimitedMemoryArgs { - pool_size: usize, - task_ctx: TaskContext, - number_of_record_batches: usize, - get_size_of_record_batch_to_generate: - Pin usize + Send + 'static>>, - memory_behavior: MemoryBehavior, -} - -#[derive(Default)] -enum MemoryBehavior { - #[default] - AsIs, - TakeAllMemoryAtTheBeginning, - TakeAllMemoryAndReleaseEveryNthBatch(usize), -} - -async fn run_sort_test_with_limited_memory( - args: RunSortTestWithLimitedMemoryArgs, -) -> Result { - let RunSortTestWithLimitedMemoryArgs { - pool_size, - task_ctx, - number_of_record_batches, - get_size_of_record_batch_to_generate, - memory_behavior, - } = args; - let scan_schema = Arc::new(Schema::new(vec![ - Field::new("col_0", DataType::UInt64, true), - Field::new("col_1", DataType::Utf8, true), - ])); - - let record_batch_size = task_ctx.session_config().batch_size() as u64; - - let schema = Arc::clone(&scan_schema); - let plan: Arc = - Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map( - move |index| { - let mut record_batch_memory_size = - get_size_of_record_batch_to_generate(index as usize); - record_batch_memory_size = record_batch_memory_size - .saturating_sub(size_of::() * record_batch_size as usize); - - let string_item_size = - record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); - - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(UInt64Array::from_iter_values( - (index * record_batch_size) - ..(index * record_batch_size) + record_batch_size, - )), - string_array, - ], - ) - .map_err(|err| err.into()) - }, - )), - )))); - let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { - expr: col("col_0", &scan_schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }]) - .unwrap(), - plan, - )); - - let task_ctx = Arc::new(task_ctx); - - let mut result = sort_exec.execute(0, Arc::clone(&task_ctx))?; - - let mut number_of_rows = 0; - - let memory_pool = task_ctx.memory_pool(); - let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); - let mut memory_reservation = memory_consumer.register(memory_pool); - - let mut index = 0; - let mut memory_took = false; - - while let Some(batch) = result.next().await { - match memory_behavior { - MemoryBehavior::AsIs => { - // Do nothing - } - MemoryBehavior::TakeAllMemoryAtTheBeginning => { - if !memory_took { - memory_took = true; - grow_memory_as_much_as_possible(10, &mut memory_reservation)?; - } - } - MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { - if !memory_took { - memory_took = true; - grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; - } else if index % n == 0 { - // release memory - memory_reservation.free(); - } - } - } - - let batch = batch?; - number_of_rows += batch.num_rows(); - - index += 1; - } - - assert_eq!( - number_of_rows, - number_of_record_batches * record_batch_size as usize - ); - - let spill_count = sort_exec.metrics().unwrap().spill_count().unwrap(); - assert!( - spill_count > 0, - "Expected spill, but did not: {number_of_record_batches:?}" - ); - - Ok(spill_count) -} - -fn grow_memory_as_much_as_possible( - memory_step: usize, - memory_reservation: &mut MemoryReservation, -) -> Result { - let mut was_able_to_grow = false; - while memory_reservation.try_grow(memory_step).is_ok() { - was_able_to_grow = true; - } - - Ok(was_able_to_grow) -} diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs new file mode 100644 index 000000000000..f7f8f23734e4 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -0,0 +1,682 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for different operators in memory constrained environment + +use std::pin::Pin; +use std::sync::Arc; + +use arrow::array::UInt64Array; +use arrow::{array::StringArray, compute::SortOptions, record_batch::RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::physical_plan::expressions::PhysicalSortExpr; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use datafusion_execution::memory_pool::{ + FairSpillPool, MemoryConsumer, MemoryReservation, +}; +use datafusion_physical_expr::expressions::{col, Column}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::StreamExt; + +use crate::fuzz_cases::aggregate_fuzz::assert_spill_count_metric; +use crate::fuzz_cases::stream_exec::StreamExec; +use datafusion_execution::memory_pool::units::{KB, MB}; +use datafusion_execution::TaskContext; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + +#[tokio::test] +async fn test_sort_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +struct RunTestWithLimitedMemoryArgs { + pool_size: usize, + task_ctx: TaskContext, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_sort_test_with_limited_memory( + args: RunTestWithLimitedMemoryArgs, +) -> Result { + let RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("col_0", &scan_schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]) + .unwrap(), + plan, + )); + + let task_ctx = Arc::new(task_ctx); + + let mut result = sort_exec.execute(0, Arc::clone(&task_ctx))?; + + let mut number_of_rows = 0; + + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_rows += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_rows, + number_of_record_batches * record_batch_size as usize + ); + + let spill_count = sort_exec.metrics().unwrap().spill_count().unwrap(); + assert!( + spill_count > 0, + "Expected spill, but did not: {number_of_record_batches:?}" + ); + + Ok(spill_count) +} + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_record_batch( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +async fn run_test_aggregate_with_high_cardinality( + args: RunTestWithLimitedMemoryArgs, +) -> Result { + let RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("col_0", 0)), + "col_0".to_string(), + )]); + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new( + array_agg_udaf(), + vec![col("col_1", &scan_schema).unwrap()], + ) + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, + )]; + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + aggregate_exec, + Arc::clone(&scan_schema), + )?); + + let task_ctx = Arc::new(task_ctx); + + let mut result = aggregate_final.execute(0, Arc::clone(&task_ctx))?; + + let mut number_of_groups = 0; + + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_groups += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_groups, + number_of_record_batches * record_batch_size as usize + ); + + let spill_count = assert_spill_count_metric(true, aggregate_final); + + Ok(spill_count) +} From ec0f4c73f70611fee57200fc566f1aba61b4e843 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Thu, 17 Jul 2025 12:50:49 +0300 Subject: [PATCH 12/21] rename `StreamExec` to `OnceExec` --- datafusion/core/tests/fuzz_cases/mod.rs | 2 +- .../fuzz_cases/{stream_exec.rs => once_exec.rs} | 14 +++++++------- .../spilling_fuzz_in_memory_constrained_env.rs | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) rename datafusion/core/tests/fuzz_cases/{stream_exec.rs => once_exec.rs} (94%) diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 69fbf44a0609..9e2fd170f7f0 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -33,6 +33,6 @@ mod sort_preserving_repartition_fuzz; mod window_fuzz; // Utility modules +mod once_exec; mod record_batch_generator; mod spilling_fuzz_in_memory_constrained_env; -mod stream_exec; diff --git a/datafusion/core/tests/fuzz_cases/stream_exec.rs b/datafusion/core/tests/fuzz_cases/once_exec.rs similarity index 94% rename from datafusion/core/tests/fuzz_cases/stream_exec.rs rename to datafusion/core/tests/fuzz_cases/once_exec.rs index 6e71b9988d79..ec77c1e64c74 100644 --- a/datafusion/core/tests/fuzz_cases/stream_exec.rs +++ b/datafusion/core/tests/fuzz_cases/once_exec.rs @@ -29,19 +29,19 @@ use std::sync::{Arc, Mutex}; /// Execution plan that return the stream on the call to `execute`. further calls to `execute` will /// return an error -pub struct StreamExec { +pub struct OnceExec { /// the results to send back stream: Mutex>, cache: PlanProperties, } -impl Debug for StreamExec { +impl Debug for OnceExec { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "StreamExec") + write!(f, "OnceExec") } } -impl StreamExec { +impl OnceExec { pub fn new(stream: SendableRecordBatchStream) -> Self { let cache = Self::compute_properties(stream.schema()); Self { @@ -61,11 +61,11 @@ impl StreamExec { } } -impl DisplayAs for StreamExec { +impl DisplayAs for OnceExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "StreamExec:") + write!(f, "OnceExec:") } DisplayFormatType::TreeRender => { write!(f, "") @@ -74,7 +74,7 @@ impl DisplayAs for StreamExec { } } -impl ExecutionPlan for StreamExec { +impl ExecutionPlan for OnceExec { fn name(&self) -> &'static str { Self::static_name() } diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs index f7f8f23734e4..29723eee5f0a 100644 --- a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -37,7 +37,7 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; use futures::StreamExt; use crate::fuzz_cases::aggregate_fuzz::assert_spill_count_metric; -use crate::fuzz_cases::stream_exec::StreamExec; +use crate::fuzz_cases::once_exec::OnceExec; use datafusion_execution::memory_pool::units::{KB, MB}; use datafusion_execution::TaskContext; use datafusion_functions_aggregate::array_agg::array_agg_udaf; @@ -270,7 +270,7 @@ async fn run_sort_test_with_limited_memory( let schema = Arc::clone(&scan_schema); let plan: Arc = - Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), futures::stream::iter((0..number_of_record_batches as u64).map( move |index| { @@ -581,7 +581,7 @@ async fn run_test_aggregate_with_high_cardinality( let schema = Arc::clone(&scan_schema); let plan: Arc = - Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), futures::stream::iter((0..number_of_record_batches as u64).map( move |index| { From 00fa9f0d667089b80275fee8dfb5d13e445cd8af Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Thu, 17 Jul 2025 12:51:08 +0300 Subject: [PATCH 13/21] added comment on the usize in the `in_progress_spill_file` inside ExternalSorter --- datafusion/physical-plan/src/sorts/sort.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 17e0a9cc3837..f585ee01dac1 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -220,6 +220,10 @@ struct ExternalSorter { /// During external sorting, in-memory intermediate data will be appended to /// this file incrementally. Once finished, this file will be moved to [`Self::finished_spill_files`]. + /// + /// this is a tuple of: + /// 1. `InProgressSpillFile` - the file that is being written to + /// 2. `max_record_batch_memory` - the maximum memory usage of a single batch in this spill file. in_progress_spill_file: Option<(InProgressSpillFile, usize)>, /// If data has previously been spilled, the locations of the spill files (in /// Arrow IPC format) From 6198d90c901d574db4ca5f6a6c290c171c862e45 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:03:31 +0300 Subject: [PATCH 14/21] rename buffer_size to buffer_len --- .../physical-plan/src/sorts/multi_level_merge.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index fc55465b0e01..6564be447684 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -239,17 +239,17 @@ impl MultiLevelMergeBuilder { /// otherwise it will return an error fn get_sorted_spill_files_to_merge( &mut self, - buffer_size: usize, + buffer_len: usize, minimum_number_of_required_streams: usize, reservation: &mut MemoryReservation, ) -> Result<(Vec, usize)> { - assert_ne!(buffer_size, 0, "Buffer size must be greater than 0"); + assert_ne!(buffer_len, 0, "Buffer length must be greater than 0"); let mut number_of_spills_to_read_for_current_phase = 0; for spill in &self.sorted_spill_files { // For memory pools that are not shared this is good, for other this is not // and there should be some upper limit to memory reservation so we won't starve the system - match reservation.try_grow(spill.max_record_batch_memory * buffer_size) { + match reservation.try_grow(spill.max_record_batch_memory * buffer_len) { Ok(_) => { number_of_spills_to_read_for_current_phase += 1; } @@ -262,10 +262,10 @@ impl MultiLevelMergeBuilder { { // Free the memory we reserved for this merge as we either try again or fail reservation.free(); - if buffer_size > 1 { + if buffer_len > 1 { // Try again with smaller buffer size, it will be slower but at least we can merge return self.get_sorted_spill_files_to_merge( - buffer_size - 1, + buffer_len - 1, minimum_number_of_required_streams, reservation, ); @@ -286,7 +286,7 @@ impl MultiLevelMergeBuilder { .drain(..number_of_spills_to_read_for_current_phase) .collect::>(); - Ok((spills, buffer_size)) + Ok((spills, buffer_len)) } } From af6b5c52afe316127cd007cd8464d07c3038a0f9 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:30:54 +0300 Subject: [PATCH 15/21] reuse code in spill fuzz --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 4 +- ...spilling_fuzz_in_memory_constrained_env.rs | 180 ++++++++---------- 2 files changed, 78 insertions(+), 106 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 7fa09c3468e3..bcf60eb2d72a 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -634,9 +634,9 @@ fn extract_result_counts(results: Vec) -> HashMap, i pub(crate) fn assert_spill_count_metric( expect_spill: bool, - single_aggregate: Arc, + plan_that_spills: Arc, ) -> usize { - if let Some(metrics_set) = single_aggregate.metrics() { + if let Some(metrics_set) = plan_that_spills.metrics() { let mut spill_count = 0; // Inspect metrics for SpillCount diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs index 29723eee5f0a..9975746f5cde 100644 --- a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -20,6 +20,8 @@ use std::pin::Pin; use std::sync::Arc; +use crate::fuzz_cases::aggregate_fuzz::assert_spill_count_metric; +use crate::fuzz_cases::once_exec::OnceExec; use arrow::array::UInt64Array; use arrow::{array::StringArray, compute::SortOptions, record_batch::RecordBatch}; use arrow_schema::{DataType, Field, Schema}; @@ -29,23 +31,20 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; +use datafusion_execution::memory_pool::units::{KB, MB}; use datafusion_execution::memory_pool::{ FairSpillPool, MemoryConsumer, MemoryReservation, }; -use datafusion_physical_expr::expressions::{col, Column}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use futures::StreamExt; - -use crate::fuzz_cases::aggregate_fuzz::assert_spill_count_metric; -use crate::fuzz_cases::once_exec::OnceExec; -use datafusion_execution::memory_pool::units::{KB, MB}; -use datafusion_execution::TaskContext; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::{col, Column}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use futures::StreamExt; #[tokio::test] async fn test_sort_with_limited_memory() -> Result<()> { @@ -72,7 +71,7 @@ async fn test_sort_with_limited_memory() -> Result<()> { // from each spill file is too much memory let spill_count = run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), memory_behavior: Default::default(), @@ -110,7 +109,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { @@ -148,7 +147,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_c run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { @@ -186,7 +185,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_t run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { @@ -224,7 +223,7 @@ async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), memory_behavior: Default::default(), @@ -236,7 +235,7 @@ async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { struct RunTestWithLimitedMemoryArgs { pool_size: usize, - task_ctx: TaskContext, + task_ctx: Arc, number_of_record_batches: usize, get_size_of_record_batch_to_generate: Pin usize + Send + 'static>>, @@ -252,27 +251,25 @@ enum MemoryBehavior { } async fn run_sort_test_with_limited_memory( - args: RunTestWithLimitedMemoryArgs, + mut args: RunTestWithLimitedMemoryArgs, ) -> Result { - let RunTestWithLimitedMemoryArgs { - pool_size, - task_ctx, - number_of_record_batches, - get_size_of_record_batch_to_generate, - memory_behavior, - } = args; + let get_size_of_record_batch_to_generate = std::mem::replace( + &mut args.get_size_of_record_batch_to_generate, + Box::pin(move |_| unreachable!("should not be called after take")), + ); + let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::UInt64, true), Field::new("col_1", DataType::Utf8, true), ])); - let record_batch_size = task_ctx.session_config().batch_size() as u64; + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; let schema = Arc::clone(&scan_schema); let plan: Arc = Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map( + futures::stream::iter((0..args.number_of_record_batches as u64).map( move |index| { let mut record_batch_memory_size = get_size_of_record_batch_to_generate(index as usize); @@ -311,59 +308,9 @@ async fn run_sort_test_with_limited_memory( plan, )); - let task_ctx = Arc::new(task_ctx); + let result = sort_exec.execute(0, Arc::clone(&args.task_ctx))?; - let mut result = sort_exec.execute(0, Arc::clone(&task_ctx))?; - - let mut number_of_rows = 0; - - let memory_pool = task_ctx.memory_pool(); - let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); - let mut memory_reservation = memory_consumer.register(memory_pool); - - let mut index = 0; - let mut memory_took = false; - - while let Some(batch) = result.next().await { - match memory_behavior { - MemoryBehavior::AsIs => { - // Do nothing - } - MemoryBehavior::TakeAllMemoryAtTheBeginning => { - if !memory_took { - memory_took = true; - grow_memory_as_much_as_possible(10, &mut memory_reservation)?; - } - } - MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { - if !memory_took { - memory_took = true; - grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; - } else if index % n == 0 { - // release memory - memory_reservation.free(); - } - } - } - - let batch = batch?; - number_of_rows += batch.num_rows(); - - index += 1; - } - - assert_eq!( - number_of_rows, - number_of_record_batches * record_batch_size as usize - ); - - let spill_count = sort_exec.metrics().unwrap().spill_count().unwrap(); - assert!( - spill_count > 0, - "Expected spill, but did not: {number_of_record_batches:?}" - ); - - Ok(spill_count) + run_test(args, sort_exec, result).await } fn grow_memory_as_much_as_possible( @@ -400,7 +347,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory() -> Result<() let spill_count = run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), memory_behavior: Default::default(), @@ -434,7 +381,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { @@ -468,7 +415,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { @@ -502,7 +449,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { @@ -537,7 +484,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_reco // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { pool_size, - task_ctx, + task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), memory_behavior: Default::default(), @@ -548,15 +495,12 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_reco } async fn run_test_aggregate_with_high_cardinality( - args: RunTestWithLimitedMemoryArgs, + mut args: RunTestWithLimitedMemoryArgs, ) -> Result { - let RunTestWithLimitedMemoryArgs { - pool_size, - task_ctx, - number_of_record_batches, - get_size_of_record_batch_to_generate, - memory_behavior, - } = args; + let get_size_of_record_batch_to_generate = std::mem::replace( + &mut args.get_size_of_record_batch_to_generate, + Box::pin(move |_| unreachable!("should not be called after take")), + ); let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::UInt64, true), Field::new("col_1", DataType::Utf8, true), @@ -577,13 +521,13 @@ async fn run_test_aggregate_with_high_cardinality( .build()?, )]; - let record_batch_size = task_ctx.session_config().batch_size() as u64; + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; let schema = Arc::clone(&scan_schema); let plan: Arc = Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map( + futures::stream::iter((0..args.number_of_record_batches as u64).map( move |index| { let mut record_batch_memory_size = get_size_of_record_batch_to_generate(index as usize); @@ -630,21 +574,48 @@ async fn run_test_aggregate_with_high_cardinality( Arc::clone(&scan_schema), )?); - let task_ctx = Arc::new(task_ctx); + let result = aggregate_final.execute(0, Arc::clone(&args.task_ctx))?; + + run_test(args, aggregate_final, result).await +} + +async fn run_test( + args: RunTestWithLimitedMemoryArgs, + plan: Arc, + result_stream: SendableRecordBatchStream, +) -> Result { + let number_of_record_batches = args.number_of_record_batches; - let mut result = aggregate_final.execute(0, Arc::clone(&task_ctx))?; + consume_stream_and_simulate_other_running_memory_consumers(args, result_stream) + .await?; - let mut number_of_groups = 0; + let spill_count = assert_spill_count_metric(true, plan); - let memory_pool = task_ctx.memory_pool(); + assert!( + spill_count > 0, + "Expected spill, but did not, number of record batches: {number_of_record_batches}", + ); + + Ok(spill_count) +} + +/// Consume the stream and change the amount of memory used while consuming it based on the [`MemoryBehavior`] provided +async fn consume_stream_and_simulate_other_running_memory_consumers( + args: RunTestWithLimitedMemoryArgs, + mut result_stream: SendableRecordBatchStream, +) -> Result<()> { + let mut number_of_rows = 0; + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let memory_pool = args.task_ctx.memory_pool(); let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); let mut memory_reservation = memory_consumer.register(memory_pool); let mut index = 0; let mut memory_took = false; - while let Some(batch) = result.next().await { - match memory_behavior { + while let Some(batch) = result_stream.next().await { + match args.memory_behavior { MemoryBehavior::AsIs => { // Do nothing } @@ -657,7 +628,10 @@ async fn run_test_aggregate_with_high_cardinality( MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { if !memory_took { memory_took = true; - grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + grow_memory_as_much_as_possible( + args.pool_size, + &mut memory_reservation, + )?; } else if index % n == 0 { // release memory memory_reservation.free(); @@ -666,17 +640,15 @@ async fn run_test_aggregate_with_high_cardinality( } let batch = batch?; - number_of_groups += batch.num_rows(); + number_of_rows += batch.num_rows(); index += 1; } assert_eq!( - number_of_groups, - number_of_record_batches * record_batch_size as usize + number_of_rows, + args.number_of_record_batches * record_batch_size as usize ); - let spill_count = assert_spill_count_metric(true, aggregate_final); - - Ok(spill_count) + Ok(()) } From ae1ed6d6bb89a70bd9bdf370944f67b5254b31fe Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Thu, 17 Jul 2025 15:29:36 +0300 Subject: [PATCH 16/21] double the amount of memory needed to sort --- datafusion/physical-plan/src/sorts/multi_level_merge.rs | 5 ++++- datafusion/physical-plan/src/sorts/sort.rs | 9 +++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 6564be447684..c1125786cc57 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -32,6 +32,7 @@ use datafusion_execution::memory_pool::{ MemoryConsumer, MemoryPool, MemoryReservation, UnboundedMemoryPool, }; +use crate::sorts::sort::get_reserved_byte_for_record_batch_size; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::stream::RecordBatchStreamAdapter; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -249,7 +250,9 @@ impl MultiLevelMergeBuilder { for spill in &self.sorted_spill_files { // For memory pools that are not shared this is good, for other this is not // and there should be some upper limit to memory reservation so we won't starve the system - match reservation.try_grow(spill.max_record_batch_memory * buffer_len) { + match reservation.try_grow(get_reserved_byte_for_record_batch_size( + spill.max_record_batch_memory * buffer_len, + )) { Ok(_) => { number_of_spills_to_read_for_current_phase += 1; } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index f585ee01dac1..b4ea44f0d160 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -784,11 +784,16 @@ impl ExternalSorter { /// in sorting and merging. The sorted copies are in either row format or array format. /// Please refer to cursor.rs and stream.rs for more details. No matter what format the /// sorted copies are, they will use more memory than the original record batch. -fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { +pub(crate) fn get_reserved_byte_for_record_batch_size(record_batch_size: usize) -> usize { // 2x may not be enough for some cases, but it's a good start. // If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes` // to compensate for the extra memory needed. - get_record_batch_memory_size(batch) * 2 + record_batch_size * 2 +} + +/// Estimate how much memory is needed to sort a `RecordBatch`. +fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { + get_reserved_byte_for_record_batch_size(get_record_batch_memory_size(batch)) } impl Debug for ExternalSorter { From 040329637ac9ba59d635806c817f2cd197d76dd3 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:44:14 +0300 Subject: [PATCH 17/21] add diagram for explaining the overview --- .../src/sorts/multi_level_merge.rs | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index c1125786cc57..b34ee6d1d3d1 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -41,6 +41,92 @@ use futures::TryStreamExt; use futures::{Stream, StreamExt}; /// Merges a stream of sorted cursors and record batches into a single sorted stream +/// +/// This is a wrapper around [`SortPreservingMergeStream`](crate::sorts::merge::SortPreservingMergeStream) +/// that provide it the sorted streams/files to merge while making sure we can merge them in memory. +/// In case we can't merge all of them in a single pass we will spill the intermediate results to disk +/// and repeat the process. +/// +/// ## High level Algorithm +/// 1. Get the maximum amount of sorted in-memory streams and spill files we can merge with the available memory +/// 2. Sort them to a sorted stream +/// 3. Do we have more spill files to to merge? +/// - Yes: write that sorted stream to a spill file, +/// add that spill file back to the spill files to merge and +/// repeat the process +/// - No: return that sorted stream as the final output stream +/// +/// ```text +/// Initial State: Multiple sorted streams + spill files +/// ┌───────────┐ +/// │ Phase 1 │ +/// └───────────┘ +/// ┌──Can hold in memory─┐ +/// │ ┌──────────────┐ │ +/// │ │ In-memory │ +/// │ │sorted stream │──┼────────┐ +/// │ │ 1 │ │ │ +/// └──────────────┘ │ │ +/// │ ┌──────────────┐ │ │ +/// │ │ In-memory │ │ +/// │ │sorted stream │──┼────────┤ +/// │ │ 2 │ │ │ +/// └──────────────┘ │ │ +/// │ ┌──────────────┐ │ │ +/// │ │ In-memory │ │ +/// │ │sorted stream │──┼────────┤ +/// │ │ 3 │ │ │ +/// └──────────────┘ │ │ +/// │ ┌──────────────┐ │ │ ┌───────────┐ +/// │ │ Sorted Spill │ │ │ Phase 2 │ +/// │ │ file 1 │──┼────────┤ └───────────┘ +/// │ └──────────────┘ │ │ +/// ──── ──── ──── ──── ─┘ │ ┌──Can hold in memory─┐ +/// │ │ │ +/// ┌──────────────┐ │ │ ┌──────────────┐ +/// │ Sorted Spill │ │ │ │ Sorted Spill │ │ +/// │ file 2 │──────────────────────▶│ file 2 │──┼─────┐ +/// └──────────────┘ │ └──────────────┘ │ │ +/// ┌──────────────┐ │ │ ┌──────────────┐ │ │ +/// │ Sorted Spill │ │ │ │ Sorted Spill │ │ +/// │ file 3 │──────────────────────▶│ file 3 │──┼─────┤ +/// └──────────────┘ │ │ └──────────────┘ │ │ +/// ┌──────────────┐ │ ┌──────────────┐ │ │ +/// │ Sorted Spill │ │ │ │ Sorted Spill │ │ │ +/// │ file 4 │──────────────────────▶│ file 4 │────────┤ ┌───────────┐ +/// └──────────────┘ │ │ └──────────────┘ │ │ │ Phase 3 │ +/// │ │ │ │ └───────────┘ +/// │ ──── ──── ──── ──── ─┘ │ ┌──Can hold in memory─┐ +/// │ │ │ │ +/// ┌──────────────┐ │ ┌──────────────┐ │ │ ┌──────────────┐ +/// │ Sorted Spill │ │ │ Sorted Spill │ │ │ │ Sorted Spill │ │ +/// │ file 5 │──────────────────────▶│ file 5 │────────────────▶│ file 5 │───┼───┐ +/// └──────────────┘ │ └──────────────┘ │ │ └──────────────┘ │ │ +/// │ │ │ │ │ +/// │ ┌──────────────┐ │ │ ┌──────────────┐ │ +/// │ │ Sorted Spill │ │ │ │ Sorted Spill │ │ │ ┌── ─── ─── ─── ─── ─── ─── ──┐ +/// └──────────▶│ file 6 │────────────────▶│ file 6 │───┼───┼──────▶ Output Stream +/// └──────────────┘ │ │ └──────────────┘ │ │ └── ─── ─── ─── ─── ─── ─── ──┘ +/// │ │ │ │ +/// │ │ ┌──────────────┐ │ +/// │ │ │ Sorted Spill │ │ │ +/// └───────▶│ file 7 │───┼───┘ +/// │ └──────────────┘ │ +/// │ │ +/// └─ ──── ──── ──── ──── +/// ``` +/// +/// ## Memory Management Strategy +/// +/// This multi-level merge make sure that we can handle any amount of data to sort as long as +/// we have enough memory to merge at least 2 streams at a time. +/// +/// 1. **Worst-Case Memory Reservation**: Reserves memory based on the largest +/// batch size encountered in each spill file to merge, ensuring sufficient memory is always +/// available during merge operations. +/// 2. **Adaptive Buffer Sizing**: Reduces buffer sizes when memory is constrained +/// 3. **Spill-to-Disk**: Spill to disk when we cannot merge all files in memory +/// pub(crate) struct MultiLevelMergeBuilder { spill_manager: SpillManager, schema: SchemaRef, @@ -114,6 +200,7 @@ impl MultiLevelMergeBuilder { // If no spill files are left, we can return the stream as this is the last sorted run // TODO - We can write to disk before reading it back to avoid having multiple streams in memory if self.sorted_spill_files.is_empty() { + assert!(self.sorted_streams.is_empty(), "We should not have any sorted streams left"); // Attach the memory reservation to the stream as we are done with it // but because we replaced the memory reservation of the merge stream, we must hold // this to make sure we have enough memory From 001022276155318cd5a0c64c19cb567339f32a7a Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 21 Jul 2025 15:47:30 +0300 Subject: [PATCH 18/21] update based on code review --- .../src/sorts/multi_level_merge.rs | 99 ++++++++++------- .../src/sorts/streaming_merge.rs | 17 ++- .../physical-plan/src/spill/spill_manager.rs | 105 ++---------------- 3 files changed, 84 insertions(+), 137 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index b34ee6d1d3d1..a0987fcca919 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -28,9 +28,7 @@ use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; use datafusion_common::Result; -use datafusion_execution::memory_pool::{ - MemoryConsumer, MemoryPool, MemoryReservation, UnboundedMemoryPool, -}; +use datafusion_execution::memory_pool::{MemoryReservation}; use crate::sorts::sort::get_reserved_byte_for_record_batch_size; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; @@ -50,11 +48,12 @@ use futures::{Stream, StreamExt}; /// ## High level Algorithm /// 1. Get the maximum amount of sorted in-memory streams and spill files we can merge with the available memory /// 2. Sort them to a sorted stream -/// 3. Do we have more spill files to to merge? -/// - Yes: write that sorted stream to a spill file, -/// add that spill file back to the spill files to merge and -/// repeat the process -/// - No: return that sorted stream as the final output stream +/// 3. Do we have more spill files to merge? +/// - Yes: write that sorted stream to a spill file, +/// add that spill file back to the spill files to merge and +/// repeat the process +/// +/// - No: return that sorted stream as the final output stream /// /// ```text /// Initial State: Multiple sorted streams + spill files @@ -138,11 +137,6 @@ pub(crate) struct MultiLevelMergeBuilder { reservation: MemoryReservation, fetch: Option, enable_round_robin_tie_breaker: bool, - - // This is for avoiding double reservation of memory from our side and the sort preserving merge stream - // side. - // and doing a lot of code changes to avoid accounting for the memory used by the streams - unbounded_memory_pool: Arc, } impl Debug for MultiLevelMergeBuilder { @@ -176,7 +170,6 @@ impl MultiLevelMergeBuilder { reservation, enable_round_robin_tie_breaker, fetch, - unbounded_memory_pool: Arc::new(UnboundedMemoryPool::default()), } } @@ -189,10 +182,7 @@ impl MultiLevelMergeBuilder { async fn create_stream(mut self) -> Result { loop { - // Hold this for the lifetime of the stream - let mut current_memory_reservation = self.reservation.new_empty(); - let mut stream = - self.create_sorted_stream(&mut current_memory_reservation)?; + let mut stream = self.merge_sorted_runs_within_mem_limit()?; // TODO - add a threshold for number of files to disk even if empty and reading from disk so // we can avoid the memory reservation @@ -200,22 +190,19 @@ impl MultiLevelMergeBuilder { // If no spill files are left, we can return the stream as this is the last sorted run // TODO - We can write to disk before reading it back to avoid having multiple streams in memory if self.sorted_spill_files.is_empty() { - assert!(self.sorted_streams.is_empty(), "We should not have any sorted streams left"); - // Attach the memory reservation to the stream as we are done with it - // but because we replaced the memory reservation of the merge stream, we must hold - // this to make sure we have enough memory - return Ok(Box::pin(StreamAttachedReservation::new( - stream, - current_memory_reservation, - ))); + assert!( + self.sorted_streams.is_empty(), + "We should not have any sorted streams left" + ); + + return Ok(stream); } // Need to sort to a spill file let Some((spill_file, max_record_batch_memory)) = self .spill_manager - .spill_record_batch_stream_by_size( + .spill_record_batch_stream_and_return_max_batch_memory( &mut stream, - self.batch_size, "MultiLevelMergeBuilder intermediate spill", ) .await? @@ -231,9 +218,10 @@ impl MultiLevelMergeBuilder { } } - fn create_sorted_stream( + /// This tries to create a stream that merges the most sorted streams and sorted spill files + /// as possible within the memory limit. + fn merge_sorted_runs_within_mem_limit( &mut self, - memory_reservation: &mut MemoryReservation, ) -> Result { match (self.sorted_spill_files.len(), self.sorted_streams.len()) { // No data so empty batch @@ -248,6 +236,7 @@ impl MultiLevelMergeBuilder { (1, 0) => { let spill_file = self.sorted_spill_files.remove(0); + // Not reserving any memory for this disk as we are not holding it in memory self.spill_manager.read_spill_as_stream(spill_file.file) } @@ -258,11 +247,14 @@ impl MultiLevelMergeBuilder { sorted_stream, // If we have no sorted spill files left, this is the last run true, + true, ) } // Need to merge multiple streams (_, _) => { + let mut memory_reservation = self.reservation.new_empty(); + // Don't account for existing streams memory // as we are not holding the memory for them let mut sorted_streams = mem::take(&mut self.sorted_streams); @@ -272,9 +264,11 @@ impl MultiLevelMergeBuilder { 2, // we must have at least 2 streams to merge 2_usize.saturating_sub(sorted_streams.len()), - memory_reservation, + &mut memory_reservation, )?; + let is_only_merging_memory_streams = sorted_spill_files.is_empty(); + for spill in sorted_spill_files { let stream = self .spill_manager @@ -284,11 +278,27 @@ impl MultiLevelMergeBuilder { sorted_streams.push(stream); } - self.create_new_merge_sort( + let merge_sort_stream = self.create_new_merge_sort( sorted_streams, // If we have no sorted spill files left, this is the last run self.sorted_spill_files.is_empty(), - ) + is_only_merging_memory_streams, + )?; + + // If we're only merging memory streams, we don't need to attach the memory reservation + // as it's empty + if is_only_merging_memory_streams { + assert_eq!(memory_reservation.size(), 0, "when only merging memory streams, we should not have any memory reservation and let the merge sort handle the memory"); + + Ok(merge_sort_stream) + } else { + // Attach the memory reservation to the stream to make sure we have enough memory + // throughout the merge process as we bypassed the memory pool for the merge sort stream + Ok(Box::pin(StreamAttachedReservation::new( + merge_sort_stream, + memory_reservation, + ))) + } } } } @@ -297,8 +307,9 @@ impl MultiLevelMergeBuilder { &mut self, streams: Vec, is_output: bool, + all_in_memory: bool, ) -> Result { - StreamingMergeBuilder::new() + let mut builder = StreamingMergeBuilder::new() .with_schema(Arc::clone(&self.schema)) .with_expressions(&self.expr) .with_batch_size(self.batch_size) @@ -310,15 +321,21 @@ impl MultiLevelMergeBuilder { self.metrics.intermediate() }) .with_round_robin_tie_breaker(self.enable_round_robin_tie_breaker) - .with_streams(streams) + .with_streams(streams); + + if !all_in_memory { // Don't track memory used by this stream as we reserve that memory by worst case sceneries // (reserving memory for the biggest batch in each stream) - // This is a hack - .with_reservation( - MemoryConsumer::new("merge stream mock memory") - .register(&self.unbounded_memory_pool), - ) - .build() + // TODO - avoid this hack as this can be broken easily when `SortPreservingMergeStream` + // changes the implementation to use more/less memory + builder = builder.with_bypass_mempool(); + } else { + // If we are only merging in-memory streams, we need to use the memory reservation + // because we don't know the maximum size of the batches in the streams + builder = builder.with_reservation(self.reservation.new_empty()); + } + + builder.build() } /// Return the sorted spill files to use for the next phase, and the buffer size diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index 86adc7e4c6a0..191b13575341 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -29,8 +29,12 @@ use arrow::array::*; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{internal_err, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; -use datafusion_execution::memory_pool::{human_readable_size, MemoryReservation}; +use datafusion_execution::memory_pool::{ + human_readable_size, MemoryConsumer, MemoryPool, MemoryReservation, + UnboundedMemoryPool, +}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use std::sync::Arc; macro_rules! primitive_merge_helper { ($t:ty, $($v:ident),+) => { @@ -154,6 +158,17 @@ impl<'a> StreamingMergeBuilder<'a> { self } + /// Bypass the mempool and avoid using the memory reservation. + /// + /// This is not marked as `pub` because it is not recommended to use this method + pub(super) fn with_bypass_mempool(self) -> Self { + let mem_pool: Arc = Arc::new(UnboundedMemoryPool::default()); + + self.with_reservation( + MemoryConsumer::new("merge stream mock memory").register(&mem_pool), + ) + } + pub fn build(self) -> Result { let Self { streams, diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 60e20e046ef3..32e870614936 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -173,114 +173,29 @@ impl SpillManager { Ok(file.map(|f| (f, max_record_batch_size))) } - /// Spill the `RecordBatch` to disk as smaller batches - /// split by `batch_size_rows`. - /// - /// will return the spill file and the size of the largest batch in memory - pub async fn spill_record_batch_stream_by_size( + /// Spill a stream of `RecordBatch`es to disk and return the spill file and the size of the largest batch in memory + pub(crate) async fn spill_record_batch_stream_and_return_max_batch_memory( &self, stream: &mut SendableRecordBatchStream, - batch_size_rows: usize, - request_msg: &str, + request_description: &str, ) -> Result> { use futures::StreamExt; - let mut in_progress_file = self.create_in_progress_file(request_msg)?; - let mut max_record_batch_size = 0; + let mut in_progress_file = self.create_in_progress_file(request_description)?; - let mut maybe_last_batch: Option = None; + let mut max_record_batch_size = 0; while let Some(batch) = stream.next().await { - let mut batch = batch?; - - if let Some(mut last_batch) = maybe_last_batch.take() { - assert!( - last_batch.num_rows() < batch_size_rows, - "last batch size must be smaller than the requested batch size" - ); - - // Get the number of rows to take from current batch so the last_batch - // will have `batch_size_rows` rows - let current_batch_offset = std::cmp::min( - // rows needed to fill - batch_size_rows - last_batch.num_rows(), - // Current length of the batch - batch.num_rows(), - ); - - // if have last batch that has less rows than concat and spill - last_batch = arrow::compute::concat_batches( - &stream.schema(), - &[last_batch, batch.slice(0, current_batch_offset)], - )?; - - assert!(last_batch.num_rows() <= batch_size_rows, "must build a batch that is smaller or equal to the requested batch size from the current batch"); - - // If not enough rows - if last_batch.num_rows() < batch_size_rows { - // keep the last batch for next iteration - maybe_last_batch = Some(last_batch); - continue; - } - - max_record_batch_size = - max_record_batch_size.max(last_batch.get_actually_used_size()); - - in_progress_file.append_batch(&last_batch)?; - - if current_batch_offset == batch.num_rows() { - // No remainder - continue; - } - - // remainder - batch = batch.slice( - current_batch_offset, - batch.num_rows() - current_batch_offset, - ); - } - - let mut offset = 0; - let total_rows = batch.num_rows(); - - // Keep slicing the batch until we have left with a batch that is smaller than - // the wanted batch size - while total_rows - offset >= batch_size_rows { - let batch = batch.slice(offset, batch_size_rows); - offset += batch_size_rows; - - max_record_batch_size = - max_record_batch_size.max(batch.get_actually_used_size()); - - in_progress_file.append_batch(&batch)?; - } - - // If there is a remainder for the current batch that is smaller than the wanted batch size - // keep it for next iteration - if offset < total_rows { - // remainder - let batch = batch.slice(offset, total_rows - offset); - - maybe_last_batch = Some(batch); - } - } - if let Some(last_batch) = maybe_last_batch.take() { - assert!( - last_batch.num_rows() < batch_size_rows, - "last batch size must be smaller than the requested batch size" - ); - - // Write it to disk - in_progress_file.append_batch(&last_batch)?; + let batch = batch?; + in_progress_file.append_batch(&batch)?; max_record_batch_size = - max_record_batch_size.max(last_batch.get_actually_used_size()); + max_record_batch_size.max(batch.get_actually_used_size()); } - // Flush disk - let spill_file = in_progress_file.finish()?; + let file = in_progress_file.finish()?; - Ok(spill_file.map(|f| (f, max_record_batch_size))) + Ok(file.map(|f| (f, max_record_batch_size))) } /// Reads a spill file as a stream. The file must be created by the current `SpillManager`. From 8e72e799b514f446ce4ba6dd8f6b1b0cf8a4b29a Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 21 Jul 2025 15:54:55 +0300 Subject: [PATCH 19/21] fix test based on new memory calculation --- .../spilling_fuzz_in_memory_constrained_env.rs | 16 ++++++++-------- .../physical-plan/src/sorts/multi_level_merge.rs | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs index 9975746f5cde..6c1bd316cdd3 100644 --- a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -113,7 +113,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { - pool_size / 4 + pool_size / 6 } else { 16 * KB as usize } @@ -151,7 +151,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_c number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { - pool_size / 4 + pool_size / 6 } else { 16 * KB as usize } @@ -189,7 +189,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_t number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { - pool_size / 4 + pool_size / 6 } else { 16 * KB as usize } @@ -225,7 +225,7 @@ async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { pool_size, task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6), memory_behavior: Default::default(), }) .await?; @@ -385,7 +385,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { - pool_size / 4 + pool_size / 6 } else { (16 * KB) as usize } @@ -419,7 +419,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { - pool_size / 4 + pool_size / 6 } else { (16 * KB) as usize } @@ -453,7 +453,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ number_of_record_batches: 100, get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { - pool_size / 4 + pool_size / 6 } else { (16 * KB) as usize } @@ -486,7 +486,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_reco pool_size, task_ctx: Arc::new(task_ctx), number_of_record_batches: 100, - get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6), memory_behavior: Default::default(), }) .await?; diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index a0987fcca919..bb6fc751b897 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -28,7 +28,7 @@ use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; use datafusion_common::Result; -use datafusion_execution::memory_pool::{MemoryReservation}; +use datafusion_execution::memory_pool::MemoryReservation; use crate::sorts::sort::get_reserved_byte_for_record_batch_size; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; From bc76825e23eb4e8bc0225218424199d382bf56c6 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 23 Jul 2025 22:34:32 +0300 Subject: [PATCH 20/21] remove get_size in favor of get_sliced_size --- datafusion/physical-plan/src/sorts/sort.rs | 5 +- .../physical-plan/src/spill/get_size.rs | 216 ------------------ datafusion/physical-plan/src/spill/mod.rs | 1 - .../physical-plan/src/spill/spill_manager.rs | 23 +- 4 files changed, 20 insertions(+), 225 deletions(-) delete mode 100644 datafusion/physical-plan/src/spill/get_size.rs diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 8cd8fffa551b..fc03dd5f9444 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -37,9 +37,8 @@ use crate::metrics::{ use crate::projection::{make_with_child, update_ordering, ProjectionExec}; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; -use crate::spill::get_size::GetActualSize; use crate::spill::in_progress_spill_file::InProgressSpillFile; -use crate::spill::spill_manager::SpillManager; +use crate::spill::spill_manager::{GetSlicedSize, SpillManager}; use crate::stream::RecordBatchStreamAdapter; use crate::topk::TopK; use crate::{ @@ -413,7 +412,7 @@ impl ExternalSorter { in_progress_file.append_batch(&batch)?; *max_record_batch_size = - (*max_record_batch_size).max(batch.get_actually_used_size()); + (*max_record_batch_size).max(batch.get_sliced_size()); } if !globally_sorted_batches.is_empty() { diff --git a/datafusion/physical-plan/src/spill/get_size.rs b/datafusion/physical-plan/src/spill/get_size.rs deleted file mode 100644 index 1cb56b2678ce..000000000000 --- a/datafusion/physical-plan/src/spill/get_size.rs +++ /dev/null @@ -1,216 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, BooleanArray, FixedSizeBinaryArray, - FixedSizeListArray, GenericByteArray, GenericListArray, OffsetSizeTrait, - PrimitiveArray, RecordBatch, StructArray, -}; -use arrow::buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer}; -use arrow::datatypes::{ArrowNativeType, ByteArrayType}; -use arrow::downcast_primitive_array; -use arrow_schema::DataType; - -/// TODO - NEED TO MOVE THIS TO ARROW -/// this is needed as unlike `get_buffer_memory_size` or `get_array_memory_size` -/// or even `get_record_batch_memory_size` calling it on a batch before writing to spill and after reading it back -/// will return the same size (minus some optimization that IPC writer does for dictionaries) -/// -pub trait GetActualSize { - fn get_actually_used_size(&self) -> usize; -} - -impl GetActualSize for RecordBatch { - fn get_actually_used_size(&self) -> usize { - self.columns() - .iter() - .map(|c| c.get_actually_used_size()) - .sum() - } -} - -pub trait GetActualSizeArray: Array { - fn get_actually_used_size(&self) -> usize { - self.get_buffer_memory_size() - } -} - -impl GetActualSize for dyn Array { - fn get_actually_used_size(&self) -> usize { - use arrow::array::AsArray; - let array = self; - - // we can avoid this is we move this trait function to be in arrow - downcast_primitive_array!( - array => { - array.get_actually_used_size() - }, - DataType::Utf8 => { - array.as_string::().get_actually_used_size() - }, - DataType::LargeUtf8 => { - array.as_string::().get_actually_used_size() - }, - DataType::Binary => { - array.as_binary::().get_actually_used_size() - }, - DataType::LargeBinary => { - array.as_binary::().get_actually_used_size() - }, - DataType::FixedSizeBinary(_) => { - array.as_fixed_size_binary().get_actually_used_size() - }, - DataType::Struct(_) => { - array.as_struct().get_actually_used_size() - }, - DataType::List(_) => { - array.as_list::().get_actually_used_size() - }, - DataType::LargeList(_) => { - array.as_list::().get_actually_used_size() - }, - DataType::FixedSizeList(_, _) => { - array.as_fixed_size_list().get_actually_used_size() - }, - DataType::Boolean => { - array.as_boolean().get_actually_used_size() - }, - - _ => { - array.get_buffer_memory_size() - } - ) - } -} - -impl GetActualSizeArray for ArrayRef { - fn get_actually_used_size(&self) -> usize { - self.as_ref().get_actually_used_size() - } -} - -impl GetActualSize for Option<&NullBuffer> { - fn get_actually_used_size(&self) -> usize { - self.map(|b| b.get_actually_used_size()).unwrap_or(0) - } -} - -impl GetActualSize for NullBuffer { - fn get_actually_used_size(&self) -> usize { - // len return in bits - self.len() / 8 - } -} -impl GetActualSize for Buffer { - fn get_actually_used_size(&self) -> usize { - self.len() - } -} - -impl GetActualSize for BooleanBuffer { - fn get_actually_used_size(&self) -> usize { - // len return in bits - self.len() / 8 - } -} - -impl GetActualSize for OffsetBuffer { - fn get_actually_used_size(&self) -> usize { - self.inner().inner().get_actually_used_size() - } -} - -impl GetActualSizeArray for BooleanArray { - fn get_actually_used_size(&self) -> usize { - let null_size = self.nulls().get_actually_used_size(); - let values_size = self.values().get_actually_used_size(); - null_size + values_size - } -} - -impl GetActualSizeArray for GenericByteArray { - fn get_actually_used_size(&self) -> usize { - let null_size = self.nulls().get_actually_used_size(); - let offsets_size = self.offsets().get_actually_used_size(); - - let values_size = { - let first_offset = self.value_offsets()[0].as_usize(); - let last_offset = self.value_offsets()[self.len()].as_usize(); - last_offset - first_offset - }; - null_size + offsets_size + values_size - } -} - -impl GetActualSizeArray for FixedSizeBinaryArray { - fn get_actually_used_size(&self) -> usize { - let null_size = self.nulls().get_actually_used_size(); - let values_size = self.value_length() as usize * self.len(); - - null_size + values_size - } -} - -impl GetActualSizeArray for PrimitiveArray { - fn get_actually_used_size(&self) -> usize { - let null_size = self.nulls().get_actually_used_size(); - let values_size = self.values().inner().get_actually_used_size(); - - null_size + values_size - } -} - -impl GetActualSizeArray for GenericListArray { - fn get_actually_used_size(&self) -> usize { - let null_size = self.nulls().get_actually_used_size(); - let offsets_size = self.offsets().get_actually_used_size(); - - let values_size = { - let first_offset = self.value_offsets()[0].as_usize(); - let last_offset = self.value_offsets()[self.len()].as_usize(); - - self.values() - .slice(first_offset, last_offset - first_offset) - .get_actually_used_size() - }; - null_size + offsets_size + values_size - } -} - -impl GetActualSizeArray for FixedSizeListArray { - fn get_actually_used_size(&self) -> usize { - let null_size = self.nulls().get_actually_used_size(); - let values_size = self.values().get_actually_used_size(); - - null_size + values_size - } -} - -impl GetActualSizeArray for StructArray { - fn get_actually_used_size(&self) -> usize { - let null_size = self.nulls().get_actually_used_size(); - let values_size = self - .columns() - .iter() - .map(|array| array.get_actually_used_size()) - .sum::(); - - null_size + values_size - } -} - -// TODO - need to add to more arrays diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 0542899a8918..a81221c8b6a9 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -17,7 +17,6 @@ //! Defines the spilling functions -pub(crate) mod get_size; pub(crate) mod in_progress_spill_file; pub(crate) mod spill_manager; diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 32e870614936..6a138de710cb 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -28,7 +28,6 @@ use datafusion_execution::SendableRecordBatchStream; use super::{in_progress_spill_file::InProgressSpillFile, SpillReaderStream}; use crate::coop::cooperative; -use crate::spill::get_size::GetActualSize; use crate::{common::spawn_buffered, metrics::SpillMetrics}; /// The `SpillManager` is responsible for the following tasks: @@ -164,8 +163,7 @@ impl SpillManager { for batch in batches { in_progress_file.append_batch(&batch)?; - max_record_batch_size = - max_record_batch_size.max(batch.get_actually_used_size()); + max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()); } let file = in_progress_file.finish()?; @@ -189,8 +187,7 @@ impl SpillManager { let batch = batch?; in_progress_file.append_batch(&batch)?; - max_record_batch_size = - max_record_batch_size.max(batch.get_actually_used_size()); + max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()); } let file = in_progress_file.finish()?; @@ -213,3 +210,19 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } } + +pub(crate) trait GetSlicedSize { + /// Returns the size of the `RecordBatch` when sliced. + fn get_sliced_size(&self) -> usize; +} + +impl GetSlicedSize for RecordBatch { + fn get_sliced_size(&self) -> usize { + let mut total = 0; + for array in self.columns() { + let data = array.to_data(); + total += data.get_slice_memory_size().unwrap(); + } + total + } +} From 00dcf583b4848dc998d336e7a767d0a47bc0cda4 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Thu, 24 Jul 2025 10:44:01 +0300 Subject: [PATCH 21/21] change to result --- datafusion/physical-plan/src/sorts/sort.rs | 2 +- datafusion/physical-plan/src/spill/spill_manager.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index fc03dd5f9444..0b7d3977d270 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -412,7 +412,7 @@ impl ExternalSorter { in_progress_file.append_batch(&batch)?; *max_record_batch_size = - (*max_record_batch_size).max(batch.get_sliced_size()); + (*max_record_batch_size).max(batch.get_sliced_size()?); } if !globally_sorted_batches.is_empty() { diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 6a138de710cb..6c47af129fce 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -163,7 +163,7 @@ impl SpillManager { for batch in batches { in_progress_file.append_batch(&batch)?; - max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()); + max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); } let file = in_progress_file.finish()?; @@ -187,7 +187,7 @@ impl SpillManager { let batch = batch?; in_progress_file.append_batch(&batch)?; - max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()); + max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); } let file = in_progress_file.finish()?; @@ -213,16 +213,16 @@ impl SpillManager { pub(crate) trait GetSlicedSize { /// Returns the size of the `RecordBatch` when sliced. - fn get_sliced_size(&self) -> usize; + fn get_sliced_size(&self) -> Result; } impl GetSlicedSize for RecordBatch { - fn get_sliced_size(&self) -> usize { + fn get_sliced_size(&self) -> Result { let mut total = 0; for array in self.columns() { let data = array.to_data(); - total += data.get_slice_memory_size().unwrap(); + total += data.get_slice_memory_size()?; } - total + Ok(total) } }