diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 6f27acfe7f6c..56eec4b1b7d5 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -380,7 +380,7 @@ impl ExecutionPlan for ParquetExec { let stream = FileStream::new(&self.base_config, partition_index, opener, &self.metrics)?; - Ok(Box::pin(stream)) + Ok(Box::pin(stream) as SendableRecordBatchStream) } fn metrics(&self) -> Option { diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 2dfcf12e350a..86c05a2376d2 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -58,6 +58,7 @@ tokio = { version = "1.28", features = ["sync", "fs", "parking_lot"] } uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] +paste = "1.0.14" rstest = "0.18.0" termtree = "0.4.1" tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 3a2cac59cfdf..7bc67ab5ee41 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -29,7 +29,9 @@ use datafusion_common::{internal_err, DataFusionError, Result}; use futures::StreamExt; use super::expressions::PhysicalSortExpr; -use super::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; +use super::stream::{ + ReceiverStream, RecordBatchReceiverStreamAdaptor, RecordBatchStreamAdapter, +}; use super::{DisplayAs, Distribution, SendableRecordBatchStream}; use datafusion_execution::TaskContext; @@ -155,11 +157,11 @@ impl ExecutionPlan for AnalyzeExec { // parallel (on a separate tokio task) using a JoinSet to // cancel outstanding futures on drop let num_input_partitions = self.input.output_partitioning().partition_count(); - let mut builder = - RecordBatchReceiverStream::builder(self.schema(), num_input_partitions); + let mut builder = ReceiverStream::builder(self.schema(), num_input_partitions); + let input = Arc::new(RecordBatchReceiverStreamAdaptor::new(self.input.clone())); for input_partition in 0..num_input_partitions { - builder.run_input(self.input.clone(), input_partition, context.clone()); + builder.run_input(input.clone(), input_partition, context.clone()); } // Create future that computes thefinal output diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 646d42795ba4..3472606918e0 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use super::expressions::PhysicalSortExpr; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use super::stream::{ObservedStream, RecordBatchReceiverStream}; +use super::stream::{ObservedStream, ReceiverStream, RecordBatchReceiverStreamAdaptor}; use super::{DisplayAs, SendableRecordBatchStream, Statistics}; use crate::{DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning}; @@ -145,12 +145,14 @@ impl ExecutionPlan for CoalescePartitionsExec { // least one result in an attempt to maximize // parallelism. let mut builder = - RecordBatchReceiverStream::builder(self.schema(), input_partitions); + ReceiverStream::builder(self.schema(), input_partitions); + let input = + Arc::new(RecordBatchReceiverStreamAdaptor::new(self.input.clone())); // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. for part_i in 0..input_partitions { - builder.run_input(self.input.clone(), part_i, context.clone()); + builder.run_input(input.clone(), part_i, context.clone()); } let stream = builder.build(); diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index c6cfbbfbbac7..2c398a25ffb8 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -18,7 +18,7 @@ //! Defines common code used in execution plans use super::SendableRecordBatchStream; -use crate::stream::RecordBatchReceiverStream; +use crate::stream::ReceiverStream; use crate::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; @@ -102,7 +102,7 @@ pub(crate) fn spawn_buffered( Ok(handle) if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => { - let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer); + let mut builder = ReceiverStream::builder(input.schema(), buffer); let sender = builder.tx(); diff --git a/datafusion/physical-plan/src/sorts/batch_cursor.rs b/datafusion/physical-plan/src/sorts/batch_cursor.rs new file mode 100644 index 000000000000..1ada42f9ed8a --- /dev/null +++ b/datafusion/physical-plan/src/sorts/batch_cursor.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::Result; + +use super::cursor::Cursor; + +pub type BatchId = u64; + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct BatchOffset(pub usize); + +pub type SlicedBatchCursorIdentifier = (BatchId, BatchOffset); + +/// The [`BatchCursor`] represents a complete, or partial, [`Cursor`] for a given record batch ([`BatchId`]). +/// +/// A record batch (represented by its [`Cursor`]) can be sliced due to the following reason: +/// 1. a merge node takes in 10 streams +/// 2 at any given time, this means up to 10 cursors (record batches) are being merged (e.g. in the loser tree) +/// 3. merge nodes will yield once it hits a size limit +/// 4. at the moment of yielding, there may be some cursors which are partially yielded +/// +/// Unique representation of sliced cursor is denoted by the [`SlicedBatchCursorIdentifier`]. +#[derive(Debug)] +pub struct BatchCursor { + /// The index into BatchTrackingStream::batches + batch: BatchId, + /// The offset of the row within the given batch, based on the idea of a sliced cursor. + /// When a batch is partially yielded, then the offset->end will determine how much was yielded. + row_offset: BatchOffset, + + /// The cursor for the given batch. + pub cursor: C, +} + +impl BatchCursor { + /// Create a new [`BatchCursor`] from a [`Cursor`] and a [`BatchId`]. + /// + /// New [`BatchCursor`]s will have a [`BatchOffset`] of 0. + /// Subsequent batch_cursors can be created by slicing. + pub fn new(batch: BatchId, cursor: C) -> Self { + Self { + batch, + row_offset: BatchOffset(0), + cursor, + } + } + + /// A unique identifier used to identify a [`BatchCursor`] + pub fn identifier(&self) -> SlicedBatchCursorIdentifier { + (self.batch, self.row_offset) + } + + /// Slicing of a batch cursor is done by slicing the underlying cursor, + /// and adjust the BatchOffset + pub fn slice(&self, offset: usize, length: usize) -> Result { + Ok(Self { + batch: self.batch, + row_offset: BatchOffset(self.row_offset.0 + offset), + cursor: self.cursor.slice(offset, length)?, + }) + } +} + +impl std::fmt::Display for BatchCursor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "BatchCursor(batch: {}, offset: {}, num_rows: {})", + self.batch, + self.row_offset.0, + self.cursor.num_rows() + ) + } +} diff --git a/datafusion/physical-plan/src/sorts/builder.rs b/datafusion/physical-plan/src/sorts/builder.rs index 3527d5738223..0105a65a6b1b 100644 --- a/datafusion/physical-plan/src/sorts/builder.rs +++ b/datafusion/physical-plan/src/sorts/builder.rs @@ -15,136 +15,192 @@ // specific language governing permissions and limitations // under the License. -use arrow::compute::interleave; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; use datafusion_common::Result; -use datafusion_execution::memory_pool::MemoryReservation; - -#[derive(Debug, Copy, Clone, Default)] -struct BatchCursor { - /// The index into BatchBuilder::batches - batch_idx: usize, - /// The row index within the given batch - row_idx: usize, -} +use std::collections::VecDeque; +use std::mem::take; -/// Provides an API to incrementally build a [`RecordBatch`] from partitioned [`RecordBatch`] -#[derive(Debug)] -pub struct BatchBuilder { - /// The schema of the RecordBatches yielded by this stream - schema: SchemaRef, +use super::batch_cursor::{BatchCursor, SlicedBatchCursorIdentifier}; +use super::cursor::Cursor; - /// Maintain a list of [`RecordBatch`] and their corresponding stream - batches: Vec<(usize, RecordBatch)>, +pub type SortOrder = (SlicedBatchCursorIdentifier, usize); // batch_id, row_idx (without offset) +pub type YieldedSortOrder = (Vec>, Vec); - /// Accounts for memory used by buffered batches - reservation: MemoryReservation, +/// Provides an API to incrementally build a [`SortOrder`] from partitioned [`RecordBatch`](arrow::record_batch::RecordBatch)es +#[derive(Debug)] +pub struct SortOrderBuilder { + /// Maintain a list of cursors for each finished (sorted) batch + /// The number of total batches can be larger than the number of total streams + batch_cursors: VecDeque>, - /// The current [`BatchCursor`] for each stream - cursors: Vec, + /// The current [`BatchCursor`] for each stream_idx + cursors: Vec>>, /// The accumulated stream indexes from which to pull rows - /// Consists of a tuple of `(batch_idx, row_idx)` - indices: Vec<(usize, usize)>, + indices: Vec, + + stream_count: usize, } -impl BatchBuilder { - /// Create a new [`BatchBuilder`] with the provided `stream_count` and `batch_size` - pub fn new( - schema: SchemaRef, - stream_count: usize, - batch_size: usize, - reservation: MemoryReservation, - ) -> Self { +impl SortOrderBuilder { + /// Create a new [`SortOrderBuilder`] with the provided `stream_count` and `batch_size` + pub fn new(stream_count: usize, batch_size: usize) -> Self { Self { - schema, - batches: Vec::with_capacity(stream_count * 2), - cursors: vec![BatchCursor::default(); stream_count], + batch_cursors: VecDeque::with_capacity(stream_count * 2), + cursors: (0..stream_count).map(|_| None).collect(), indices: Vec::with_capacity(batch_size), - reservation, + stream_count, } } /// Append a new batch in `stream_idx` - pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> { - self.reservation.try_grow(batch.get_array_memory_size())?; - let batch_idx = self.batches.len(); - self.batches.push((stream_idx, batch)); - self.cursors[stream_idx] = BatchCursor { - batch_idx, - row_idx: 0, - }; + pub fn push_batch( + &mut self, + stream_idx: usize, + batch_cursor: BatchCursor, + ) -> Result<()> { + self.cursors[stream_idx] = Some(batch_cursor); Ok(()) } /// Append the next row from `stream_idx` pub fn push_row(&mut self, stream_idx: usize) { - let cursor = &mut self.cursors[stream_idx]; - let row_idx = cursor.row_idx; - cursor.row_idx += 1; - self.indices.push((cursor.batch_idx, row_idx)); + if let Some(batch_cursor) = &mut self.cursors[stream_idx] { + // This is tightly coupled to the loser tree implementation. + // The winner (top of the tree) is from the stream_idx + // and gets pushed after the cursor already advanced (to fill a leaf node). + // Therefore, we need to subtract 1 from the current_idx. + let row_idx = batch_cursor.cursor.current_idx() - 1; + self.indices.push((batch_cursor.identifier(), row_idx)); + + if batch_cursor.cursor.is_finished() { + self.cursor_finished(stream_idx); + } + } else { + unreachable!("row pushed for non-existant cursor"); + }; } - /// Returns the number of in-progress rows in this [`BatchBuilder`] + /// Returns the number of in-progress rows in this [`SortOrderBuilder`] pub fn len(&self) -> usize { self.indices.len() } - /// Returns `true` if this [`BatchBuilder`] contains no in-progress rows + /// Returns `true` if this [`SortOrderBuilder`] contains no in-progress rows pub fn is_empty(&self) -> bool { self.indices.is_empty() } - /// Returns the schema of this [`BatchBuilder`] - pub fn schema(&self) -> &SchemaRef { - &self.schema + /// For a finished cursor, remove from BatchCursor (per stream_idx), and track in batch_cursors (per batch_idx) + fn cursor_finished(&mut self, stream_idx: usize) { + if let Some(cursor) = self.cursors[stream_idx].take() { + self.batch_cursors.push_back(cursor); + } } - /// Drains the in_progress row indexes, and builds a new RecordBatch from them + /// Advance the cursor for `stream_idx` + /// Returns `true` if the cursor was advanced + pub fn advance(&mut self, stream_idx: usize) -> bool { + match &mut self.cursors[stream_idx] { + Some(batch_cursor) => { + let cursor = &mut batch_cursor.cursor; + if cursor.is_finished() { + return false; + } + cursor.advance(); + true + } + None => false, + } + } + + /// Returns true if there is an in-progress cursor for a given stream + pub fn cursor_in_progress(&mut self, stream_idx: usize) -> bool { + self.cursors[stream_idx] + .as_mut() + .map_or(false, |batch_cursor| !batch_cursor.cursor.is_finished()) + } + + /// Returns `true` if the cursor at index `a` is greater than at index `b` + #[inline] + pub fn is_gt(&mut self, stream_idx_a: usize, stream_idx_b: usize) -> bool { + match ( + self.cursor_in_progress(stream_idx_a), + self.cursor_in_progress(stream_idx_b), + ) { + (false, _) => true, + (_, false) => false, + _ => match (&self.cursors[stream_idx_a], &self.cursors[stream_idx_b]) { + (Some(a), Some(b)) => a + .cursor + .cmp(&b.cursor) + .then_with(|| stream_idx_a.cmp(&stream_idx_b)) + .is_gt(), + _ => unreachable!(), + }, + } + } + + /// Takes the batches which already are sorted, and returns them with the corresponding cursors and sort order. /// - /// Will then drop any batches for which all rows have been yielded to the output + /// This will drain the internal state of the builder, and return `None` if there are no pending. /// - /// Returns `None` if no pending rows - pub fn build_record_batch(&mut self) -> Result> { + /// This slices cursors for each record batch, as follows: + /// 1. input was N record_batchs of up to max M size + /// 2. yielded ordered rows can only equal up to M size + /// 3. of the N record_batches, each will be: + /// a. fully yielded (all rows) + /// b. partially yielded (some rows) => slice the batch_cursor + /// c. not yielded (no rows) => retain cursor + /// 4. output will be: + /// - SortOrder + /// - corresponding cursors, each up to total yielded rows [cursor_batch_0, cursor_batch_1, ..] + pub fn yield_sort_order(&mut self) -> Result>> { if self.is_empty() { return Ok(None); } - let columns = (0..self.schema.fields.len()) - .map(|column_idx| { - let arrays: Vec<_> = self - .batches - .iter() - .map(|(_, batch)| batch.column(column_idx).as_ref()) - .collect(); - Ok(interleave(&arrays, &self.indices)?) - }) - .collect::>>()?; - - self.indices.clear(); - - // New cursors are only created once the previous cursor for the stream - // is finished. This means all remaining rows from all but the last batch - // for each stream have been yielded to the newly created record batch - // - // We can therefore drop all but the last batch for each stream - let mut batch_idx = 0; - let mut retained = 0; - self.batches.retain(|(stream_idx, batch)| { - let stream_cursor = &mut self.cursors[*stream_idx]; - let retain = stream_cursor.batch_idx == batch_idx; - batch_idx += 1; - - if retain { - stream_cursor.batch_idx = retained; - retained += 1; + let sort_order = take(&mut self.indices); + let mut cursors_to_yield: Vec> = + Vec::with_capacity(self.stream_count * 2); + + // drain already complete cursors + for mut batch_cursor in take(&mut self.batch_cursors) { + batch_cursor.cursor.reset(); + cursors_to_yield.push(batch_cursor); + } + + // split any in_progress cursor + for stream_idx in 0..self.cursors.len() { + let mut batch_cursor = match self.cursors[stream_idx].take() { + Some(c) => c, + None => continue, + }; + + if batch_cursor.cursor.is_finished() { + batch_cursor.cursor.reset(); + cursors_to_yield.push(batch_cursor); + } else if batch_cursor.cursor.in_progress() { + let row_idx = batch_cursor.cursor.current_idx(); + let num_rows = batch_cursor.cursor.num_rows(); + + let batch_cursor_to_yield = batch_cursor.slice(0, row_idx)?; + let batch_cursor_to_retain = + batch_cursor.slice(row_idx, num_rows - row_idx)?; + assert_eq!( + batch_cursor_to_yield.cursor.num_rows() + + batch_cursor_to_retain.cursor.num_rows(), + num_rows + ); + drop(batch_cursor.cursor); // drop the original cursor + + self.cursors[stream_idx] = Some(batch_cursor_to_retain); + cursors_to_yield.push(batch_cursor_to_yield); } else { - self.reservation.shrink(batch.get_array_memory_size()); + // retained all (nothing yielded) + self.cursors[stream_idx] = Some(batch_cursor); } - retain - }); + } - Ok(Some(RecordBatch::try_new(self.schema.clone(), columns)?)) + Ok(Some((cursors_to_yield, sort_order))) } } diff --git a/datafusion/physical-plan/src/sorts/cascade.rs b/datafusion/physical-plan/src/sorts/cascade.rs new file mode 100644 index 000000000000..800f767300aa --- /dev/null +++ b/datafusion/physical-plan/src/sorts/cascade.rs @@ -0,0 +1,324 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::BaselineMetrics; +use crate::sorts::builder::SortOrder; +use crate::sorts::cursor::Cursor; +use crate::sorts::merge::SortPreservingMergeStream; +use crate::sorts::stream::{ + BatchCursorStream, BatchTracker, BatchTrackerStream, MergeStream, YieldedCursorStream, +}; +use crate::stream::ReceiverStream; +use crate::RecordBatchStream; + +use super::batch_cursor::BatchId; + +use arrow::compute::interleave; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::memory_pool::MemoryReservation; +use futures::{Stream, StreamExt}; +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +static MAX_STREAMS_PER_MERGE: usize = 10; + +/// Sort preserving cascade stream +/// +/// The cascade works as a tree of sort-preserving-merges, where each merge has +/// a limited fan-in (number of inputs) and a limit size yielded (batch size) per poll. +/// The poll is called from the root merge, which will poll its children, and so on. +/// +/// ```text +/// ┌─────┐ ┌─────┐ +/// │ 2 │ │ 1 │ +/// │ 3 │ │ 2 │ +/// │ 1 │─ ─▶ sort ─ ─▶│ 2 │─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ 4 │ │ 3 │ +/// │ 2 │ │ 4 │ │ +/// └─────┘ └─────┘ +/// ┌─────┐ ┌─────┐ ▼ +/// │ 1 │ │ 1 │ +/// │ 4 │─ ▶ sort ─ ─ ▶│ 1 ├ ─ ─ ─ ─ ─ ▶ merge ─ ─ ─ ─ +/// │ 1 │ │ 4 │ │ +/// └─────┘ └─────┘ +/// ... ... ... ▼ +/// +/// merge ─ ─ ─ ─ ─ ─ ▶ sorted output +/// stream +/// ▲ +/// ... ... ... │ +/// ┌─────┐ ┌─────┐ +/// │ 3 │ │ 3 │ │ +/// │ 1 │─ ▶ sort ─ ─ ▶│ 1 │─ ─ ─ ─ ─ ─▶ merge ─ ─ ─ ─ +/// └─────┘ └─────┘ +/// ┌─────┐ ┌─────┐ ▲ +/// │ 4 │ │ 3 │ +/// │ 3 │─ ▶ sort ─ ─ ▶│ 4 │─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// └─────┘ └─────┘ +/// +/// in_mem_batches do a series of merges that +/// each has a limited fan-in +/// (number of inputs) +/// ``` +/// +/// The cascade is built using a series of streams, each with a different purpose: +/// * Streams leading into the leaf nodes: +/// 1. [`BatchCursorStream`] yields the initial cursors and batches. (e.g. a RowCursorStream) +/// * This initial CursorStream is for a number of partitions (e.g. 100). +/// * only a single BatchCursorStream. +/// 2. [`BatchCursorStream::take_partitions()`] allows us to take a subset of the partitioned streams. +/// * This enables parallelism of [`BatchCursorStream`] with mutable access (for polling), without locking. +/// * The total fan-in is always limited to 10. E.g. each leaf node will pull from a dedicated 10 (out of 100) partitions. +/// 3. [`BatchTrackerStream`] is used to collect the record batches from the leaf nodes. +/// * contains a single, shared using [`BatchTracker`]. +/// * polling of streams is non-blocking. +/// +/// * Streams between merge nodes: +/// 1. a single [`MergeStream`] is yielded per node. +/// 2. A connector [`YieldedCursorStream`] converts a [`MergeStream`] into a [`CursorStream`](super::stream::CursorStream). +/// 3. next merge node takes a fan-in of up to 10 [`CursorStream`](super::stream::CursorStream)s. +/// +pub(crate) struct SortPreservingCascadeStream { + /// If the stream has encountered an error, or fetch is reached + aborted: bool, + + /// used to record execution metrics + metrics: BaselineMetrics, + + /// The cascading stream + cascade: MergeStream, + + /// The schema of the RecordBatches yielded by this stream + schema: SchemaRef, + + /// Batches are collected on first yield from the RowCursorStream + /// Subsequent merges in cascade all refer to the [`BatchId`]s + record_batch_collector: Arc, +} + +impl SortPreservingCascadeStream { + pub(crate) fn new>>( + mut streams: S, + schema: SchemaRef, + metrics: BaselineMetrics, + batch_size: usize, + fetch: Option, + reservation: MemoryReservation, + ) -> Self { + let stream_count = streams.partitions(); + + let batch_tracker = Arc::new(BatchTracker::new(reservation.new_empty())); + + let max_streams_per_merge = MAX_STREAMS_PER_MERGE; + let mut divided_streams: VecDeque> = + VecDeque::with_capacity(stream_count / max_streams_per_merge + 1); + + // build leaves + for stream_idx in (0..stream_count).step_by(max_streams_per_merge) { + let limit = std::cmp::min(max_streams_per_merge, stream_count - stream_idx); + + // divide the BatchCursorStream across multiple leafnode merges. + let streams = BatchTrackerStream::new( + streams.take_partitions(0..limit), + batch_tracker.clone(), + ); + + divided_streams.push_back(spawn_buffered_merge( + Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + metrics.clone(), + batch_size, + None, // fetch, the LIMIT, is applied to the final merge + )), + schema.clone(), + 2, + )); + } + + // build rest of tree + let mut next_level: VecDeque> = + VecDeque::with_capacity(divided_streams.len() / max_streams_per_merge + 1); + while divided_streams.len() > 1 || !next_level.is_empty() { + let fan_in: Vec> = divided_streams + .drain(0..std::cmp::min(max_streams_per_merge, divided_streams.len())) + .collect(); + + next_level.push_back(spawn_buffered_merge( + Box::pin(SortPreservingMergeStream::new( + Box::new(YieldedCursorStream::new(fan_in)), + metrics.clone(), + batch_size, + if divided_streams.is_empty() && next_level.is_empty() { + fetch + } else { + None + }, // fetch, the LIMIT, is applied to the final merge + )), + schema.clone(), + 2, + )); + // in order to maintain sort-preserving streams, don't mix the merge tree levels. + if divided_streams.is_empty() { + divided_streams = std::mem::take(&mut next_level); + } + } + + Self { + aborted: false, + cascade: divided_streams + .remove(0) + .expect("must have a root merge stream"), + schema, + metrics, + record_batch_collector: batch_tracker, + } + } + + /// Construct and yield the root node [`RecordBatch`]s. + fn build_record_batch(&mut self, sort_order: Vec) -> Result { + let mut batches_needed = Vec::with_capacity(sort_order.len()); + let mut batches_seen: HashMap = + HashMap::with_capacity(sort_order.len()); // (batch_idx, rows_sorted) + + // The sort_order yielded at each poll is relative to the sliced batch it came from. + // Therefore, the sort_order row_idx needs to be adjusted by the offset of the sliced batch. + let mut sort_order_offset_adjusted = Vec::with_capacity(sort_order.len()); + + for ((batch_id, offset), row_idx) in sort_order.iter() { + let batch_idx = match batches_seen.get(batch_id) { + Some((batch_idx, _)) => *batch_idx, + None => { + let batch_idx = batches_seen.len(); + batches_needed.push(*batch_id); + batch_idx + } + }; + sort_order_offset_adjusted.push((batch_idx, *row_idx + offset.0)); + batches_seen.insert(*batch_id, (batch_idx, *row_idx + offset.0 + 1)); + } + + let batches = self + .record_batch_collector + .get_batches(batches_needed.as_slice()); + + // remove record_batches (from the batch tracker) that are fully yielded + let batches_to_remove = batches + .iter() + .zip(batches_needed) + .filter_map(|(batch, batch_id)| { + if batch.num_rows() == batches_seen[&batch_id].1 { + Some(batch_id) + } else { + None + } + }) + .collect::>(); + + // record_batch data to yield + let columns = (0..self.schema.fields.len()) + .map(|column_idx| { + let arrays: Vec<_> = batches + .iter() + .map(|batch| batch.column(column_idx).as_ref()) + .collect(); + Ok(interleave(&arrays, sort_order_offset_adjusted.as_slice())?) + }) + .collect::>>()?; + + self.record_batch_collector + .remove_batches(batches_to_remove.as_slice()); + + Ok(RecordBatch::try_new(self.schema.clone(), columns)?) + } + + fn poll_next_inner( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.aborted { + return Poll::Ready(None); + } + + match futures::ready!(self.cascade.as_mut().poll_next(cx)) { + None => Poll::Ready(None), + Some(Err(e)) => { + self.aborted = true; + Poll::Ready(Some(Err(e))) + } + Some(Ok((_, sort_order))) => match self.build_record_batch(sort_order) { + Ok(batch) => Poll::Ready(Some(Ok(batch))), + Err(e) => { + self.aborted = true; + Poll::Ready(Some(Err(e))) + } + }, + } + } +} + +impl Stream for SortPreservingCascadeStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.poll_next_inner(cx); + self.metrics.record_poll(poll) + } +} + +impl RecordBatchStream + for SortPreservingCascadeStream +{ + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +fn spawn_buffered_merge( + mut input: MergeStream, + schema: SchemaRef, + buffer: usize, +) -> MergeStream { + // Use tokio only if running from a multi-thread tokio context + match tokio::runtime::Handle::try_current() { + Ok(handle) + if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => + { + let mut builder = ReceiverStream::builder(schema, buffer); + + let sender = builder.tx(); + + builder.spawn(async move { + while let Some(item) = input.next().await { + if sender.send(item).await.is_err() { + return Ok(()); + } + } + Ok(()) + }); + + builder.build() + } + _ => input, + } +} diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index baa417649fb0..a08bef0fa4b4 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -21,15 +21,17 @@ use arrow::datatypes::ArrowNativeTypeOp; use arrow::row::{Row, Rows}; use arrow_array::types::ByteArrayType; use arrow_array::{Array, ArrowPrimitiveType, GenericByteArray, PrimitiveArray}; +use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; -use std::cmp::Ordering; +use std::{cmp::Ordering, sync::Arc}; /// A [`Cursor`] for [`Rows`] pub struct RowCursor { cur_row: usize, - num_rows: usize, + row_offset: usize, + row_limit: usize, // exclusive [offset..limit] - rows: Rows, + rows: Arc, /// Tracks for the memory used by in the `Rows` of this /// cursor. Freed on drop @@ -41,7 +43,7 @@ impl std::fmt::Debug for RowCursor { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.debug_struct("SortKeyCursor") .field("cur_row", &self.cur_row) - .field("num_rows", &self.num_rows) + .field("num_rows", &self.num_rows()) .finish() } } @@ -60,8 +62,9 @@ impl RowCursor { ); Self { cur_row: 0, - num_rows: rows.num_rows(), - rows, + row_offset: 0, + row_limit: rows.num_rows(), + rows: Arc::new(rows), reservation, } } @@ -93,25 +96,75 @@ impl Ord for RowCursor { } /// A cursor into a sorted batch of rows -pub trait Cursor: Ord { +pub trait Cursor: Ord + Send + Sync { /// Returns true if there are no more rows in this cursor fn is_finished(&self) -> bool; + /// Returns true this cursor is midway completed + fn in_progress(&self) -> bool; + /// Advance the cursor, returning the previous row index fn advance(&mut self) -> usize; + + /// Current row index + fn current_idx(&self) -> usize; + + /// Go to start row + fn reset(&mut self); + + /// Slice the cursor at a given row index, returning a new cursor + /// + /// Returns an error if the slice is out of bounds, or memory is insufficient + fn slice(&self, offset: usize, length: usize) -> Result + where + Self: Sized; + + /// Returns the number of rows in this cursor + fn num_rows(&self) -> usize; } impl Cursor for RowCursor { #[inline] fn is_finished(&self) -> bool { - self.num_rows == self.cur_row + self.cur_row >= self.row_limit + } + + #[inline] + fn in_progress(&self) -> bool { + !self.is_finished() && self.cur_row > self.row_offset } #[inline] fn advance(&mut self) -> usize { let t = self.cur_row; self.cur_row += 1; - t + t - self.row_offset + } + + #[inline] + fn current_idx(&self) -> usize { + self.cur_row - self.row_offset + } + + #[inline] + fn reset(&mut self) { + self.cur_row = self.row_offset; + } + + #[inline] + fn slice(&self, offset: usize, length: usize) -> Result { + Ok(Self { + cur_row: self.row_offset + offset, + row_offset: self.row_offset + offset, + row_limit: self.row_offset + offset + length, + rows: self.rows.clone(), + reservation: self.reservation.new_empty(), // Arc cloning of Rows is cheap + }) + } + + #[inline] + fn num_rows(&self) -> usize { + self.row_limit - self.row_offset } } @@ -131,6 +184,10 @@ pub trait FieldValues { fn compare(a: &Self::Value, b: &Self::Value) -> Ordering; fn value(&self, idx: usize) -> &Self::Value; + + fn slice(&self, offset: usize, length: usize) -> Result + where + Self: Sized; } impl FieldArray for PrimitiveArray { @@ -160,6 +217,14 @@ impl FieldValues for PrimitiveValues { fn value(&self, idx: usize) -> &Self::Value { &self.0[idx] } + + #[inline] + fn slice(&self, offset: usize, length: usize) -> Result { + if offset + length > self.len() { + return Err(DataFusionError::Internal("slice out of bounds".into())); + } + Ok(Self(self.0.slice(offset, length))) + } } impl FieldArray for GenericByteArray { @@ -191,6 +256,14 @@ impl FieldValues for GenericByteArray { fn value(&self, idx: usize) -> &Self::Value { self.value(idx) } + + #[inline] + fn slice(&self, offset: usize, length: usize) -> Result { + if offset + length > Array::len(self) { + return Err(DataFusionError::Internal("slice out of bounds".into())); + } + Ok(self.slice(offset, length)) + } } /// A cursor over sorted, nullable [`FieldValues`] @@ -265,16 +338,48 @@ impl Ord for FieldCursor { } } -impl Cursor for FieldCursor { +impl Cursor for FieldCursor { fn is_finished(&self) -> bool { self.offset == self.values.len() } + fn in_progress(&self) -> bool { + !self.is_finished() && self.offset > 0 + } + fn advance(&mut self) -> usize { let t = self.offset; self.offset += 1; t } + + fn current_idx(&self) -> usize { + self.offset + } + + fn reset(&mut self) { + self.offset = 0; + } + + fn slice(&self, offset: usize, length: usize) -> Result { + let FieldCursor { + values, + offset: _, + null_threshold, + options, + } = self; + + Ok(Self { + values: values.slice(offset, length)?, + offset: 0, + null_threshold: *null_threshold, + options: *options, + }) + } + + fn num_rows(&self) -> usize { + self.values.len() + } } #[cfg(test)] diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 67685509abe5..9352eef1cc80 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -19,87 +19,19 @@ //! This is an order-preserving merge. use crate::metrics::BaselineMetrics; -use crate::sorts::builder::BatchBuilder; +use crate::sorts::builder::SortOrderBuilder; use crate::sorts::cursor::Cursor; -use crate::sorts::stream::{FieldCursorStream, PartitionedStream, RowCursorStream}; -use crate::{PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream}; -use arrow::datatypes::{DataType, SchemaRef}; -use arrow::record_batch::RecordBatch; -use arrow_array::*; +use crate::sorts::stream::CursorStream; use datafusion_common::Result; -use datafusion_execution::memory_pool::MemoryReservation; use futures::Stream; use std::pin::Pin; use std::task::{ready, Context, Poll}; -macro_rules! primitive_merge_helper { - ($t:ty, $($v:ident),+) => { - merge_helper!(PrimitiveArray<$t>, $($v),+) - }; -} - -macro_rules! merge_helper { - ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{ - let streams = FieldCursorStream::<$t>::new($sort, $streams); - return Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - $schema, - $tracking_metrics, - $batch_size, - $fetch, - $reservation, - ))); - }}; -} - -/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions -/// while preserving order. -pub fn streaming_merge( - streams: Vec, - schema: SchemaRef, - expressions: &[PhysicalSortExpr], - metrics: BaselineMetrics, - batch_size: usize, - fetch: Option, - reservation: MemoryReservation, -) -> Result { - // Special case single column comparisons with optimized cursor implementations - if expressions.len() == 1 { - let sort = expressions[0].clone(); - let data_type = sort.expr.data_type(schema.as_ref())?; - downcast_primitive! { - data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), - DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - _ => {} - } - } - - let streams = RowCursorStream::try_new( - schema.as_ref(), - expressions, - streams, - reservation.new_empty(), - )?; - - Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - schema, - metrics, - batch_size, - fetch, - reservation, - ))) -} - -/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] -type CursorStream = Box>>; +use super::builder::YieldedSortOrder; #[derive(Debug)] -struct SortPreservingMergeStream { - in_progress: BatchBuilder, +pub(crate) struct SortPreservingMergeStream { + in_progress: SortOrderBuilder, /// The sorted input streams to merge together streams: CursorStream, @@ -151,9 +83,6 @@ struct SortPreservingMergeStream { /// target batch size batch_size: usize, - /// Vector that holds cursors for each non-exhausted input partition - cursors: Vec>, - /// Optional number of rows to fetch fetch: Option, @@ -162,22 +91,19 @@ struct SortPreservingMergeStream { } impl SortPreservingMergeStream { - fn new( + pub(crate) fn new( streams: CursorStream, - schema: SchemaRef, metrics: BaselineMetrics, batch_size: usize, fetch: Option, - reservation: MemoryReservation, ) -> Self { let stream_count = streams.partitions(); Self { - in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation), + in_progress: SortOrderBuilder::new(stream_count, batch_size), streams, metrics, aborted: false, - cursors: (0..stream_count).map(|_| None).collect(), loser_tree: vec![], loser_tree_adjusted: false, batch_size, @@ -194,7 +120,7 @@ impl SortPreservingMergeStream { cx: &mut Context<'_>, idx: usize, ) -> Poll> { - if self.cursors[idx].is_some() { + if self.in_progress.cursor_in_progress(idx) { // Cursor is not finished - don't need a new RecordBatch yet return Poll::Ready(Ok(())); } @@ -202,9 +128,8 @@ impl SortPreservingMergeStream { match futures::ready!(self.streams.poll_next(cx, idx)) { None => Poll::Ready(Ok(())), Some(Err(e)) => Poll::Ready(Err(e)), - Some(Ok((cursor, batch))) => { - self.cursors[idx] = Some(cursor); - Poll::Ready(self.in_progress.push_batch(idx, batch)) + Some(Ok(batch_cursor)) => { + Poll::Ready(self.in_progress.push_batch(idx, batch_cursor)) } } } @@ -212,7 +137,7 @@ impl SortPreservingMergeStream { fn poll_next_inner( &mut self, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>>> { if self.aborted { return Poll::Ready(None); } @@ -242,11 +167,12 @@ impl SortPreservingMergeStream { self.aborted = true; return Poll::Ready(Some(Err(e))); } + self.update_loser_tree(); } let stream_idx = self.loser_tree[0]; - if self.advance(stream_idx) { + if self.in_progress.advance(stream_idx) { self.loser_tree_adjusted = false; self.in_progress.push_row(stream_idx); @@ -259,8 +185,7 @@ impl SortPreservingMergeStream { } self.produced += self.in_progress.len(); - - return Poll::Ready(self.in_progress.build_record_batch().transpose()); + return Poll::Ready(self.in_progress.yield_sort_order().transpose()); } } @@ -270,30 +195,6 @@ impl SortPreservingMergeStream { .unwrap_or(false) } - fn advance(&mut self, stream_idx: usize) -> bool { - let slot = &mut self.cursors[stream_idx]; - match slot.as_mut() { - Some(c) => { - c.advance(); - if c.is_finished() { - *slot = None; - } - true - } - None => false, - } - } - - /// Returns `true` if the cursor at index `a` is greater than at index `b` - #[inline] - fn is_gt(&self, a: usize, b: usize) -> bool { - match (&self.cursors[a], &self.cursors[b]) { - (None, _) => true, - (_, None) => false, - (Some(ac), Some(bc)) => ac.cmp(bc).then_with(|| a.cmp(&b)).is_gt(), - } - } - /// Find the leaf node index in the loser tree for the given cursor index /// /// Note that this is not necessarily a leaf node in the tree, but it can @@ -324,7 +225,7 @@ impl SortPreservingMergeStream { /// #[inline] fn lt_leaf_node_index(&self, cursor_index: usize) -> usize { - (self.cursors.len() + cursor_index) / 2 + (self.num_leaf_nodes() + cursor_index) / 2 } /// Find the parent node index for the given node index @@ -333,17 +234,22 @@ impl SortPreservingMergeStream { node_idx / 2 } + #[inline] + fn num_leaf_nodes(&self) -> usize { + self.streams.partitions() + } + /// Attempts to initialize the loser tree with one value from each /// non exhausted input, if possible fn init_loser_tree(&mut self) { // Init loser tree - self.loser_tree = vec![usize::MAX; self.cursors.len()]; - for i in 0..self.cursors.len() { + self.loser_tree = vec![usize::MAX; self.num_leaf_nodes()]; + for i in 0..self.num_leaf_nodes() { let mut winner = i; let mut cmp_node = self.lt_leaf_node_index(i); while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX { let challenger = self.loser_tree[cmp_node]; - if self.is_gt(winner, challenger) { + if self.in_progress.is_gt(winner, challenger) { self.loser_tree[cmp_node] = winner; winner = challenger; } @@ -362,7 +268,7 @@ impl SortPreservingMergeStream { let mut cmp_node = self.lt_leaf_node_index(winner); while cmp_node != 0 { let challenger = self.loser_tree[cmp_node]; - if self.is_gt(winner, challenger) { + if self.in_progress.is_gt(winner, challenger) { self.loser_tree[cmp_node] = winner; winner = challenger; } @@ -374,19 +280,12 @@ impl SortPreservingMergeStream { } impl Stream for SortPreservingMergeStream { - type Item = Result; + type Item = Result>; fn poll_next( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll = self.poll_next_inner(cx); - self.metrics.record_poll(poll) - } -} - -impl RecordBatchStream for SortPreservingMergeStream { - fn schema(&self) -> SchemaRef { - self.in_progress.schema().clone() + self.poll_next_inner(cx) } } diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index dff39db423f0..6dcbdd7b02f0 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -17,13 +17,16 @@ //! Sort functionalities +mod batch_cursor; mod builder; +mod cascade; mod cursor; mod index; -pub mod merge; +mod merge; pub mod sort; pub mod sort_preserving_merge; mod stream; +pub mod streaming_merge; pub use index::RowIndex; -pub(crate) use merge::streaming_merge; +pub(crate) use streaming_merge::streaming_merge; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 703f80d90d2b..6f0103901c07 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -24,8 +24,8 @@ use crate::expressions::PhysicalSortExpr; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; -use crate::sorts::merge::streaming_merge; -use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; +use crate::sorts::streaming_merge::streaming_merge; +use crate::stream::{ReceiverStream, RecordBatchStreamAdapter}; use crate::topk::TopK; use crate::{ DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, @@ -613,7 +613,7 @@ pub(crate) fn read_spill_as_stream( path: RefCountedTempFile, schema: SchemaRef, ) -> Result { - let mut builder = RecordBatchReceiverStream::builder(schema, 2); + let mut builder = ReceiverStream::builder(schema, 2); let sender = builder.tx(); builder.spawn_blocking(move || { diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 5b485e0b68e4..d49e3b1cfded 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -282,7 +282,7 @@ mod tests { use crate::memory::MemoryExec; use crate::metrics::{MetricValue, Timestamp}; use crate::sorts::sort::SortExec; - use crate::stream::RecordBatchReceiverStream; + use crate::stream::ReceiverStream; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{self, assert_is_pending, make_partition}; use crate::{collect, common}; @@ -291,8 +291,32 @@ mod tests { use super::*; - #[tokio::test] - async fn test_merge_interleave() { + macro_rules! run_test_in_threaded_envs { + ($name: ident, ret = $ret:ty, $($test:tt)*) => { + paste::paste! { + #[tokio::test] + async fn $name() -> $ret + $($test)* + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn [<$name _multithreaded>]() -> $ret + $($test)* + } + }; + ($name: ident, $($test:tt)*) => { + paste::paste! { + #[tokio::test] + async fn $name() + $($test)* + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn [<$name _multithreaded>]() + $($test)* + } + }; + } + + run_test_in_threaded_envs!(test_merge_interleave, { let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ @@ -337,10 +361,9 @@ mod tests { task_ctx, ) .await; - } + }); - #[tokio::test] - async fn test_merge_some_overlap() { + run_test_in_threaded_envs!(test_merge_some_overlap, { let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ @@ -385,10 +408,9 @@ mod tests { task_ctx, ) .await; - } + }); - #[tokio::test] - async fn test_merge_no_overlap() { + run_test_in_threaded_envs!(test_merge_no_overlap, { let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ @@ -433,10 +455,9 @@ mod tests { task_ctx, ) .await; - } + }); - #[tokio::test] - async fn test_merge_three_partitions() { + run_test_in_threaded_envs!(test_merge_three_partitions, { let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ @@ -498,7 +519,7 @@ mod tests { task_ctx, ) .await; - } + }); async fn _test_merge( partitions: &[Vec], @@ -556,8 +577,7 @@ mod tests { result.remove(0) } - #[tokio::test] - async fn test_partition_sort() -> Result<()> { + run_test_in_threaded_envs!(test_partition_sort, ret = Result<()>, { let task_ctx = Arc::new(TaskContext::default()); let partitions = 4; let csv = test::scan_partitioned(partitions); @@ -587,7 +607,7 @@ mod tests { ); Ok(()) - } + }); // Split the provided record batch into multiple batch_size record batches fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec { @@ -628,8 +648,7 @@ mod tests { )) } - #[tokio::test] - async fn test_partition_sort_streaming_input() -> Result<()> { + run_test_in_threaded_envs!(test_partition_sort_streaming_input, ret = Result<()>, { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); let sort = vec![PhysicalSortExpr { @@ -656,10 +675,9 @@ mod tests { assert_eq!(basic, partition); Ok(()) - } + }); - #[tokio::test] - async fn test_partition_sort_streaming_input_output() -> Result<()> { + run_test_in_threaded_envs!(test_partition_sort_streaming_input_output, ret = Result<()>, { let schema = make_partition(11).schema(); let sort = vec![PhysicalSortExpr { expr: col("i", &schema).unwrap(), @@ -696,10 +714,9 @@ mod tests { assert_eq!(basic, partition); Ok(()) - } + }); - #[tokio::test] - async fn test_nulls() { + run_test_in_threaded_envs!(test_nulls, { let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ @@ -777,10 +794,9 @@ mod tests { ], collected.as_slice() ); - } + }); - #[tokio::test] - async fn test_async() -> Result<()> { + run_test_in_threaded_envs!(test_async, ret = Result<()>, { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); let sort = vec![PhysicalSortExpr { @@ -792,10 +808,11 @@ mod tests { sorted_partitioned_input(sort.clone(), &[5, 7, 3], task_ctx.clone()).await?; let partition_count = batches.output_partitioning().partition_count(); - let mut streams = Vec::with_capacity(partition_count); + let mut streams: Vec = + Vec::with_capacity(partition_count); for partition in 0..partition_count { - let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1); + let mut builder = ReceiverStream::::builder(schema.clone(), 1); let sender = builder.tx(); @@ -848,10 +865,9 @@ mod tests { ); Ok(()) - } + }); - #[tokio::test] - async fn test_merge_metrics() { + run_test_in_threaded_envs!(test_merge_metrics, { let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")])); @@ -904,14 +920,13 @@ mod tests { assert!(saw_start); assert!(saw_end); - } + }); fn nanos_from_timestamp(ts: &Timestamp) -> i64 { ts.value().unwrap().timestamp_nanos_opt().unwrap() } - #[tokio::test] - async fn test_drop_cancel() -> Result<()> { + run_test_in_threaded_envs!(test_drop_cancel, ret = Result<()>, { let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -934,10 +949,9 @@ mod tests { assert_strong_count_converges_to_zero(refs).await; Ok(()) - } + }); - #[tokio::test] - async fn test_stable_sort() { + run_test_in_threaded_envs!(test_stable_sort, { let task_ctx = Arc::new(TaskContext::default()); // Create record batches like: @@ -1013,5 +1027,5 @@ mod tests { ], collected.as_slice() ); - } + }); } diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index a7f9e7380c47..0cec3bbb3603 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::sorts::cursor::{FieldArray, FieldCursor, RowCursor}; +use crate::sorts::builder::SortOrder; +use crate::sorts::cursor::{Cursor, FieldArray, FieldCursor, RowCursor}; use crate::SendableRecordBatchStream; use crate::{PhysicalExpr, PhysicalSortExpr}; +use ahash::RandomState; use arrow::array::Array; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; @@ -25,10 +27,47 @@ use arrow::row::{RowConverter, SortField}; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; use futures::stream::{Fuse, StreamExt}; +use parking_lot::Mutex; +use std::collections::{HashMap, VecDeque}; use std::marker::PhantomData; +use std::ops::Range; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::task::{ready, Context, Poll}; +use super::batch_cursor::{BatchCursor, BatchId, SlicedBatchCursorIdentifier}; +use super::builder::YieldedSortOrder; + +/// A fallible [`PartitionedStream`] of record batches. +/// +/// Each [`Cursor`] and [`RecordBatch`] represents a single record batch. +pub(crate) trait BatchCursorStream: + PartitionedStream + Send + 'static +{ + /// Acquire ownership over a subset of the partitioned streams. + /// + /// Like `Vec::take()`, this removes the indexed positions. + fn take_partitions(&mut self, range: Range) -> Box + where + Self: Sized; +} + +/// A [`PartitionedStream`] representing partial record batches (a.k.a. [`BatchCursor`]). +/// +/// Each merge node (a `SortPreservingMergeStream` loser tree) will consume a [`CursorStream`]. +pub(crate) type CursorStream = + Box>> + Send>; + +/// A fallible stream of yielded [`SortOrder`]s is a [`MergeStream`]. +/// +/// Within a cascade of merge nodes, (each node being a `SortPreservingMergeStream` loser tree), +/// the merge node will yield a SortOrder as well as any partial record batches from the SortOrder. +/// +/// [`YieldedCursorStream`] then converts an output [`MergeStream`] +/// into an input [`CursorStream`] for the next merge. +pub(crate) type MergeStream = + std::pin::Pin>> + Send>>; + /// A [`Stream`](futures::Stream) that has multiple partitions that can /// be polled separately but not concurrently /// @@ -80,7 +119,7 @@ impl FusedStreams { #[derive(Debug)] pub struct RowCursorStream { /// Converter to convert output of physical expressions - converter: RowConverter, + converter: Arc, /// The physical expressions to sort by column_expressions: Vec>, /// Input streams @@ -105,7 +144,7 @@ impl RowCursorStream { .collect::>>()?; let streams = streams.into_iter().map(|s| s.fuse()).collect(); - let converter = RowConverter::new(sort_fields)?; + let converter = Arc::new(RowConverter::new(sort_fields)?); Ok(Self { converter, reservation, @@ -114,6 +153,15 @@ impl RowCursorStream { }) } + fn new_from_streams(&self, streams: FusedStreams) -> Self { + Self { + converter: self.converter.clone(), + column_expressions: self.column_expressions.clone(), + reservation: self.reservation.new_empty(), + streams, + } + } + fn convert_batch(&mut self, batch: &RecordBatch) -> Result { let cols = self .column_expressions @@ -152,6 +200,13 @@ impl PartitionedStream for RowCursorStream { } } +impl BatchCursorStream for RowCursorStream { + fn take_partitions(&mut self, range: Range) -> Box { + let streams_slice = self.streams.0.drain(range).collect::>(); + Box::new(self.new_from_streams(FusedStreams(streams_slice))) + } +} + /// Specialized stream for sorts on single primitive columns pub struct FieldCursorStream { /// The physical expressions to sort by @@ -179,6 +234,14 @@ impl FieldCursorStream { } } + fn new_from_streams(&self, streams: FusedStreams) -> Self { + Self { + sort: self.sort.clone(), + phantom: self.phantom, + streams, + } + } + fn convert_batch(&mut self, batch: &RecordBatch) -> Result> { let value = self.sort.expr.evaluate(batch)?; let array = value.into_array(batch.num_rows()); @@ -207,3 +270,264 @@ impl PartitionedStream for FieldCursorStream { })) } } + +impl BatchCursorStream for FieldCursorStream { + fn take_partitions(&mut self, range: Range) -> Box { + let streams_slice = self.streams.0.drain(range).collect::>(); + Box::new(self.new_from_streams(FusedStreams(streams_slice))) + } +} + +/// A wrapper around [`CursorStream`] that collects the [`RecordBatch`] per poll, +/// and only passes along the [`BatchCursor`]. +/// +/// This is used in the leaf nodes of the cascading merge tree. +pub(crate) struct BatchTrackerStream< + C: Cursor, + S: BatchCursorStream>, +> { + // Partitioned Input stream. + streams: Box, + record_batch_holder: Arc, +} + +impl>> + BatchTrackerStream +{ + pub fn new(streams: Box, record_batch_holder: Arc) -> Self { + Self { + streams, + record_batch_holder, + } + } +} + +impl>> + PartitionedStream for BatchTrackerStream +{ + type Output = Result>; + + fn partitions(&self) -> usize { + self.streams.partitions() + } + + fn poll_next( + &mut self, + cx: &mut Context<'_>, + stream_idx: usize, + ) -> Poll> { + Poll::Ready(ready!(self.streams.poll_next(cx, stream_idx)).map(|r| { + r.and_then(|(cursor, batch)| { + let batch_id = self.record_batch_holder.add_batch(batch)?; + Ok(BatchCursor::new(batch_id, cursor)) + }) + })) + } +} + +impl>> + std::fmt::Debug for BatchTrackerStream +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OffsetCursorStream").finish() + } +} + +/// Converts a [`BatchCursorStream`] into a [`CursorStream`]. +/// +/// While storing the record batches outside of the cascading merge tree. +/// Should be used with a Mutex. +pub struct BatchTracker { + /// Monotonically increasing batch id + monotonic_counter: AtomicU64, + /// Write once, read many [`RecordBatch`]s + batches: Mutex, RandomState>>, + /// Accounts for memory used by buffered batches + reservation: Mutex, +} + +impl BatchTracker { + pub fn new(reservation: MemoryReservation) -> Self { + Self { + monotonic_counter: AtomicU64::new(0), + batches: Mutex::new(HashMap::with_hasher(RandomState::new())), + reservation: Mutex::new(reservation), + } + } + + pub fn add_batch(&self, batch: RecordBatch) -> Result { + self.reservation + .lock() + .try_grow(batch.get_array_memory_size())?; + let batch_id = self.monotonic_counter.fetch_add(1, Ordering::Relaxed); + self.batches.lock().insert(batch_id, Arc::new(batch)); + Ok(batch_id) + } + + pub fn get_batches(&self, batch_ids: &[BatchId]) -> Vec> { + let batches = self.batches.lock(); + batch_ids.iter().map(|id| batches[id].clone()).collect() + } + + pub fn remove_batches(&self, batch_ids: &[BatchId]) { + let mut batches = self.batches.lock(); + for id in batch_ids { + batches.remove(id); + } + } +} + +impl std::fmt::Debug for BatchTracker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BatchTracker").finish() + } +} + +/// A newtype wrapper around a set of fused [`MergeStream`] +/// that implements debug, and skips over empty inner poll results +struct FusedMergeStreams(Vec>>); + +impl FusedMergeStreams { + fn poll_next( + &mut self, + cx: &mut Context<'_>, + stream_idx: usize, + ) -> Poll>>> { + loop { + match ready!(self.0[stream_idx].poll_next_unpin(cx)) { + Some(Ok((_, sort_order))) if sort_order.is_empty() => continue, + r => return Poll::Ready(r), + } + } + } +} + +impl std::fmt::Debug for FusedMergeStreams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FusedMergeStreams").finish() + } +} + +/// [`YieldedCursorStream`] converts an output [`MergeStream`] +/// into an input [`CursorStream`] for the next merge. +pub struct YieldedCursorStream { + // Inner polled batch cursors, per stream_idx, which represents partially yielded batches. + cursors: Vec>>>, + /// Streams being polled + streams: FusedMergeStreams, +} + +impl YieldedCursorStream { + pub fn new(streams: Vec>) -> Self { + let stream_cnt = streams.len(); + Self { + cursors: (0..stream_cnt).map(|_| None).collect(), + streams: FusedMergeStreams(streams.into_iter().map(|s| s.fuse()).collect()), + } + } + + fn incr_next_batch(&mut self, stream_idx: usize) -> Option> { + self.cursors[stream_idx] + .as_mut() + .and_then(|queue| queue.pop_front()) + } + + // The input [`SortOrder`] is across batches. + // We need to further parse the cursors into smaller batches. + // + // Input: + // - sort_order: Vec<(SlicedBatchCursorIdentifier, row_idx)> = [(0,0), (0,1), (1,0), (0,2), (0,3)] + // - cursors: Vec = [cursor_0, cursor_1] + // + // Output stream: + // Needs to be yielded to the next merge in three partial batches: + // [(0,0),(0,1)] with cursor => then [(1,0)] with cursor => then [(0,2),(0,3)] with cursor + // + // This additional parsing is only required when streaming into another merge node, + // and not required when yielding to the final interleave step. + // (Performance slightly decreases when doing this additional parsing for all SortOrderBuilder yields.) + fn try_parse_batches( + &mut self, + stream_idx: usize, + cursors: Vec>, + sort_order: Vec, + ) -> Result<()> { + let mut cursors_per_batch: HashMap< + SlicedBatchCursorIdentifier, + BatchCursor, + RandomState, + > = HashMap::with_capacity_and_hasher(cursors.len(), RandomState::new()); + for cursor in cursors { + cursors_per_batch.insert(cursor.identifier(), cursor); + } + + let mut parsed_cursors: Vec> = + Vec::with_capacity(sort_order.len()); + let ((mut prev_batch_id, mut prev_batch_offset), mut prev_row_idx) = + sort_order[0]; + let mut len = 0; + + for ((batch_id, batch_offset), row_idx) in sort_order.iter() { + if prev_batch_id == *batch_id && batch_offset.0 == prev_batch_offset.0 { + len += 1; + continue; + } else { + // parse cursor + if let Some(cursor) = + cursors_per_batch.get(&(prev_batch_id, prev_batch_offset)) + { + let parsed_cursor = cursor.slice(prev_row_idx, len)?; + parsed_cursors.push(parsed_cursor); + } else { + unreachable!("cursor not found"); + } + + prev_batch_id = *batch_id; + prev_row_idx = *row_idx; + prev_batch_offset = *batch_offset; + len = 1; + } + } + if let Some(cursor) = cursors_per_batch.get(&(prev_batch_id, prev_batch_offset)) { + let parsed_cursor = cursor.slice(prev_row_idx, len)?; + parsed_cursors.push(parsed_cursor); + } + + self.cursors[stream_idx] = Some(VecDeque::from(parsed_cursors)); + Ok(()) + } +} + +impl PartitionedStream for YieldedCursorStream { + type Output = Result>; + + fn partitions(&self) -> usize { + self.streams.0.len() + } + + fn poll_next( + &mut self, + cx: &mut Context<'_>, + stream_idx: usize, + ) -> Poll> { + match self.incr_next_batch(stream_idx) { + None => match ready!(self.streams.poll_next(cx, stream_idx)) { + None => Poll::Ready(None), + Some(Err(e)) => Poll::Ready(Some(Err(e))), + Some(Ok((cursors, sort_order))) => { + self.try_parse_batches(stream_idx, cursors, sort_order)?; + Poll::Ready((Ok(self.incr_next_batch(stream_idx))).transpose()) + } + }, + Some(r) => Poll::Ready(Some(Ok(r))), + } + } +} + +impl std::fmt::Debug for YieldedCursorStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("YieldedCursorStream") + .field("num_partitions", &self.partitions()) + .finish() + } +} diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs new file mode 100644 index 000000000000..e026ad683cdf --- /dev/null +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Merge that deals with an arbitrary size of streaming inputs. +//! This is an order-preserving merge. + +use crate::metrics::BaselineMetrics; +use crate::sorts::{ + cascade::SortPreservingCascadeStream, + stream::{FieldCursorStream, RowCursorStream}, +}; +use crate::{PhysicalSortExpr, SendableRecordBatchStream}; +use arrow::datatypes::{DataType, SchemaRef}; +use arrow_array::*; +use datafusion_common::Result; +use datafusion_execution::memory_pool::MemoryReservation; + +macro_rules! primitive_merge_helper { + ($t:ty, $($v:ident),+) => { + merge_helper!(PrimitiveArray<$t>, $($v),+) + }; +} + +macro_rules! merge_helper { + ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{ + let streams = FieldCursorStream::<$t>::new($sort, $streams); + return Ok(Box::pin(SortPreservingCascadeStream::new( + streams, + $schema, + $tracking_metrics, + $batch_size, + $fetch, + $reservation, + ))); + }}; +} + +/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions +/// while preserving order. +pub fn streaming_merge( + streams: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + metrics: BaselineMetrics, + batch_size: usize, + fetch: Option, + reservation: MemoryReservation, +) -> Result { + // Special case single column comparisons with optimized cursor implementations + if expressions.len() == 1 { + let sort = expressions[0].clone(); + let data_type = sort.expr.data_type(schema.as_ref())?; + downcast_primitive! { + data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), + DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + _ => {} + } + } + + let streams = RowCursorStream::try_new( + schema.as_ref(), + expressions, + streams, + reservation.new_empty(), + )?; + + Ok(Box::pin(SortPreservingCascadeStream::new( + streams, + schema, + metrics, + batch_size, + fetch, + reservation, + ))) +} diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index a3fb856c326d..dc3aab33636b 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -38,22 +38,22 @@ use tokio::task::JoinSet; use super::metrics::BaselineMetrics; use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; -/// Builder for [`RecordBatchReceiverStream`] that propagates errors +/// Builder for [`ReceiverStream`] that propagates errors /// and panic's correctly. /// -/// [`RecordBatchReceiverStream`] is used to spawn one or more tasks -/// that produce `RecordBatch`es and send them to a single -/// `Receiver` which can improve parallelism. +/// [`ReceiverStream`] is used to spawn one or more tasks +/// that produce `O` output, (such as a `RecordBatch`), +/// and sends them to a single `Receiver` which can improve parallelism. /// /// This also handles propagating panic`s and canceling the tasks. -pub struct RecordBatchReceiverStreamBuilder { - tx: Sender>, - rx: Receiver>, +pub struct ReceiverStreamBuilder { + tx: Sender>, + rx: Receiver>, schema: SchemaRef, join_set: JoinSet>, } -impl RecordBatchReceiverStreamBuilder { +impl ReceiverStreamBuilder { /// create new channels with the specified buffer size pub fn new(schema: SchemaRef, capacity: usize) -> Self { let (tx, rx) = tokio::sync::mpsc::channel(capacity); @@ -66,8 +66,8 @@ impl RecordBatchReceiverStreamBuilder { } } - /// Get a handle for sending [`RecordBatch`]es to the output - pub fn tx(&self) -> Sender> { + /// Get a handle for sending `O` to the output + pub fn tx(&self) -> Sender> { self.tx.clone() } @@ -104,59 +104,17 @@ impl RecordBatchReceiverStreamBuilder { /// sent to the output stream and no further results are sent. pub(crate) fn run_input( &mut self, - input: Arc, + input: Arc> + Send + Sync>, partition: usize, context: Arc, ) { let output = self.tx(); - self.spawn(async move { - let mut stream = match input.execute(partition, context) { - Err(e) => { - // If send fails, the plan being torn down, there - // is no place to send the error and no reason to continue. - output.send(Err(e)).await.ok(); - debug!( - "Stopping execution: error executing input: {}", - displayable(input.as_ref()).one_line() - ); - return Ok(()); - } - Ok(stream) => stream, - }; - - // Transfer batches from inner stream to the output tx - // immediately. - while let Some(item) = stream.next().await { - let is_err = item.is_err(); - - // If send fails, plan being torn down, there is no - // place to send the error and no reason to continue. - if output.send(item).await.is_err() { - debug!( - "Stopping execution: output is gone, plan cancelling: {}", - displayable(input.as_ref()).one_line() - ); - return Ok(()); - } - - // stop after the first error is encontered (don't - // drive all streams to completion) - if is_err { - debug!( - "Stopping execution: plan returned error: {}", - displayable(input.as_ref()).one_line() - ); - return Ok(()); - } - } - - Ok(()) - }); + self.spawn(async move { input.call(output, partition, context).await }); } - /// Create a stream of all `RecordBatch`es written to `tx` - pub fn build(self) -> SendableRecordBatchStream { + /// Create a stream of all `O`es written to `tx` + pub fn build(self) -> Pin>> { let Self { tx, rx, @@ -214,34 +172,33 @@ impl RecordBatchReceiverStreamBuilder { // produces the batch let inner = futures::stream::select(rx_stream, check_stream).boxed(); - Box::pin(RecordBatchReceiverStream { schema, inner }) + Box::pin(ReceiverStream::::new(schema, inner)) } } -/// A [`SendableRecordBatchStream`] that combines [`RecordBatch`]es from multiple inputs, -/// on new tokio Tasks, increasing the potential parallelism. +/// A SendableStream that combines multiple inputs, on new tokio Tasks, +/// increasing the potential parallelism. +/// +/// When `O` is a [`RecordBatch`], this is a [`SendableRecordBatchStream`]. /// /// This structure also handles propagating panics and cancelling the /// underlying tasks correctly. /// /// Use [`Self::builder`] to construct one. -pub struct RecordBatchReceiverStream { +pub struct ReceiverStream { schema: SchemaRef, - inner: BoxStream<'static, Result>, + inner: BoxStream<'static, Result>, } -impl RecordBatchReceiverStream { +impl ReceiverStream { /// Create a builder with an internal buffer of capacity batches. - pub fn builder( - schema: SchemaRef, - capacity: usize, - ) -> RecordBatchReceiverStreamBuilder { - RecordBatchReceiverStreamBuilder::new(schema, capacity) + pub fn builder(schema: SchemaRef, capacity: usize) -> ReceiverStreamBuilder { + ReceiverStreamBuilder::::new(schema, capacity) } } -impl Stream for RecordBatchReceiverStream { - type Item = Result; +impl Stream for ReceiverStream { + type Item = Result; fn poll_next( mut self: Pin<&mut Self>, @@ -251,12 +208,95 @@ impl Stream for RecordBatchReceiverStream { } } -impl RecordBatchStream for RecordBatchReceiverStream { +impl ReceiverStream { + /// Create a new [`ReceiverStream`] from the provided schema and stream + pub fn new(schema: SchemaRef, inner: BoxStream<'static, Result>) -> Self { + Self { schema, inner } + } +} + +impl RecordBatchStream for ReceiverStream { fn schema(&self) -> SchemaRef { self.schema.clone() } } +#[async_trait::async_trait] +pub(crate) trait StreamAdaptor: Send + Sync { + type Item: Send + Sync; + + async fn call( + &self, + output: Sender, + partition: usize, + context: Arc, + ) -> Result<()>; +} + +pub(crate) struct RecordBatchReceiverStreamAdaptor { + input: Arc, +} + +impl RecordBatchReceiverStreamAdaptor { + pub fn new(input: Arc) -> Self { + Self { input } + } +} + +#[async_trait::async_trait] +impl StreamAdaptor for RecordBatchReceiverStreamAdaptor { + type Item = Result; + + async fn call( + &self, + output: Sender, + partition: usize, + context: Arc, + ) -> Result<()> { + let mut stream = match self.input.execute(partition, context) { + Err(e) => { + // If send fails, the plan being torn down, there + // is no place to send the error and no reason to continue. + output.send(Err(e)).await.ok(); + debug!( + "Stopping execution: error executing input: {}", + displayable(self.input.as_ref()).one_line() + ); + return Ok(()); + } + Ok(stream) => stream, + }; + + // Transfer batches from inner stream to the output tx + // immediately. + while let Some(item) = stream.next().await { + let is_err = item.is_err(); + + // If send fails, plan being torn down, there is no + // place to send the error and no reason to continue. + if output.send(item).await.is_err() { + debug!( + "Stopping execution: output is gone, plan cancelling: {}", + displayable(self.input.as_ref()).one_line() + ); + return Ok(()); + } + + // stop after the first error is encontered (don't + // drive all streams to completion) + if is_err { + debug!( + "Stopping execution: plan returned error: {}", + displayable(self.input.as_ref()).one_line() + ); + return Ok(()); + } + } + + Ok(()) + } +} + pin_project! { /// Combines a [`Stream`] with a [`SchemaRef`] implementing /// [`RecordBatchStream`] for the combination @@ -429,8 +469,9 @@ mod test { let refs = input.refs(); // Configure a RecordBatchReceiverStream to consume the input - let mut builder = RecordBatchReceiverStream::builder(schema, 2); - builder.run_input(Arc::new(input), 0, task_ctx.clone()); + let mut builder = ReceiverStream::builder(schema, 2); + let input = Arc::new(RecordBatchReceiverStreamAdaptor::new(Arc::new(input))); + builder.run_input(input, 0, task_ctx.clone()); let stream = builder.build(); // input should still be present @@ -454,8 +495,11 @@ mod test { MockExec::new(vec![exec_err!("Test1"), exec_err!("Test2")], schema.clone()) .with_use_task(false); - let mut builder = RecordBatchReceiverStream::builder(schema, 2); - builder.run_input(Arc::new(error_stream), 0, task_ctx.clone()); + let mut builder = ReceiverStream::builder(schema, 2); + let input = Arc::new(RecordBatchReceiverStreamAdaptor::new(Arc::new( + error_stream, + ))); + builder.run_input(input, 0, task_ctx.clone()); let mut stream = builder.build(); // get the first result, which should be an error @@ -478,10 +522,13 @@ mod test { let num_partitions = input.output_partitioning().partition_count(); // Configure a RecordBatchReceiverStream to consume all the input partitions - let mut builder = - RecordBatchReceiverStream::builder(input.schema(), num_partitions); + let mut builder = ReceiverStream::builder(input.schema(), num_partitions); for partition in 0..num_partitions { - builder.run_input(input.clone(), partition, task_ctx.clone()); + builder.run_input( + Arc::new(RecordBatchReceiverStreamAdaptor::new(input.clone())), + partition, + task_ctx.clone(), + ); } let mut stream = builder.build(); diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index a1f40c7ba909..8c09a96deb3e 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -32,7 +32,8 @@ use arrow::{ use futures::Stream; use crate::{ - common, stream::RecordBatchReceiverStream, stream::RecordBatchStreamAdapter, + common, + stream::{ReceiverStream, RecordBatchStreamAdapter}, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; @@ -211,7 +212,7 @@ impl ExecutionPlan for MockExec { .collect(); if self.use_task { - let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); + let mut builder = ReceiverStream::builder(self.schema(), 2); // send data in order but in a separate task (to ensure // the batches are not available without the stream // yielding). @@ -346,7 +347,7 @@ impl ExecutionPlan for BarrierExec { ) -> Result { assert!(partition < self.data.len()); - let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); + let mut builder = ReceiverStream::builder(self.schema(), 2); // task simply sends data in order after barrier is reached let data = self.data[partition].clone();