Skip to content

Commit

Permalink
refine code for hash join exec
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Dec 9, 2022
1 parent 2d1e9ad commit 18c6152
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 451 deletions.
226 changes: 10 additions & 216 deletions datafusion/core/src/physical_plan/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ use arrow::{
DictionaryArray, LargeStringArray, PrimitiveArray, Time32MillisecondArray,
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray,
UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder,
UInt32BufferBuilder, UInt64BufferBuilder,
},
compute,
datatypes::{
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type,
UInt8Type,
Expand All @@ -41,7 +40,7 @@ use std::{time::Instant, vec};

use futures::{ready, Stream, StreamExt, TryStreamExt};

use arrow::array::{new_null_array, Array};
use arrow::array::Array;
use arrow::datatypes::{ArrowNativeType, DataType};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::{ArrowError, Result as ArrowResult};
Expand All @@ -53,7 +52,7 @@ use arrow::array::{
UInt8Array,
};

use datafusion_common::cast::{as_boolean_array, as_dictionary_array, as_string_array};
use datafusion_common::cast::{as_dictionary_array, as_string_array};

use hashbrown::raw::RawTable;

Expand All @@ -66,7 +65,7 @@ use crate::physical_plan::{
joins::utils::{
adjust_right_output_partitioning, build_join_schema, check_join_is_valid,
combine_join_equivalence_properties, estimate_join_statistics,
partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, JoinSide,
partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn,
},
metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet},
DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning,
Expand All @@ -84,6 +83,10 @@ use super::{
utils::{OnceAsync, OnceFut},
PartitionMode,
};
use crate::physical_plan::joins::utils::{
adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices,
get_final_indices, need_produce_result_in_final,
};
use log::debug;
use std::fmt;
use std::task::Poll;
Expand Down Expand Up @@ -647,50 +650,6 @@ impl RecordBatchStream for HashJoinStream {
}
}

/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`.
/// The resulting batch has [Schema] `schema`.
fn build_batch_from_indices(
schema: &Schema,
left: &RecordBatch,
right: &RecordBatch,
left_indices: UInt64Array,
right_indices: UInt32Array,
column_indices: &[ColumnIndex],
) -> ArrowResult<RecordBatch> {
// build the columns of the new [RecordBatch]:
// 1. pick whether the column is from the left or right
// 2. based on the pick, `take` items from the different RecordBatches
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());

for column_index in column_indices {
let array = match column_index.side {
JoinSide::Left => {
let array = left.column(column_index.index);
if array.is_empty() || left_indices.null_count() == left_indices.len() {
// Outer join would generate a null index when finding no match at our side.
// Therefore, it's possible we are empty but need to populate an n-length null array,
// where n is the length of the index array.
assert_eq!(left_indices.null_count(), left_indices.len());
new_null_array(array.data_type(), left_indices.len())
} else {
compute::take(array.as_ref(), &left_indices, None)?
}
}
JoinSide::Right => {
let array = right.column(column_index.index);
if array.is_empty() || right_indices.null_count() == right_indices.len() {
assert_eq!(right_indices.null_count(), right_indices.len());
new_null_array(array.data_type(), right_indices.len())
} else {
compute::take(array.as_ref(), &right_indices, None)?
}
}
};
columns.push(array);
}
RecordBatch::try_new(Arc::new(schema.clone()), columns)
}

// Get left and right indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join
#[allow(clippy::too_many_arguments)]
fn build_join_indices(
Expand Down Expand Up @@ -821,41 +780,6 @@ fn build_equal_condition_join_indices(
))
}

fn apply_join_filter_to_indices(
left: &RecordBatch,
right: &RecordBatch,
left_indices: UInt64Array,
right_indices: UInt32Array,
filter: &JoinFilter,
) -> Result<(UInt64Array, UInt32Array)> {
if left_indices.is_empty() && right_indices.is_empty() {
return Ok((left_indices, right_indices));
};

let intermediate_batch = build_batch_from_indices(
filter.schema(),
left,
right,
PrimitiveArray::from(left_indices.data().clone()),
PrimitiveArray::from(right_indices.data().clone()),
filter.column_indices(),
)?;
let filter_result = filter
.expression()
.evaluate(&intermediate_batch)?
.into_array(intermediate_batch.num_rows());
let mask = as_boolean_array(&filter_result)?;

let left_filtered = PrimitiveArray::<UInt64Type>::from(
compute::filter(&left_indices, mask)?.data().clone(),
);
let right_filtered = PrimitiveArray::<UInt32Type>::from(
compute::filter(&right_indices, mask)?.data().clone(),
);

Ok((left_filtered, right_filtered))
}

macro_rules! equal_rows_elem {
($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{
let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap();
Expand Down Expand Up @@ -1186,138 +1110,6 @@ fn equal_rows(
err.unwrap_or(Ok(res))
}

// The input is the matched indices for left and right.
// Adjust the indices according to the join type
fn adjust_indices_by_join_type(
left_indices: UInt64Array,
right_indices: UInt32Array,
count_right_batch: usize,
join_type: JoinType,
) -> (UInt64Array, UInt32Array) {
match join_type {
JoinType::Inner => {
// matched
(left_indices, right_indices)
}
JoinType::Left => {
// matched
(left_indices, right_indices)
// unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap
}
JoinType::Right | JoinType::Full => {
// matched
// unmatched right row will be produced in this batch
let right_unmatched_indices =
get_anti_indices(count_right_batch, &right_indices);
// combine the matched and unmatched right result together
append_right_indices(left_indices, right_indices, right_unmatched_indices)
}
JoinType::RightSemi => {
// need to remove the duplicated record in the right side
let right_indices = get_semi_indices(count_right_batch, &right_indices);
// the left_indices will not be used later for the `right semi` join
(left_indices, right_indices)
}
JoinType::RightAnti => {
// need to remove the duplicated record in the right side
// get the anti index for the right side
let right_indices = get_anti_indices(count_right_batch, &right_indices);
// the left_indices will not be used later for the `right anti` join
(left_indices, right_indices)
}
JoinType::LeftSemi | JoinType::LeftAnti => {
// matched or unmatched left row will be produced in the end of loop
// TODO: left semi can be optimized.
// When visit the right batch, we can output the matched left row and don't need to wait the end of loop
(
UInt64Array::from_iter_values(vec![]),
UInt32Array::from_iter_values(vec![]),
)
}
}
}

fn append_right_indices(
left_indices: UInt64Array,
right_indices: UInt32Array,
appended_right_indices: UInt32Array,
) -> (UInt64Array, UInt32Array) {
// left_indices, right_indices and appended_right_indices must not contain the null value
if appended_right_indices.is_empty() {
(left_indices, right_indices)
} else {
let unmatched_size = appended_right_indices.len();
// the new left indices: left_indices + null array
// the new right indices: right_indices + appended_right_indices
let new_left_indices = left_indices
.iter()
.chain(std::iter::repeat(None).take(unmatched_size))
.collect::<UInt64Array>();
let new_right_indices = right_indices
.iter()
.chain(appended_right_indices.iter())
.collect::<UInt32Array>();
(new_left_indices, new_right_indices)
}
}

fn get_anti_indices(row_count: usize, input_indices: &UInt32Array) -> UInt32Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
});

// get the anti index
(0..row_count)
.filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u32))
.collect::<UInt32Array>()
}

fn get_semi_indices(row_count: usize, input_indices: &UInt32Array) -> UInt32Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
});

// get the semi index
(0..row_count)
.filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u32))
.collect::<UInt32Array>()
}

fn need_produce_result_in_final(join_type: JoinType) -> bool {
matches!(
join_type,
JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full
)
}

fn get_final_indices(
left_bit_map: &BooleanBufferBuilder,
join_type: JoinType,
) -> (UInt64Array, UInt32Array) {
let left_size = left_bit_map.len();
let left_indices = if join_type == JoinType::LeftSemi {
(0..left_size)
.filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
} else {
// just for `Left`, `LeftAnti` and `Full` join
// `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally
(0..left_size)
.filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
};
// right_indices
// all the element in the right side is None
let mut builder = UInt32Builder::with_capacity(left_indices.len());
builder.append_nulls(left_indices.len());
let right_indices = builder.finish();
(left_indices, right_indices)
}

impl HashJoinStream {
/// Separate implementation function that unpins the [`HashJoinStream`] so
/// that partial borrows work correctly
Expand Down Expand Up @@ -1469,12 +1261,14 @@ mod tests {
test::exec::MockExec,
test::{build_table_i32, columns},
};
use arrow::array::UInt32Builder;
use arrow::array::UInt64Builder;
use arrow::datatypes::Field;
use arrow::error::ArrowError;
use datafusion_expr::Operator;

use super::*;
use crate::physical_plan::joins::utils::JoinSide;
use crate::prelude::SessionContext;
use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::Literal;
Expand Down
Loading

0 comments on commit 18c6152

Please sign in to comment.