Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(query): integrate hash join with new filter framework #14689

Merged
merged 13 commits into from
Feb 26, 2024
6 changes: 5 additions & 1 deletion src/query/expression/src/filter/filter_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,11 @@ impl FilterExecutor {
true_idx
}

pub fn mut_true_selection(&mut self) -> &mut [u32] {
pub fn true_selection(&mut self) -> &[u32] {
&self.true_selection
}

pub fn mutable_true_selection(&mut self) -> &mut [u32] {
&mut self.true_selection
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ use databend_common_exception::Result;
use databend_common_expression::arrow::or_validities;
use databend_common_expression::types::nullable::NullableColumn;
use databend_common_expression::types::AnyType;
use databend_common_expression::types::BooleanType;
use databend_common_expression::types::DataType;
use databend_common_expression::BlockEntry;
use databend_common_expression::Column;
use databend_common_expression::DataBlock;
use databend_common_expression::Evaluator;
use databend_common_expression::Expr;
use databend_common_expression::FilterExecutor;
use databend_common_expression::FunctionContext;
use databend_common_expression::Value;
use databend_common_functions::BUILTIN_FUNCTIONS;
use databend_common_sql::executor::cast_expr_to_non_null_boolean;

use super::desc::MARKER_KIND_FALSE;
use super::desc::MARKER_KIND_NULL;
Expand Down Expand Up @@ -93,28 +92,36 @@ impl HashJoinProbeState {
Ok(DataBlock::new_from_columns(vec![marker_column]))
}

// return an (option bitmap, all_true, all_false).
pub(crate) fn get_other_filters(
// return (result data block, filtered indices, all_true, all_false).
pub(crate) fn get_other_predicate_result_block<'a>(
&self,
merged_block: &DataBlock,
filter: &Expr,
func_ctx: &FunctionContext,
) -> Result<(Option<Bitmap>, bool, bool)> {
let filter = cast_expr_to_non_null_boolean(filter.clone())?;
let evaluator = Evaluator::new(merged_block, func_ctx, &BUILTIN_FUNCTIONS);
let predicates = evaluator
.run(&filter)?
.try_downcast::<BooleanType>()
.unwrap();
filter_executor: &'a mut FilterExecutor,
data_block: DataBlock,
) -> Result<(DataBlock, &'a [u32], bool, bool)> {
let origin_count = data_block.num_rows();
let result_block = filter_executor.filter(data_block)?;
let result_count = result_block.num_rows();
Ok((
result_block,
&filter_executor.true_selection()[0..result_count],
result_count == origin_count,
result_count == 0,
))
}

match predicates {
Value::Scalar(v) => Ok((None, v, !v)),
Value::Column(s) => {
let count_zeros = s.unset_bits();
let all_false = s.len() == count_zeros;
Ok((Some(s), count_zeros == 0, all_false))
}
}
// return (result data block, filtered indices, all_true, all_false).
pub(crate) fn get_other_predicate_selection<'a>(
&self,
filter_executor: &'a mut FilterExecutor,
data_block: &DataBlock,
) -> Result<(&'a [u32], bool, bool)> {
let origin_count = data_block.num_rows();
let result_count = filter_executor.select(data_block)?;
Ok((
&filter_executor.true_selection()[0..result_count],
result_count == origin_count,
result_count == 0,
))
}

pub(crate) fn get_nullable_filter_column(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use std::cell::SyncUnsafeCell;
use std::sync::atomic::AtomicU8;
use std::sync::atomic::Ordering;

use databend_common_arrow::arrow::bitmap::Bitmap;
use databend_common_catalog::plan::compute_row_id_prefix;
use databend_common_catalog::plan::split_prefix;
use databend_common_exception::ErrorCode;
Expand Down Expand Up @@ -157,7 +156,7 @@ impl HashJoinProbeState {
&self,
build_indexes: &[RowPtr],
matched_idx: usize,
valids: &Bitmap,
selection: Option<&[u32]>,
) -> Result<()> {
// merge into target table as build side.
if self
Expand All @@ -176,45 +175,70 @@ impl HashJoinProbeState {

let pointer = &merge_into_state.atomic_pointer;
// add matched indexes.
for (idx, row_ptr) in build_indexes[0..matched_idx].iter().enumerate() {
unsafe {
if !valids.get_bit_unchecked(idx) {
continue;
}
}
let offset = if row_ptr.chunk_index == 0 {
row_ptr.row_index as usize
} else {
chunk_offsets[(row_ptr.chunk_index - 1) as usize] as usize
+ row_ptr.row_index as usize
};
if let Some(selection) = selection {
for idx in selection {
let row_ptr = unsafe { build_indexes.get_unchecked(*idx as usize) };
let offset = if row_ptr.chunk_index == 0 {
row_ptr.row_index as usize
} else {
chunk_offsets[(row_ptr.chunk_index - 1) as usize] as usize
+ row_ptr.row_index as usize
};

let mut old_mactehd_counts =
unsafe { (*pointer.0.add(offset)).load(Ordering::Relaxed) };
let mut new_matched_count = old_mactehd_counts + 1;
loop {
if old_mactehd_counts > 0 {
return Err(ErrorCode::UnresolvableConflict(
"multi rows from source match one and the same row in the target_table multi times in probe phase",
));
}
let mut old_mactehd_counts =
unsafe { (*pointer.0.add(offset)).load(Ordering::Relaxed) };
loop {
if old_mactehd_counts > 0 {
return Err(ErrorCode::UnresolvableConflict(
"multi rows from source match one and the same row in the target_table multi times in probe phase",
));
}

let res = unsafe {
(*pointer.0.add(offset)).compare_exchange_weak(
old_mactehd_counts,
new_matched_count,
Ordering::SeqCst,
Ordering::SeqCst,
)
let res = unsafe {
(*pointer.0.add(offset)).compare_exchange_weak(
old_mactehd_counts,
old_mactehd_counts + 1,
Ordering::SeqCst,
Ordering::SeqCst,
)
};
match res {
Ok(_) => break,
Err(x) => old_mactehd_counts = x,
};
}
}
} else {
for row_ptr in &build_indexes[0..matched_idx] {
let offset = if row_ptr.chunk_index == 0 {
row_ptr.row_index as usize
} else {
chunk_offsets[(row_ptr.chunk_index - 1) as usize] as usize
+ row_ptr.row_index as usize
};

match res {
Ok(_) => break,
Err(x) => {
old_mactehd_counts = x;
new_matched_count = old_mactehd_counts + 1;
let mut old_mactehd_counts =
unsafe { (*pointer.0.add(offset)).load(Ordering::Relaxed) };
loop {
if old_mactehd_counts > 0 {
return Err(ErrorCode::UnresolvableConflict(
"multi rows from source match one and the same row in the target_table multi times in probe phase",
));
}
};

let res = unsafe {
(*pointer.0.add(offset)).compare_exchange_weak(
old_mactehd_counts,
old_mactehd_counts + 1,
Ordering::SeqCst,
Ordering::SeqCst,
)
};
match res {
Ok(_) => break,
Err(x) => old_mactehd_counts = x,
};
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,10 @@ use std::sync::atomic::Ordering;

use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::types::BooleanType;
use databend_common_expression::types::DataType;
use databend_common_expression::DataBlock;
use databend_common_expression::Evaluator;
use databend_common_expression::KeyAccessor;
use databend_common_functions::BUILTIN_FUNCTIONS;
use databend_common_hashtable::HashJoinHashtableLike;
use databend_common_hashtable::RowPtr;
use databend_common_sql::executor::cast_expr_to_non_null_boolean;

use crate::pipelines::processors::transforms::hash_join::build_state::BuildBlockGenerationState;
use crate::pipelines::processors::transforms::hash_join::common::wrap_true_validity;
Expand Down Expand Up @@ -183,30 +178,19 @@ impl HashJoinProbeState {
)?);
}

match &self.hash_join_state.hash_join_desc.other_predicate {
match &mut probe_state.filter_executor {
None => Ok(result_blocks),
Some(other_predicate) => {
// Wrap `is_true` to `other_predicate`
let other_predicate = cast_expr_to_non_null_boolean(other_predicate.clone())?;
assert_eq!(other_predicate.data_type(), &DataType::Boolean);

Some(filter_executor) => {
let mut filtered_blocks = Vec::with_capacity(result_blocks.len());
for result_block in result_blocks {
if self.hash_join_state.interrupt.load(Ordering::Relaxed) {
return Err(ErrorCode::AbortedQuery(
"Aborted query, because the server is shutting down or the query was killed.",
));
}

let evaluator =
Evaluator::new(&result_block, &self.func_ctx, &BUILTIN_FUNCTIONS);
let predicate = evaluator
.run(&other_predicate)?
.try_downcast::<BooleanType>()
.unwrap();
let res = result_block.filter_boolean_value(&predicate)?;
if !res.is_empty() {
filtered_blocks.push(res);
let result_block = filter_executor.filter(result_block)?;
if !result_block.is_empty() {
filtered_blocks.push(result_block);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::atomic::Ordering;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::DataBlock;
use databend_common_expression::Expr;
use databend_common_expression::FilterExecutor;
use databend_common_expression::KeyAccessor;
use databend_common_hashtable::HashJoinHashtableLike;
use databend_common_hashtable::RowPtr;
Expand Down Expand Up @@ -112,12 +112,7 @@ impl HashJoinProbeState {

// For anti join, it defaults to false.
let mut row_state = vec![false; input.num_rows()];
let other_predicate = self
.hash_join_state
.hash_join_desc
.other_predicate
.as_ref()
.unwrap();
let filter_executor = probe_state.filter_executor.as_mut().unwrap();

// Results.
let mut matched_idx = 0;
Expand Down Expand Up @@ -152,8 +147,8 @@ impl HashJoinProbeState {
build_indexes,
&mut probe_state.generation_state,
&build_state.generation_state,
other_predicate,
&mut row_state,
filter_executor,
)?;
(matched_idx, incomplete_ptr) = self.fill_probe_and_build_indexes::<_, false>(
hash_table,
Expand Down Expand Up @@ -193,8 +188,8 @@ impl HashJoinProbeState {
build_indexes,
&mut probe_state.generation_state,
&build_state.generation_state,
other_predicate,
&mut row_state,
filter_executor,
)?;
(matched_idx, incomplete_ptr) = self.fill_probe_and_build_indexes::<_, false>(
hash_table,
Expand All @@ -217,8 +212,8 @@ impl HashJoinProbeState {
build_indexes,
&mut probe_state.generation_state,
&build_state.generation_state,
other_predicate,
&mut row_state,
filter_executor,
)?;
}

Expand Down Expand Up @@ -251,8 +246,8 @@ impl HashJoinProbeState {
build_indexes: &[RowPtr],
probe_state: &mut ProbeBlockGenerationState,
build_state: &BuildBlockGenerationState,
other_predicate: &Expr,
row_state: &mut [bool],
filter_executor: &mut FilterExecutor,
) -> Result<()> {
if self.hash_join_state.interrupt.load(Ordering::Relaxed) {
return Err(ErrorCode::AbortedQuery(
Expand Down Expand Up @@ -284,9 +279,9 @@ impl HashJoinProbeState {
let result_block = self.merge_eq_block(probe_block.clone(), build_block, matched_idx);
self.update_row_state(
&result_block,
other_predicate,
&probe_indexes[0..matched_idx],
row_state,
filter_executor,
)?;

Ok(())
Expand Down
Loading
Loading