diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index fe667d1e6e2a..3fcae0174114 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -19,16 +19,17 @@ //! into a single partition use std::any::Any; +use std::panic; use std::sync::Arc; use std::task::Poll; -use futures::Stream; +use futures::{Future, Stream}; use tokio::sync::mpsc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use tokio::task::JoinSet; -use super::common::AbortOnDropMany; use super::expressions::PhysicalSortExpr; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{RecordBatchStream, Statistics}; @@ -142,21 +143,22 @@ impl ExecutionPlan for CoalescePartitionsExec { // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. - let mut join_handles = Vec::with_capacity(input_partitions); + let mut tasks = JoinSet::new(); for part_i in 0..input_partitions { - join_handles.push(spawn_execution( + spawn_execution( + &mut tasks, self.input.clone(), sender.clone(), part_i, context.clone(), - )); + ); } Ok(Box::pin(MergeStream { input: receiver, schema: self.schema(), baseline_metrics, - drop_helper: AbortOnDropMany(join_handles), + tasks, })) } } @@ -187,8 +189,7 @@ struct MergeStream { schema: SchemaRef, input: mpsc::Receiver>, baseline_metrics: BaselineMetrics, - #[allow(unused)] - drop_helper: AbortOnDropMany<()>, + tasks: JoinSet<()>, } impl Stream for MergeStream { @@ -199,6 +200,28 @@ impl Stream for MergeStream { cx: &mut std::task::Context<'_>, ) -> Poll> { let poll = self.input.poll_recv(cx); + + // If the input stream is done, wait for all tasks to finish and return + // the failure if any. + if let Poll::Ready(None) = poll { + let fut = self.tasks.join_next(); + tokio::pin!(fut); + + match fut.poll(cx) { + Poll::Ready(task_poll) => { + if let Some(Err(e)) = task_poll { + if e.is_panic() { + panic::resume_unwind(e.into_panic()); + } + return Poll::Ready(Some(Err(DataFusionError::Execution( + format!("{e:?}"), + )))); + } + } + Poll::Pending => {} + } + } + self.baseline_metrics.record_poll(poll) } } @@ -218,7 +241,9 @@ mod tests { use super::*; use crate::physical_plan::{collect, common}; use crate::prelude::SessionContext; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, PanickingExec, + }; use crate::test::{self, assert_is_pending}; #[tokio::test] @@ -270,4 +295,19 @@ mod tests { Ok(()) } + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic")] + async fn test_panic() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let panicking_exec = Arc::new(PanickingExec::new(Arc::clone(&schema), 2)); + let coalesce_partitions_exec = + Arc::new(CoalescePartitionsExec::new(panicking_exec)); + + collect(coalesce_partitions_exec, task_ctx).await.unwrap(); + } } diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index e766c225b51c..a9c267f123a8 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -38,7 +38,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::mpsc; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, JoinSet}; /// [`MemoryReservation`] used across query execution streams pub(crate) type SharedMemoryReservation = Arc>; @@ -98,12 +98,13 @@ fn build_file_list_recurse( /// Spawns a task to the tokio threadpool and writes its outputs to the provided mpsc sender pub(crate) fn spawn_execution( + join_set: &mut JoinSet<()>, input: Arc, output: mpsc::Sender>, partition: usize, context: Arc, -) -> JoinHandle<()> { - tokio::spawn(async move { +) { + join_set.spawn(async move { let mut stream = match input.execute(partition, context) { Err(e) => { // If send fails, plan being torn down, @@ -129,7 +130,7 @@ pub(crate) fn spawn_execution( return; } } - }) + }); } /// If running in a tokio context spawns the execution of `stream` to a separate task diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index bce7d08a5c56..13f3dc6a16c8 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -643,3 +643,108 @@ pub async fn assert_strong_count_converges_to_zero(refs: Weak) { .await .unwrap(); } + +/// Execution plan that emits streams that panics. +/// +/// This is useful to test panic handling of certain execution plans. +#[derive(Debug)] +pub struct PanickingExec { + /// Schema that is mocked by this plan. + schema: SchemaRef, + + /// Number of output partitions. + n_partitions: usize, +} + +impl PanickingExec { + /// Create new [`PanickingExec`] with a give schema and number of partitions. + pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { + Self { + schema, + n_partitions, + } + } +} + +impl ExecutionPlan for PanickingExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec> { + // this is a leaf node and has no children + vec![] + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.n_partitions) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Err(DataFusionError::Internal(format!( + "Children cannot be replaced in {:?}", + self + ))) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + Ok(Box::pin(PanickingStream { + schema: Arc::clone(&self.schema), + })) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!(f, "PanickingExec",) + } + } + } + + fn statistics(&self) -> Statistics { + unimplemented!() + } +} + +/// A [`RecordBatchStream`] that panics on first poll. +#[derive(Debug)] +pub struct PanickingStream { + /// Schema mocked by this stream. + schema: SchemaRef, +} + +impl Stream for PanickingStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + panic!("PanickingStream did panic") + } +} + +impl RecordBatchStream for PanickingStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +}