Skip to content

Commit

Permalink
Account for memory usage in SortPreservingMerge (apache#5885)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed May 18, 2023
1 parent c24830a commit 435337f
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 56 deletions.
29 changes: 22 additions & 7 deletions datafusion/core/src/physical_plan/sorts/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::common::Result;
use arrow::compute::interleave;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_execution::memory_pool::MemoryReservation;

#[derive(Debug, Copy, Clone, Default)]
struct BatchCursor {
Expand All @@ -37,6 +38,9 @@ pub struct BatchBuilder {
/// Maintain a list of [`RecordBatch`] and their corresponding stream
batches: Vec<(usize, RecordBatch)>,

/// Accounts for memory used by buffered batches
reservation: MemoryReservation,

/// The current [`BatchCursor`] for each stream
cursors: Vec<BatchCursor>,

Expand All @@ -47,23 +51,31 @@ pub struct BatchBuilder {

impl BatchBuilder {
/// Create a new [`BatchBuilder`] with the provided `stream_count` and `batch_size`
pub fn new(schema: SchemaRef, stream_count: usize, batch_size: usize) -> Self {
pub fn new(
schema: SchemaRef,
stream_count: usize,
batch_size: usize,
reservation: MemoryReservation,
) -> Self {
Self {
schema,
batches: Vec::with_capacity(stream_count * 2),
cursors: vec![BatchCursor::default(); stream_count],
indices: Vec::with_capacity(batch_size),
reservation,
}
}

/// Append a new batch in `stream_idx`
pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) {
pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> {
self.reservation.try_grow(batch.get_array_memory_size())?;
let batch_idx = self.batches.len();
self.batches.push((stream_idx, batch));
self.cursors[stream_idx] = BatchCursor {
batch_idx,
row_idx: 0,
}
};
Ok(())
}

/// Append the next row from `stream_idx`
Expand Down Expand Up @@ -119,14 +131,17 @@ impl BatchBuilder {
// We can therefore drop all but the last batch for each stream
let mut batch_idx = 0;
let mut retained = 0;
self.batches.retain(|(stream_idx, _)| {
self.batches.retain(|(stream_idx, batch)| {
let stream_cursor = &mut self.cursors[*stream_idx];
let retain = stream_cursor.batch_idx == batch_idx;
batch_idx += 1;

if retain {
stream_cursor.batch_idx = retained;
retained += 1;
match retain {
true => {
stream_cursor.batch_idx = retained;
retained += 1;
}
false => self.reservation.shrink(batch.get_array_memory_size()),
}
retain
});
Expand Down
9 changes: 7 additions & 2 deletions datafusion/core/src/physical_plan/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use arrow::datatypes::ArrowNativeTypeOp;
use arrow::row::{Row, Rows};
use arrow_array::types::ByteArrayType;
use arrow_array::{Array, ArrowPrimitiveType, GenericByteArray, PrimitiveArray};
use datafusion_execution::memory_pool::MemoryReservation;
use std::cmp::Ordering;

/// A [`Cursor`] for [`Rows`]
Expand All @@ -29,6 +30,9 @@ pub struct RowCursor {
num_rows: usize,

rows: Rows,

#[allow(dead_code)]
reservation: MemoryReservation,
}

impl std::fmt::Debug for RowCursor {
Expand All @@ -41,12 +45,13 @@ impl std::fmt::Debug for RowCursor {
}

impl RowCursor {
/// Create a new SortKeyCursor
pub fn new(rows: Rows) -> Self {
/// Create a new SortKeyCursor from `rows` and the associated `reservation`
pub fn new(rows: Rows, reservation: MemoryReservation) -> Self {
Self {
cur_row: 0,
num_rows: rows.num_rows(),
rows,
reservation,
}
}

Expand Down
30 changes: 20 additions & 10 deletions datafusion/core/src/physical_plan/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::physical_plan::{
use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_array::*;
use datafusion_execution::memory_pool::MemoryReservation;
use futures::Stream;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
Expand All @@ -39,13 +40,14 @@ macro_rules! primitive_merge_helper {
}

macro_rules! merge_helper {
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $reservation:ident) => {{
let streams = FieldCursorStream::<$t>::new($sort, $streams);
return Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
$schema,
$tracking_metrics,
$batch_size,
$reservation,
)));
}};
}
Expand All @@ -57,27 +59,35 @@ pub(crate) fn streaming_merge(
expressions: &[PhysicalSortExpr],
metrics: BaselineMetrics,
batch_size: usize,
reservation: MemoryReservation,
) -> Result<SendableRecordBatchStream> {
// Special case single column comparisons with optimized cursor implementations
if expressions.len() == 1 {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size)
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, reservation),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, reservation)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, reservation)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, reservation)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, reservation)
_ => {}
}
}

let streams = RowCursorStream::try_new(schema.as_ref(), expressions, streams)?;
let streams = RowCursorStream::try_new(
schema.as_ref(),
expressions,
streams,
reservation.split_empty(),
)?;

Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
schema,
metrics,
batch_size,
reservation,
)))
}

Expand Down Expand Up @@ -148,11 +158,12 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
schema: SchemaRef,
metrics: BaselineMetrics,
batch_size: usize,
reservation: MemoryReservation,
) -> Self {
let stream_count = streams.partitions();

Self {
in_progress: BatchBuilder::new(schema, stream_count, batch_size),
in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation),
streams,
metrics,
aborted: false,
Expand Down Expand Up @@ -181,8 +192,7 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
Some(Err(e)) => Poll::Ready(Err(e)),
Some(Ok((cursor, batch))) => {
self.cursors[idx] = Some(cursor);
self.in_progress.push_batch(idx, batch);
Poll::Ready(Ok(()))
Poll::Ready(self.in_progress.push_batch(idx, batch))
}
}
}
Expand Down
55 changes: 35 additions & 20 deletions datafusion/core/src/physical_plan/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ use tempfile::NamedTempFile;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::task;

/// How much memory to reserve for performing in-memory sorts
const EXTERNAL_SORTER_MERGE_RESERVATION: usize = 10 * 1024 * 1024;

struct ExternalSorterMetrics {
/// metrics
baseline: BaselineMetrics,
Expand Down Expand Up @@ -94,8 +97,10 @@ struct ExternalSorter {
expr: Arc<[PhysicalSortExpr]>,
metrics: ExternalSorterMetrics,
fetch: Option<usize>,
/// Reservation for in_mem_batches
reservation: MemoryReservation,
partition_id: usize,
/// Reservation for in memory sorting of batches
merge_reservation: MemoryReservation,
runtime: Arc<RuntimeEnv>,
batch_size: usize,
}
Expand All @@ -115,6 +120,12 @@ impl ExternalSorter {
.with_can_spill(true)
.register(&runtime.memory_pool);

let mut merge_reservation =
MemoryConsumer::new(format!("ExternalSorterMerge[{partition_id}]"))
.register(&runtime.memory_pool);

merge_reservation.resize(EXTERNAL_SORTER_MERGE_RESERVATION);

Self {
schema,
in_mem_batches: vec![],
Expand All @@ -124,7 +135,7 @@ impl ExternalSorter {
metrics,
fetch,
reservation,
partition_id,
merge_reservation,
runtime,
batch_size,
}
Expand Down Expand Up @@ -189,12 +200,10 @@ impl ExternalSorter {
&self.expr,
self.metrics.baseline.clone(),
self.batch_size,
self.reservation.split_empty(),
)
} else if !self.in_mem_batches.is_empty() {
let result = self.in_mem_sort_stream(self.metrics.baseline.clone());
// Report to the memory manager we are no longer using memory
self.reservation.free();
result
self.in_mem_sort_stream(self.metrics.baseline.clone())
} else {
Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
}
Expand Down Expand Up @@ -238,6 +247,9 @@ impl ExternalSorter {
return Ok(());
}

// Release the memory reserved for merge
self.merge_reservation.free();

self.in_mem_batches = self
.in_mem_sort_stream(self.metrics.baseline.intermediate())?
.try_collect()
Expand All @@ -249,7 +261,11 @@ impl ExternalSorter {
.map(|x| x.get_array_memory_size())
.sum();

self.reservation.resize(size);
// Reserve headroom for next sort
self.merge_reservation
.resize(EXTERNAL_SORTER_MERGE_RESERVATION);

self.reservation.try_resize(size)?;
self.in_mem_batches_sorted = true;
Ok(())
}
Expand All @@ -262,9 +278,8 @@ impl ExternalSorter {
assert_ne!(self.in_mem_batches.len(), 0);
if self.in_mem_batches.len() == 1 {
let batch = self.in_mem_batches.remove(0);
let stream = self.sort_batch_stream(batch, metrics)?;
self.in_mem_batches.clear();
return Ok(stream);
let reservation = self.reservation.take();
return self.sort_batch_stream(batch, metrics, reservation);
}

// If less than 1MB of in-memory data, concatenate and sort in place
Expand All @@ -274,14 +289,19 @@ impl ExternalSorter {
// Concatenate memory batches together and sort
let batch = concat_batches(&self.schema, &self.in_mem_batches)?;
self.in_mem_batches.clear();
return self.sort_batch_stream(batch, metrics);
self.reservation.try_resize(batch.get_array_memory_size())?;
let reservation = self.reservation.take();
return self.sort_batch_stream(batch, metrics, reservation);
}

let streams = std::mem::take(&mut self.in_mem_batches)
.into_iter()
.map(|batch| {
let metrics = self.metrics.baseline.intermediate();
Ok(spawn_buffered(self.sort_batch_stream(batch, metrics)?, 1))
let reservation =
self.reservation.split(batch.get_array_memory_size())?;
let input = self.sort_batch_stream(batch, metrics, reservation)?;
Ok(spawn_buffered(input, 1))
})
.collect::<Result<_>>()?;

Expand All @@ -293,30 +313,25 @@ impl ExternalSorter {
&self.expr,
metrics,
self.batch_size,
self.merge_reservation.split_empty(),
)
}

fn sort_batch_stream(
&self,
batch: RecordBatch,
metrics: BaselineMetrics,
reservation: MemoryReservation,
) -> Result<SendableRecordBatchStream> {
let schema = batch.schema();

let mut reservation =
MemoryConsumer::new(format!("sort_batch_stream{}", self.partition_id))
.register(&self.runtime.memory_pool);

// TODO: This should probably be try_grow (#5885)
reservation.resize(batch.get_array_memory_size());

let fetch = self.fetch;
let expressions = self.expr.clone();
let stream = futures::stream::once(futures::future::lazy(move |_| {
let sorted = sort_batch(&batch, &expressions, fetch)?;
metrics.record_output(sorted.num_rows());
drop(batch);
reservation.free();
drop(reservation);
Ok(sorted)
}));
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::any::Any;
use std::sync::Arc;

use arrow::datatypes::SchemaRef;
use datafusion_execution::memory_pool::MemoryConsumer;
use log::{debug, trace};

use crate::error::{DataFusionError, Result};
Expand Down Expand Up @@ -165,6 +166,10 @@ impl ExecutionPlan for SortPreservingMergeExec {
);
let schema = self.schema();

let reservation =
MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
.register(&context.runtime_env().memory_pool);

match input_partitions {
0 => Err(DataFusionError::Internal(
"SortPreservingMergeExec requires at least one input partition"
Expand Down Expand Up @@ -192,6 +197,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
&self.expr,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
reservation,
)?;

debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
Expand Down
Loading

0 comments on commit 435337f

Please sign in to comment.