Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support nested loop join with the initial version #4562

Merged
merged 7 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 14 additions & 218 deletions datafusion/core/src/physical_plan/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ use arrow::{
DictionaryArray, LargeStringArray, PrimitiveArray, Time32MillisecondArray,
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray,
UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder,
UInt32BufferBuilder, UInt64BufferBuilder,
},
compute,
datatypes::{
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type,
UInt8Type,
Expand All @@ -41,7 +40,7 @@ use std::{time::Instant, vec};

use futures::{ready, Stream, StreamExt, TryStreamExt};

use arrow::array::{new_null_array, Array};
use arrow::array::Array;
use arrow::datatypes::{ArrowNativeType, DataType};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::{ArrowError, Result as ArrowResult};
Expand All @@ -53,7 +52,7 @@ use arrow::array::{
UInt8Array,
};

use datafusion_common::cast::{as_boolean_array, as_dictionary_array, as_string_array};
use datafusion_common::cast::{as_dictionary_array, as_string_array};

use hashbrown::raw::RawTable;

Expand All @@ -66,7 +65,7 @@ use crate::physical_plan::{
joins::utils::{
adjust_right_output_partitioning, build_join_schema, check_join_is_valid,
combine_join_equivalence_properties, estimate_join_statistics,
partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, JoinSide,
partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn,
},
metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet},
DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning,
Expand All @@ -84,6 +83,10 @@ use super::{
utils::{OnceAsync, OnceFut},
PartitionMode,
};
use crate::physical_plan::joins::utils::{
adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices,
get_final_indices_from_bit_map, need_produce_result_in_final,
};
use log::debug;
use std::fmt;
use std::task::Poll;
Expand Down Expand Up @@ -647,50 +650,6 @@ impl RecordBatchStream for HashJoinStream {
}
}

/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`.
/// The resulting batch has [Schema] `schema`.
fn build_batch_from_indices(
schema: &Schema,
left: &RecordBatch,
right: &RecordBatch,
left_indices: UInt64Array,
right_indices: UInt32Array,
column_indices: &[ColumnIndex],
) -> ArrowResult<RecordBatch> {
// build the columns of the new [RecordBatch]:
// 1. pick whether the column is from the left or right
// 2. based on the pick, `take` items from the different RecordBatches
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());

for column_index in column_indices {
let array = match column_index.side {
JoinSide::Left => {
let array = left.column(column_index.index);
if array.is_empty() || left_indices.null_count() == left_indices.len() {
// Outer join would generate a null index when finding no match at our side.
// Therefore, it's possible we are empty but need to populate an n-length null array,
// where n is the length of the index array.
assert_eq!(left_indices.null_count(), left_indices.len());
new_null_array(array.data_type(), left_indices.len())
} else {
compute::take(array.as_ref(), &left_indices, None)?
}
}
JoinSide::Right => {
let array = right.column(column_index.index);
if array.is_empty() || right_indices.null_count() == right_indices.len() {
assert_eq!(right_indices.null_count(), right_indices.len());
new_null_array(array.data_type(), right_indices.len())
} else {
compute::take(array.as_ref(), &right_indices, None)?
}
}
};
columns.push(array);
}
RecordBatch::try_new(Arc::new(schema.clone()), columns)
}

// Get left and right indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join
#[allow(clippy::too_many_arguments)]
fn build_join_indices(
Expand Down Expand Up @@ -821,41 +780,6 @@ fn build_equal_condition_join_indices(
))
}

fn apply_join_filter_to_indices(
left: &RecordBatch,
right: &RecordBatch,
left_indices: UInt64Array,
right_indices: UInt32Array,
filter: &JoinFilter,
) -> Result<(UInt64Array, UInt32Array)> {
if left_indices.is_empty() && right_indices.is_empty() {
return Ok((left_indices, right_indices));
};

let intermediate_batch = build_batch_from_indices(
filter.schema(),
left,
right,
PrimitiveArray::from(left_indices.data().clone()),
PrimitiveArray::from(right_indices.data().clone()),
filter.column_indices(),
)?;
let filter_result = filter
.expression()
.evaluate(&intermediate_batch)?
.into_array(intermediate_batch.num_rows());
let mask = as_boolean_array(&filter_result)?;

let left_filtered = PrimitiveArray::<UInt64Type>::from(
compute::filter(&left_indices, mask)?.data().clone(),
);
let right_filtered = PrimitiveArray::<UInt32Type>::from(
compute::filter(&right_indices, mask)?.data().clone(),
);

Ok((left_filtered, right_filtered))
}

macro_rules! equal_rows_elem {
($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{
let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap();
Expand Down Expand Up @@ -1186,138 +1110,6 @@ fn equal_rows(
err.unwrap_or(Ok(res))
}

// The input is the matched indices for left and right.
// Adjust the indices according to the join type
fn adjust_indices_by_join_type(
left_indices: UInt64Array,
right_indices: UInt32Array,
count_right_batch: usize,
join_type: JoinType,
) -> (UInt64Array, UInt32Array) {
match join_type {
JoinType::Inner => {
// matched
(left_indices, right_indices)
}
JoinType::Left => {
// matched
(left_indices, right_indices)
// unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap
}
JoinType::Right | JoinType::Full => {
// matched
// unmatched right row will be produced in this batch
let right_unmatched_indices =
get_anti_indices(count_right_batch, &right_indices);
// combine the matched and unmatched right result together
append_right_indices(left_indices, right_indices, right_unmatched_indices)
}
JoinType::RightSemi => {
// need to remove the duplicated record in the right side
let right_indices = get_semi_indices(count_right_batch, &right_indices);
// the left_indices will not be used later for the `right semi` join
(left_indices, right_indices)
}
JoinType::RightAnti => {
// need to remove the duplicated record in the right side
// get the anti index for the right side
let right_indices = get_anti_indices(count_right_batch, &right_indices);
// the left_indices will not be used later for the `right anti` join
(left_indices, right_indices)
}
JoinType::LeftSemi | JoinType::LeftAnti => {
// matched or unmatched left row will be produced in the end of loop
// TODO: left semi can be optimized.
// When visit the right batch, we can output the matched left row and don't need to wait the end of loop
(
UInt64Array::from_iter_values(vec![]),
UInt32Array::from_iter_values(vec![]),
)
}
}
}

fn append_right_indices(
left_indices: UInt64Array,
right_indices: UInt32Array,
appended_right_indices: UInt32Array,
) -> (UInt64Array, UInt32Array) {
// left_indices, right_indices and appended_right_indices must not contain the null value
if appended_right_indices.is_empty() {
(left_indices, right_indices)
} else {
let unmatched_size = appended_right_indices.len();
// the new left indices: left_indices + null array
// the new right indices: right_indices + appended_right_indices
let new_left_indices = left_indices
.iter()
.chain(std::iter::repeat(None).take(unmatched_size))
.collect::<UInt64Array>();
let new_right_indices = right_indices
.iter()
.chain(appended_right_indices.iter())
.collect::<UInt32Array>();
(new_left_indices, new_right_indices)
}
}

fn get_anti_indices(row_count: usize, input_indices: &UInt32Array) -> UInt32Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
});

// get the anti index
(0..row_count)
.filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u32))
.collect::<UInt32Array>()
}

fn get_semi_indices(row_count: usize, input_indices: &UInt32Array) -> UInt32Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
});

// get the semi index
(0..row_count)
.filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u32))
.collect::<UInt32Array>()
}

fn need_produce_result_in_final(join_type: JoinType) -> bool {
matches!(
join_type,
JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full
)
}

fn get_final_indices(
left_bit_map: &BooleanBufferBuilder,
join_type: JoinType,
) -> (UInt64Array, UInt32Array) {
let left_size = left_bit_map.len();
let left_indices = if join_type == JoinType::LeftSemi {
(0..left_size)
.filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
} else {
// just for `Left`, `LeftAnti` and `Full` join
// `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally
(0..left_size)
.filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
};
// right_indices
// all the element in the right side is None
let mut builder = UInt32Builder::with_capacity(left_indices.len());
builder.append_nulls(left_indices.len());
let right_indices = builder.finish();
(left_indices, right_indices)
}

impl HashJoinStream {
/// Separate implementation function that unpins the [`HashJoinStream`] so
/// that partial borrows work correctly
Expand Down Expand Up @@ -1413,8 +1205,10 @@ impl HashJoinStream {
if need_produce_result_in_final(self.join_type) && !self.is_exhausted
{
// use the global left bitmap to produce the left indices and right indices
let (left_side, right_side) =
get_final_indices(visited_left_side, self.join_type);
let (left_side, right_side) = get_final_indices_from_bit_map(
visited_left_side,
self.join_type,
);
let empty_right_batch =
RecordBatch::new_empty(self.right.schema());
// use the left and right indices to produce the batch result
Expand Down Expand Up @@ -1469,12 +1263,14 @@ mod tests {
test::exec::MockExec,
test::{build_table_i32, columns},
};
use arrow::array::UInt32Builder;
use arrow::array::UInt64Builder;
use arrow::datatypes::Field;
use arrow::error::ArrowError;
use datafusion_expr::Operator;

use super::*;
use crate::physical_plan::joins::utils::JoinSide;
use crate::prelude::SessionContext;
use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::Literal;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/src/physical_plan/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

mod cross_join;
mod hash_join;
mod nested_loop_join;
mod sort_merge_join;
pub mod utils;

Expand All @@ -36,6 +37,7 @@ pub enum PartitionMode {

pub use cross_join::CrossJoinExec;
pub use hash_join::HashJoinExec;
pub use nested_loop_join::NestedLoopJoinExec;

// Note: SortMergeJoin is not used in plans yet
pub use sort_merge_join::SortMergeJoinExec;
Loading