From 886735304f0fa83a24ddc3b3e356aeff4aab78f2 Mon Sep 17 00:00:00 2001 From: Zhang Li Date: Fri, 22 Apr 2022 11:43:02 +0800 Subject: [PATCH] Implement merge join (#2242) * Implement Sort-Merge join (#141) * Complete doc comments and pass cargo clippy * Implement metrics for SMJ * Support join columns with different sort options * Update datafusion/core/src/physical_plan/sort_merge_join.rs Add detailed comments of the ordering requirements of two input children. Co-authored-by: Andrew Lamb * use indices instead of ArrayBuilders for constructing output record batches * Support timestamp/decimal types in join columns * Add fuzz test and fix edge cases * Support float32/64 data types in comparison * Fix lint issues Co-authored-by: zhangli20 Co-authored-by: Andrew Lamb --- datafusion/core/src/physical_plan/mod.rs | 1 + .../core/src/physical_plan/sort_merge_join.rs | 1750 +++++++++++++++++ datafusion/core/tests/join_fuzz.rs | 223 +++ 3 files changed, 1974 insertions(+) create mode 100644 datafusion/core/src/physical_plan/sort_merge_join.rs create mode 100644 datafusion/core/tests/join_fuzz.rs diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index e38b9d51212b..98485916870a 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -566,6 +566,7 @@ pub mod metrics; pub mod planner; pub mod projection; pub mod repartition; +pub mod sort_merge_join; pub mod sorts; pub mod stream; pub mod type_coercion; diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs new file mode 100644 index 000000000000..9c769bd96165 --- /dev/null +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -0,0 +1,1750 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the Sort-Merge join execution plan. +//! A sort-merge join plan consumes two sorted children plan and produces +//! joined output by given join type and other options. + +use std::any::Any; +use std::cmp::Ordering; +use std::collections::VecDeque; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::*; +use arrow::compute::{take, SortOptions}; +use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; +use arrow::error::{ArrowError, Result as ArrowResult}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; + +use crate::error::DataFusionError; +use crate::error::Result; +use crate::execution::context::TaskContext; +use crate::logical_plan::JoinType; +use crate::physical_plan::common::combine_batches; +use crate::physical_plan::expressions::Column; +use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::join_utils::{build_join_schema, check_join_is_valid, JoinOn}; +use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::physical_plan::{ + metrics, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, + Statistics, +}; + +/// join execution plan executes partitions in parallel and combines them into a set of +/// partitions. +#[derive(Debug)] +pub struct SortMergeJoinExec { + /// Left sorted joining execution plan + left: Arc, + /// Right sorting joining execution plan + right: Arc, + /// Set of common columns used to join on + on: JoinOn, + /// How the join is performed + join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Sort options of join columns used in sorting left and right execution plans + sort_options: Vec, + /// If null_equals_null is true, null == null else null != null + null_equals_null: bool, +} + +impl SortMergeJoinExec { + /// Tries to create a new [SortMergeJoinExec]. + /// The inputs are sorted using `sort_options` are applied to the columns in the `on` + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + sort_options: Vec, + null_equals_null: bool, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + check_join_is_valid(&left_schema, &right_schema, &on)?; + if sort_options.len() != on.len() { + return Err(DataFusionError::Plan(format!( + "Expected number of sort options: {}, actual: {}", + on.len(), + sort_options.len() + ))); + } + + let schema = + Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); + + Ok(Self { + left, + right, + on, + join_type, + schema, + metrics: ExecutionPlanMetricsSet::new(), + sort_options, + null_equals_null, + }) + } +} + +#[async_trait] +impl ExecutionPlan for SortMergeJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_partitioning(&self) -> Partitioning { + self.right.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.right.output_ordering() + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( + left.clone(), + right.clone(), + self.on.clone(), + self.join_type, + self.sort_options.clone(), + self.null_equals_null, + )?)), + _ => Err(DataFusionError::Internal( + "SortMergeJoin wrong number of children".to_string(), + )), + } + } + + async fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let (streamed, buffered, on_streamed, on_buffered) = match self.join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Anti + | JoinType::Semi => ( + self.left.clone(), + self.right.clone(), + self.on.iter().map(|on| on.0.clone()).collect(), + self.on.iter().map(|on| on.1.clone()).collect(), + ), + JoinType::Right => ( + self.right.clone(), + self.left.clone(), + self.on.iter().map(|on| on.1.clone()).collect(), + self.on.iter().map(|on| on.0.clone()).collect(), + ), + }; + + // execute children plans + let streamed = streamed.execute(partition, context.clone()).await?; + let buffered = buffered.execute(partition, context.clone()).await?; + + // create output buffer + let batch_size = context.session_config().batch_size; + + // create join stream + Ok(Box::pin(SMJStream::try_new( + self.schema.clone(), + self.sort_options.clone(), + self.null_equals_null, + streamed, + buffered, + on_streamed, + on_buffered, + self.join_type, + batch_size, + SortMergeJoinMetrics::new(partition, &self.metrics), + )?)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + todo!() + } +} + +/// Metrics for SortMergeJoinExec +#[allow(dead_code)] +struct SortMergeJoinMetrics { + /// Total time for joining probe-side batches to the build-side batches + join_time: metrics::Time, + /// Number of batches consumed by this operator + input_batches: metrics::Count, + /// Number of rows consumed by this operator + input_rows: metrics::Count, + /// Number of batches produced by this operator + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl SortMergeJoinMetrics { + #[allow(dead_code)] + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + join_time, + input_batches, + input_rows, + output_batches, + output_rows, + } + } +} + +/// State of SMJ stream +#[derive(Debug, PartialEq, Eq)] +enum SMJState { + /// Init joining with a new streamed row or a new buffered batches + Init, + /// Polling one streamed row or one buffered batch, or both + Polling, + /// Joining polled data and making output + JoinOutput, + /// No more output + Exhausted, +} + +/// State of streamed data stream +#[derive(Debug, PartialEq, Eq)] +enum StreamedState { + /// Init polling + Init, + /// Polling one streamed row + Polling, + /// Ready to produce one streamed row + Ready, + /// No more streamed row + Exhausted, +} + +/// State of buffered data stream +#[derive(Debug, PartialEq, Eq)] +enum BufferedState { + /// Init polling + Init, + /// Polling first row in the next batch + PollingFirst, + /// Polling rest rows in the next batch + PollingRest, + /// Ready to produce one batch + Ready, + /// No more buffered batches + Exhausted, +} + +/// A buffered batch that contains contiguous rows with same join key +#[derive(Debug)] +struct BufferedBatch { + /// The buffered record batch + pub batch: RecordBatch, + /// The range in which the rows share the same join key + pub range: Range, + /// Array refs of the join key + pub join_arrays: Vec, +} +impl BufferedBatch { + fn new(batch: RecordBatch, range: Range, on_column: &[Column]) -> Self { + let join_arrays = join_arrays(&batch, on_column); + BufferedBatch { + batch, + range, + join_arrays, + } + } +} + +/// Sort-merge join stream that consumes streamed and buffered data stream +/// and produces joined output +struct SMJStream { + /// Current state of the stream + pub state: SMJState, + /// Output schema + pub schema: SchemaRef, + /// Sort options of join columns used to sort streamed and buffered data stream + pub sort_options: Vec, + /// null == null? + pub null_equals_null: bool, + /// Input schema of streamed + pub streamed_schema: SchemaRef, + /// Input schema of buffered + pub buffered_schema: SchemaRef, + /// Streamed data stream + pub streamed: SendableRecordBatchStream, + /// Buffered data stream + pub buffered: SendableRecordBatchStream, + /// Current processing record batch of streamed + pub streamed_batch: RecordBatch, + /// Current processing streamed join arrays + pub streamed_join_arrays: Vec, + /// Current processing row of streamed + pub streamed_idx: usize, + /// Currrent buffered data + pub buffered_data: BufferedData, + /// (used in outer join) Is current streamed row joined at least once? + pub streamed_joined: bool, + /// (used in outer join) Is current buffered batches joined at least once? + pub buffered_joined: bool, + /// State of streamed + pub streamed_state: StreamedState, + /// State of buffered + pub buffered_state: BufferedState, + /// The comparison result of current streamed row and buffered batches + pub current_ordering: Ordering, + /// Join key columns of streamed + pub on_streamed: Vec, + /// Join key columns of buffered + pub on_buffered: Vec, + /// Staging output array builders + pub output_record_batches: Vec, + /// Staging output size + pub output_size: usize, + /// Target output batch size + pub batch_size: usize, + /// How the join is performed + pub join_type: JoinType, + /// Metrics + pub join_metrics: SortMergeJoinMetrics, +} + +impl RecordBatchStream for SMJStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for SMJStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.join_metrics.join_time.timer(); + loop { + match &self.state { + SMJState::Init => { + let streamed_exhausted = + self.streamed_state == StreamedState::Exhausted; + let buffered_exhausted = + self.buffered_state == BufferedState::Exhausted; + self.state = if streamed_exhausted && buffered_exhausted { + SMJState::Exhausted + } else { + match self.current_ordering { + Ordering::Less | Ordering::Equal => { + if !streamed_exhausted { + self.streamed_joined = false; + self.streamed_state = StreamedState::Init; + } + } + Ordering::Greater => { + if !buffered_exhausted { + self.buffered_joined = false; + self.buffered_state = BufferedState::Init; + } + } + } + SMJState::Polling + }; + } + SMJState::Polling => { + if ![StreamedState::Exhausted, StreamedState::Ready] + .contains(&self.streamed_state) + { + match self.poll_streamed_row(cx)? { + Poll::Ready(_) => {} + Poll::Pending => return Poll::Pending, + } + } + + if ![BufferedState::Exhausted, BufferedState::Ready] + .contains(&self.buffered_state) + { + match self.poll_buffered_batches(cx)? { + Poll::Ready(_) => {} + Poll::Pending => return Poll::Pending, + } + } + let streamed_exhausted = + self.streamed_state == StreamedState::Exhausted; + let buffered_exhausted = + self.buffered_state == BufferedState::Exhausted; + if streamed_exhausted && buffered_exhausted { + self.state = SMJState::Exhausted; + continue; + } + self.current_ordering = self.compare_streamed_buffered()?; + self.state = SMJState::JoinOutput; + } + SMJState::JoinOutput => { + let output_indices = self.join_partial()?; + if !output_indices.is_empty() { + self.output_partial(&output_indices)?; + } + + if self.output_size < self.batch_size { + if self.buffered_data.scanning_finished() { + self.buffered_data.scanning_reset(); + self.state = SMJState::Init; + } + } else { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + } + SMJState::Exhausted => { + if !self.output_record_batches.is_empty() { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + return Poll::Ready(None); + } + } + } + } +} + +impl SMJStream { + #[allow(clippy::too_many_arguments)] + pub fn try_new( + schema: SchemaRef, + sort_options: Vec, + null_equals_null: bool, + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + on_streamed: Vec, + on_buffered: Vec, + join_type: JoinType, + batch_size: usize, + join_metrics: SortMergeJoinMetrics, + ) -> Result { + Ok(Self { + state: SMJState::Init, + sort_options, + null_equals_null, + schema: schema.clone(), + streamed_schema: streamed.schema(), + buffered_schema: buffered.schema(), + streamed, + buffered, + streamed_batch: RecordBatch::new_empty(schema), + streamed_join_arrays: vec![], + streamed_idx: 0, + buffered_data: BufferedData::default(), + streamed_joined: false, + buffered_joined: false, + streamed_state: StreamedState::Init, + buffered_state: BufferedState::Init, + current_ordering: Ordering::Equal, + on_streamed, + on_buffered, + output_record_batches: vec![], + output_size: 0, + batch_size, + join_type, + join_metrics, + }) + } + + /// Poll next streamed row + fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll>> { + loop { + match &self.streamed_state { + StreamedState::Init => { + if self.streamed_idx + 1 < self.streamed_batch.num_rows() { + self.streamed_idx += 1; + self.streamed_state = StreamedState::Ready; + return Poll::Ready(Some(Ok(()))); + } else { + self.streamed_state = StreamedState::Polling; + } + continue; + } + StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + self.streamed_state = StreamedState::Exhausted; + } + Poll::Ready(Some(batch)) => { + if batch.num_rows() > 0 { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + self.streamed_batch = batch; + self.streamed_join_arrays = + join_arrays(&self.streamed_batch, &self.on_streamed); + self.streamed_idx = 0; + self.streamed_state = StreamedState::Ready; + } + } + }, + StreamedState::Ready => { + return Poll::Ready(Some(Ok(()))); + } + StreamedState::Exhausted => { + return Poll::Ready(None); + } + } + } + } + + /// Poll next buffered batches + fn poll_buffered_batches( + &mut self, + cx: &mut Context, + ) -> Poll>> { + loop { + match &self.buffered_state { + BufferedState::Init => { + // pop previous buffered batches + while !self.buffered_data.batches.is_empty() { + let head_batch = self.buffered_data.head_batch(); + if head_batch.range.end == head_batch.batch.num_rows() { + self.buffered_data.batches.pop_front(); + } else { + break; + } + } + if self.buffered_data.batches.is_empty() { + self.buffered_state = BufferedState::PollingFirst; + } else { + let tail_batch = self.buffered_data.tail_batch_mut(); + tail_batch.range.start = tail_batch.range.end; + tail_batch.range.end += 1; + self.buffered_state = BufferedState::PollingRest; + } + } + BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + self.buffered_state = BufferedState::Exhausted; + return Poll::Ready(None); + } + Poll::Ready(Some(batch)) => { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + if batch.num_rows() > 0 { + self.buffered_data.batches.push_back(BufferedBatch::new( + batch, + 0..1, + &self.on_buffered, + )); + self.buffered_state = BufferedState::PollingRest; + } + } + }, + BufferedState::PollingRest => { + if self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().batch.num_rows() + { + while self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().batch.num_rows() + { + if is_join_arrays_equal( + &self.buffered_data.head_batch().join_arrays, + self.buffered_data.head_batch().range.start, + &self.buffered_data.tail_batch().join_arrays, + self.buffered_data.tail_batch().range.end, + )? { + self.buffered_data.tail_batch_mut().range.end += 1; + } else { + self.buffered_state = BufferedState::Ready; + return Poll::Ready(Some(Ok(()))); + } + } + } else { + match self.buffered.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + self.buffered_state = BufferedState::Ready; + } + Poll::Ready(Some(batch)) => { + self.join_metrics.input_batches.add(1); + if batch.num_rows() > 0 { + self.join_metrics.input_rows.add(batch.num_rows()); + self.buffered_data.batches.push_back( + BufferedBatch::new( + batch, + 0..0, + &self.on_buffered, + ), + ); + } + } + } + } + } + BufferedState::Ready => { + return Poll::Ready(Some(Ok(()))); + } + BufferedState::Exhausted => { + return Poll::Ready(None); + } + } + } + } + + /// Get comparison result of streamed row and buffered batches + fn compare_streamed_buffered(&self) -> ArrowResult { + if self.streamed_state == StreamedState::Exhausted { + return Ok(Ordering::Greater); + } + if !self.buffered_data.has_buffered_rows() { + return Ok(Ordering::Less); + } + + return compare_join_arrays( + &self.streamed_join_arrays, + self.streamed_idx, + &self.buffered_data.head_batch().join_arrays, + self.buffered_data.head_batch().range.start, + &self.sort_options, + self.null_equals_null, + ); + } + + /// Produce join and fill output buffer until reaching target batch size + /// or the join is finished + fn join_partial(&mut self) -> ArrowResult> { + let mut join_streamed = false; + let mut join_buffered = false; + + // determine whether we need to join streamed/buffered rows + match self.current_ordering { + Ordering::Less => { + if matches!( + self.join_type, + JoinType::Left | JoinType::Right | JoinType::Full | JoinType::Anti + ) { + join_streamed = !self.streamed_joined; + } + } + Ordering::Equal => { + if matches!(self.join_type, JoinType::Semi) { + join_streamed = !self.streamed_joined; + } + if matches!( + self.join_type, + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full + ) { + join_streamed = true; + join_buffered = true; + }; + } + Ordering::Greater => { + if matches!(self.join_type, JoinType::Full) { + join_buffered = !self.buffered_joined; + }; + } + } + if !join_streamed && !join_buffered { + // no joined data + self.buffered_data.scanning_finish(); + return Ok(vec![]); + } + + let mut output_indices = vec![]; + + if join_buffered { + // joining streamed/nulls and buffered + let streamed_idx = if join_streamed { + Some(self.streamed_idx) + } else { + None + }; + while !self.buffered_data.scanning_finished() + && self.output_size < self.batch_size + { + output_indices.push(OutputIndex { + streamed_idx, + buffered_idx: Some(( + self.buffered_data.scanning_batch_idx, + self.buffered_data.scanning_idx(), + )), + }); + self.output_size += 1; + self.buffered_data.scanning_advance(); + + if self.buffered_data.scanning_finished() { + self.streamed_joined = join_streamed; + self.buffered_joined = true; + } + } + } else { + // joining streamed and nulls + output_indices.push(OutputIndex { + streamed_idx: Some(self.streamed_idx), + buffered_idx: None, + }); + self.output_size += 1; + self.buffered_data.scanning_finish(); + self.streamed_joined = true; + } + Ok(output_indices) + } + + fn output_record_batch_and_reset(&mut self) -> ArrowResult { + assert!(!self.output_record_batches.is_empty()); + + let record_batch = + combine_batches(&self.output_record_batches, self.schema.clone())?.unwrap(); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(record_batch.num_rows()); + self.output_size = 0; + self.output_record_batches.clear(); + Ok(record_batch) + } + + fn output_partial(&mut self, output_indices: &[OutputIndex]) -> ArrowResult<()> { + match self.join_type { + JoinType::Inner => { + self.output_partial_streamed_joining_buffered(output_indices)?; + } + JoinType::Left | JoinType::Right => { + self.output_partial_streamed_joining_buffered(output_indices)?; + self.output_partial_streamed_joining_null(output_indices)?; + } + JoinType::Full => { + self.output_partial_streamed_joining_buffered(output_indices)?; + self.output_partial_streamed_joining_null(output_indices)?; + self.output_partial_null_joining_buffered(output_indices)?; + } + JoinType::Semi | JoinType::Anti => { + self.output_partial_streamed_joining_null(output_indices)?; + } + } + Ok(()) + } + + fn output_partial_streamed_joining_buffered( + &mut self, + output_indices: &[OutputIndex], + ) -> ArrowResult<()> { + let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| { + if indices.is_empty() { + return ArrowResult::Ok(()); + } + + // take streamed columns + let streamed_indices = UInt64Array::from_iter_values( + indices + .iter() + .map(|index| index.streamed_idx.unwrap() as u64), + ); + let mut streamed_columns = self + .streamed_batch + .columns() + .iter() + .map(|column| take(column, &streamed_indices, None)) + .collect::>>()?; + + // take buffered columns + let buffered_indices = UInt64Array::from_iter_values( + indices + .iter() + .map(|index| index.buffered_idx.unwrap().1 as u64), + ); + let mut buffered_columns = self.buffered_data.batches[buffered_batch_idx] + .batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::>>()?; + + // combine columns and produce record batch + let columns = match self.join_type { + JoinType::Inner | JoinType::Left | JoinType::Full => { + streamed_columns.extend(buffered_columns); + streamed_columns + } + JoinType::Right => { + buffered_columns.extend(streamed_columns); + buffered_columns + } + JoinType::Semi | JoinType::Anti => { + unreachable!() + } + }; + let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + self.output_record_batches.push(record_batch); + Ok(()) + }; + + let mut buffered_batch_idx = 0; + let mut indices = vec![]; + for &index in output_indices + .iter() + .filter(|index| index.streamed_idx.is_some()) + .filter(|index| index.buffered_idx.is_some()) + { + let buffered_idx = index.buffered_idx.unwrap(); + if index.buffered_idx.unwrap().0 != buffered_batch_idx { + output(buffered_batch_idx, &indices)?; + buffered_batch_idx = buffered_idx.0; + indices.clear(); + } + indices.push(index); + } + output(buffered_batch_idx, &indices)?; + Ok(()) + } + + fn output_partial_streamed_joining_null( + &mut self, + output_indices: &[OutputIndex], + ) -> ArrowResult<()> { + // streamed joining null + let streamed_indices = UInt64Array::from_iter_values( + output_indices + .iter() + .filter(|index| index.streamed_idx.is_some()) + .filter(|index| index.buffered_idx.is_none()) + .map(|index| index.streamed_idx.unwrap() as u64), + ); + let mut streamed_columns = self + .streamed_batch + .columns() + .iter() + .map(|column| take(column, &streamed_indices, None)) + .collect::>>()?; + + let mut buffered_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), streamed_indices.len())) + .collect::>(); + + let columns = match self.join_type { + JoinType::Inner => { + unreachable!() + } + JoinType::Left | JoinType::Full => { + streamed_columns.extend(buffered_columns); + streamed_columns + } + JoinType::Right => { + buffered_columns.extend(streamed_columns); + buffered_columns + } + JoinType::Anti | JoinType::Semi => streamed_columns, + }; + + if !streamed_indices.is_empty() { + let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + self.output_record_batches.push(record_batch); + } + Ok(()) + } + + fn output_partial_null_joining_buffered( + &mut self, + output_indices: &[OutputIndex], + ) -> ArrowResult<()> { + let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| { + if indices.is_empty() { + return ArrowResult::Ok(()); + } + + // take buffered columns + let buffered_indices = UInt64Array::from_iter_values( + indices + .iter() + .map(|index| index.buffered_idx.unwrap().1 as u64), + ); + let buffered_columns = self.buffered_data.batches[buffered_batch_idx] + .batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::>>()?; + + // create null streamed columns + let mut streamed_columns = self + .streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), buffered_indices.len())) + .collect::>(); + + // combine columns and produce record batch + let columns = match self.join_type { + JoinType::Full => { + streamed_columns.extend(buffered_columns); + streamed_columns + } + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Semi + | JoinType::Anti => { + unreachable!() + } + }; + let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + self.output_record_batches.push(record_batch); + Ok(()) + }; + + let mut buffered_batch_idx = 0; + let mut indices = vec![]; + for &index in output_indices + .iter() + .filter(|index| index.streamed_idx.is_none()) + .filter(|index| index.buffered_idx.is_some()) + { + let buffered_idx = index.buffered_idx.unwrap(); + if buffered_idx.0 != buffered_batch_idx { + output(buffered_batch_idx, &indices)?; + buffered_batch_idx = buffered_idx.0; + indices.clear(); + } + indices.push(index); + } + output(buffered_batch_idx, &indices)?; + Ok(()) + } +} + +/// Buffered data contains all buffered batches with one unique join key +#[derive(Debug, Default)] +struct BufferedData { + /// Buffered batches with the same key + pub batches: VecDeque, + /// current scanning batch index used in join_partial() + pub scanning_batch_idx: usize, + /// current scanning offset used in join_partial() + pub scanning_offset: usize, +} +impl BufferedData { + pub fn head_batch(&self) -> &BufferedBatch { + self.batches.front().unwrap() + } + + pub fn tail_batch(&self) -> &BufferedBatch { + self.batches.back().unwrap() + } + + pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch { + self.batches.back_mut().unwrap() + } + + pub fn has_buffered_rows(&self) -> bool { + self.batches.iter().any(|batch| !batch.range.is_empty()) + } + + pub fn scanning_reset(&mut self) { + self.scanning_batch_idx = 0; + self.scanning_offset = 0; + } + + pub fn scanning_advance(&mut self) { + self.scanning_offset += 1; + while !self.scanning_finished() && self.scanning_batch_finished() { + self.scanning_batch_idx += 1; + self.scanning_offset = 0; + } + } + + pub fn scanning_batch(&self) -> &BufferedBatch { + &self.batches[self.scanning_batch_idx] + } + + pub fn scanning_idx(&self) -> usize { + self.scanning_batch().range.start + self.scanning_offset + } + + pub fn scanning_batch_finished(&self) -> bool { + self.scanning_offset == self.scanning_batch().range.len() + } + + pub fn scanning_finished(&self) -> bool { + self.scanning_batch_idx == self.batches.len() + } + + pub fn scanning_finish(&mut self) { + self.scanning_batch_idx = self.batches.len(); + self.scanning_offset = 0; + } +} + +#[derive(Clone, Copy, Debug)] +struct OutputIndex { + /// joined streamed row index + streamed_idx: Option, + /// joined buffered batch index and row index + buffered_idx: Option<(usize, usize)>, +} + +/// Get join array refs of given batch and join columns +fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec { + on_column + .iter() + .map(|c| batch.column(c.index()).clone()) + .collect() +} + +/// Get comparison result of two rows of join arrays +fn compare_join_arrays( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, + sort_options: &[SortOptions], + null_equals_null: bool, +) -> ArrowResult { + let mut res = Ordering::Equal; + for ((left_array, right_array), sort_options) in + left_arrays.iter().zip(right_arrays).zip(sort_options) + { + macro_rules! compare_value { + ($T:ty) => {{ + let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + let left_value = &left_array.value(left); + let right_value = &right_array.value(right); + res = left_value.partial_cmp(right_value).unwrap(); + if sort_options.descending { + res = res.reverse(); + } + } + (true, false) => { + res = if sort_options.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + res = if sort_options.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + _ => { + res = if null_equals_null { + Ordering::Equal + } else { + Ordering::Less + }; + } + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), + DataType::Utf8 => compare_value!(StringArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Decimal(..) => compare_value!(DecimalArray), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, + _ => { + return Err(ArrowError::NotYetImplemented( + "Unsupported data type in sort merge join comparator".to_owned(), + )); + } + } + if !res.is_eq() { + break; + } + } + Ok(res) +} + +/// A faster version of compare_join_arrays() that only output whether +/// the given two rows are equal +fn is_join_arrays_equal( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, +) -> ArrowResult { + let mut is_equal = true; + for (left_array, right_array) in left_arrays.iter().zip(right_arrays) { + macro_rules! compare_value { + ($T:ty) => {{ + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + let left_array = + left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = + right_array.as_any().downcast_ref::<$T>().unwrap(); + if left_array.value(left) != right_array.value(right) { + is_equal = false; + } + } + (true, false) => is_equal = false, + (false, true) => is_equal = false, + _ => {} + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), + DataType::Utf8 => compare_value!(StringArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Decimal(..) => compare_value!(DecimalArray), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, + _ => { + return Err(ArrowError::NotYetImplemented( + "Unsupported data type in sort merge join comparator".to_owned(), + )); + } + } + if !is_equal { + return Ok(false); + } + } + Ok(true) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::Int32Array; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + use crate::assert_batches_sorted_eq; + use crate::error::Result; + use crate::logical_plan::JoinType; + use crate::physical_plan::expressions::Column; + use crate::physical_plan::join_utils::JoinOn; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::sort_merge_join::SortMergeJoinExec; + use crate::physical_plan::{common, ExecutionPlan}; + use crate::prelude::{SessionConfig, SessionContext}; + use crate::test::{build_table_i32, columns}; + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + /// returns a table with 3 columns of i32 in memory + pub fn build_table_i32_nullable( + a: (&str, &Vec>), + b: (&str, &Vec>), + c: (&str, &Vec>), + ) -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + Field::new(c.0, DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn join( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + ) -> Result { + let sort_options = vec![SortOptions::default(); on.len()]; + SortMergeJoinExec::try_new(left, right, on, join_type, sort_options, false) + } + + fn join_with_options( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + sort_options: Vec, + null_equals_null: bool, + ) -> Result { + SortMergeJoinExec::try_new( + left, + right, + on, + join_type, + sort_options, + null_equals_null, + ) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let sort_options = vec![SortOptions::default(); on.len()]; + join_collect_with_options(left, right, on, join_type, sort_options, false).await + } + + async fn join_collect_with_options( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + sort_options: Vec, + null_equals_null: bool, + ) -> Result<(Vec, Vec)> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let join = join_with_options( + left, + right, + on, + join_type, + sort_options, + null_equals_null, + )?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx).await?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + async fn join_collect_batch_size_equals_two( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let session_ctx = + SessionContext::with_config(SessionConfig::new().with_batch_size(2)); + let task_ctx = session_ctx.task_ctx(); + let join = join(left, right, on, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx).await?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_inner_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_inner_two_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 1, 2]), + ("b2", &vec![1, 1, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 1, 3]), + ("b2", &vec![1, 1, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 1 | 1 | 7 | 1 | 1 | 80 |", + "| 1 | 1 | 8 | 1 | 1 | 70 |", + "| 1 | 1 | 8 | 1 | 1 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_inner_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field + ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), + ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_inner_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]), + ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field + ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]), + ("b2", &vec![Some(2), Some(2), Some(1), None]), + ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + let (_, batches) = join_collect_with_options( + left, + right, + on, + JoinType::Inner, + vec![ + SortOptions { + descending: true, + nulls_first: false + }; + 2 + ], + true, + ) + .await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | | 1 | 1 | | 10 |", + "| 1 | 1 | | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + //assert_eq!(batches.len(), 1); + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] + async fn join_inner_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_full_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_anti() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Anti).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "| 5 | 7 | 11 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_semi() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 5 is double on the right + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Semi).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_with_duplicated_column_names() -> Result<()> { + let left = build_table( + ("a", &vec![1, 2, 3]), + ("b", &vec![4, 5, 7]), + ("c", &vec![7, 8, 9]), + ); + let right = build_table( + ("a", &vec![10, 20, 30]), + ("b", &vec![1, 2, 7]), + ("c", &vec![70, 80, 90]), + ); + let on = vec![( + // join on a=b so there are duplicate column names on unjoined columns + Column::new_with_schema("a", &left.schema())?, + Column::new_with_schema("b", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+---+---+---+----+---+----+", + "| a | b | c | a | b | c |", + "+---+---+---+----+---+----+", + "| 1 | 4 | 7 | 10 | 1 | 70 |", + "| 2 | 5 | 8 | 20 | 2 | 80 |", + "+---+---+---+----+---+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } +} diff --git a/datafusion/core/tests/join_fuzz.rs b/datafusion/core/tests/join_fuzz.rs new file mode 100644 index 000000000000..2f47356236b6 --- /dev/null +++ b/datafusion/core/tests/join_fuzz.rs @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array}; +use arrow::compute::SortOptions; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +use datafusion::logical_plan::JoinType; +use datafusion::physical_plan::collect; +use datafusion::physical_plan::expressions::Column; +use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::sort_merge_join::SortMergeJoinExec; + +use datafusion::prelude::{SessionConfig, SessionContext}; +use fuzz_utils::add_empty_batches; + +#[tokio::test] +async fn test_inner_join_1k() { + run_join_test( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Inner, + ) + .await +} + +#[tokio::test] +async fn test_left_join_1k() { + run_join_test( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Left, + ) + .await +} + +#[tokio::test] +async fn test_right_join_1k() { + run_join_test( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Right, + ) + .await +} + +#[tokio::test] +async fn test_full_join_1k() { + run_join_test( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Full, + ) + .await +} + +#[tokio::test] +async fn test_semi_join_1k() { + run_join_test( + make_staggered_batches(10000), + make_staggered_batches(10000), + JoinType::Semi, + ) + .await +} + +#[tokio::test] +async fn test_anti_join_1k() { + run_join_test( + make_staggered_batches(10000), + make_staggered_batches(10000), + JoinType::Anti, + ) + .await +} + +/// Perform sort-merge join and hash join on same input +/// and verify two outputs are equal +async fn run_join_test( + input1: Vec, + input2: Vec, + join_type: JoinType, +) { + let batch_sizes = [1, 2, 7, 49, 50, 51, 100]; + for batch_size in batch_sizes { + let session_config = SessionConfig::new().with_batch_size(batch_size); + let ctx = SessionContext::with_config(session_config); + let task_ctx = ctx.task_ctx(); + + let schema1 = input1[0].schema(); + let schema2 = input2[0].schema(); + let on_columns = vec![ + ( + Column::new_with_schema("a", &schema1).unwrap(), + Column::new_with_schema("a", &schema2).unwrap(), + ), + ( + Column::new_with_schema("b", &schema1).unwrap(), + Column::new_with_schema("b", &schema2).unwrap(), + ), + ]; + + // sort-merge join + let left = Arc::new( + MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), + ); + let right = Arc::new( + MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), + ); + let smj = Arc::new( + SortMergeJoinExec::try_new( + left, + right, + on_columns.clone(), + join_type, + vec![SortOptions::default(), SortOptions::default()], + false, + ) + .unwrap(), + ); + let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); + + // hash join + let left = Arc::new( + MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), + ); + let right = Arc::new( + MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), + ); + let hj = Arc::new( + HashJoinExec::try_new( + left, + right, + on_columns.clone(), + &join_type, + PartitionMode::Partitioned, + &false, + ) + .unwrap(), + ); + let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + + // compare + let smj_formatted = pretty_format_batches(&smj_collected).unwrap().to_string(); + let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); + + let mut smj_formatted_sorted: Vec<&str> = smj_formatted.trim().lines().collect(); + smj_formatted_sorted.sort_unstable(); + + let mut hj_formatted_sorted: Vec<&str> = hj_formatted.trim().lines().collect(); + hj_formatted_sorted.sort_unstable(); + + for (i, (smj_line, hj_line)) in smj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!((i, smj_line), (i, hj_line)); + } + } +} + +/// Return randomly sized record batches with: +/// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns +/// two random int32 columns 'x', 'y' as other columns +fn make_staggered_batches(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; + let mut input3: Vec = vec![0; len]; + let mut input4: Vec = vec![0; len]; + input12 + .iter_mut() + .for_each(|v| *v = (rng.gen_range(0..100), rng.gen_range(0..100))); + rng.fill(&mut input3[..]); + rng.fill(&mut input4[..]); + input12.sort_unstable(); + let input1 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.0)); + let input2 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.1)); + let input3 = Int32Array::from_iter_values(input3.into_iter()); + let input4 = Int32Array::from_iter_values(input4.into_iter()); + + // split into several record batches + let mut remainder = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(input1) as ArrayRef), + ("b", Arc::new(input2) as ArrayRef), + ("x", Arc::new(input3) as ArrayRef), + ("y", Arc::new(input4) as ArrayRef), + ]) + .unwrap(); + + let mut batches = vec![]; + + // use a random number generator to pick a random sized output + let mut rng = StdRng::seed_from_u64(42); + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + + batches.push(remainder.slice(0, batch_size)); + remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + + add_empty_batches(batches, &mut rng) +}