Skip to content

Commit

Permalink
refactor: runtime filter (#13842)
Browse files Browse the repository at this point in the history
* refactor runtime filter

* executor core

* make storage ok

* create filter

* add runtime filter prunner

* add to async ReadParquetDataSource

* add to async ReadParquetDataSource

* fix

* fix table schema

* fix filter push down

* fix lint

* dedup inlist

* dedup inlist

* fix cluster

* define a RuntimeFilter trait to reduce invade processor core

* fix other source

* fix

* broadcast join

* fix ut

* remove executor logic

* redesign

* resolve comments

* add check
  • Loading branch information
xudong963 authored Dec 11, 2023
1 parent 5de46f1 commit 4b94823
Show file tree
Hide file tree
Showing 44 changed files with 510 additions and 25 deletions.
6 changes: 4 additions & 2 deletions benchmark/tpch/prepare_table.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ options="$1"
# Create Database
echo "CREATE DATABASE IF NOT EXISTS ${MYSQL_DATABASE}" | $BENDSQL_CLIENT_CONNECT_DEFAULT

echo "use ${MYSQL_DATABASE}" | $BENDSQL_CLIENT_CONNECT_DEFAULT

for t in customer lineitem nation orders partsupp part region supplier; do
echo "DROP TABLE IF EXISTS $t" | $BENDSQL_CLIENT_CONNECT
done
Expand Down Expand Up @@ -112,6 +114,6 @@ echo "CREATE TABLE IF NOT EXISTS lineitem
for t in customer lineitem nation orders partsupp part region supplier
do
echo "$t"
insert_sql="insert into $t file_format = (type = CSV skip_header = 0 field_delimiter = '|' record_delimiter = '\n')"
curl -s -u root: -XPUT "http://localhost:${QUERY_HTTP_HANDLER_PORT}/v1/streaming_load" -H "database: tpch" -H "insert_sql: ${insert_sql}" -F 'upload=@"./data/'$t'.tbl"' > /dev/null 2>&1
insert_sql="insert into ${MYSQL_DATABASE}.$t file_format = (type = CSV skip_header = 0 field_delimiter = '|' record_delimiter = '\n')"
curl -s -u root: -XPUT "http://localhost:${QUERY_HTTP_HANDLER_PORT}/v1/streaming_load" -H "database: tpch" -H "insert_sql: ${insert_sql}" -F 'upload=@"./data/'$t'.tbl"'
done
2 changes: 2 additions & 0 deletions src/query/catalog/src/plan/datasource/datasource_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pub struct DataSourcePlan {

// data mask policy for `output_schema` columns
pub data_mask_policy: Option<BTreeMap<FieldIndex, RemoteExpr>>,

pub table_index: usize,
}

impl DataSourcePlan {
Expand Down
5 changes: 5 additions & 0 deletions src/query/catalog/src/table_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use common_base::base::ProgressValues;
use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::DataBlock;
use common_expression::Expr;
use common_expression::FunctionContext;
use common_io::prelude::FormatSettings;
use common_meta_app::principal::FileFormatParams;
Expand Down Expand Up @@ -228,4 +229,8 @@ pub trait TableContext: Send + Sync {

/// Get license key from context, return empty if license is not found or error happened.
fn get_license_key(&self) -> String;

fn set_runtime_filter(&self, filters: (usize, Vec<Expr<String>>));

fn get_runtime_filter_with_id(&self, id: usize) -> Vec<Expr<String>>;
}
6 changes: 6 additions & 0 deletions src/query/expression/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ impl DataBlock {
self.num_rows() == 0
}

// Full empty means no row, no column, no meta
#[inline]
pub fn is_full_empty(&self) -> bool {
self.is_empty() && self.meta.is_none() && self.columns.is_empty()
}

#[inline]
pub fn domains(&self) -> Vec<Domain> {
self.columns
Expand Down
4 changes: 4 additions & 0 deletions src/query/pipeline/sources/src/sync_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ impl<T: 'static + SyncSource> Processor for SyncSourcer<T> {
match self.inner.generate()? {
None => self.is_finish = true,
Some(data_block) => {
if data_block.is_full_empty() {
// A part was pruned by runtime filter
return Ok(());
}
let progress_values = ProgressValues {
rows: data_block.num_rows(),
bytes: data_block.memory_size(),
Expand Down
2 changes: 0 additions & 2 deletions src/query/profile/src/prof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ pub enum OperatorType {
Window,
RowFetch,
Exchange,
RuntimeFilter,
Insert,
ConstantTableScan,
Udf,
Expand All @@ -94,7 +93,6 @@ impl Display for OperatorType {
OperatorType::Window => write!(f, "Window"),
OperatorType::RowFetch => write!(f, "RowFetch"),
OperatorType::Exchange => write!(f, "Exchange"),
OperatorType::RuntimeFilter => write!(f, "RuntimeFilter"),
OperatorType::Insert => write!(f, "Insert"),
OperatorType::CteScan => write!(f, "CteScan"),
OperatorType::ConstantTableScan => write!(f, "ConstantTableScan"),
Expand Down
6 changes: 4 additions & 2 deletions src/query/service/src/pipelines/builders/builder_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,20 @@ impl PipelineBuilder {
}

pub(crate) fn build_join(&mut self, join: &HashJoin) -> Result<()> {
let state = self.build_join_state(join)?;
let id = join.probe.get_table_index();
let state = self.build_join_state(join, id)?;
self.expand_build_side_pipeline(&join.build, join, state.clone())?;
self.build_join_probe(join, state)
}

fn build_join_state(&mut self, join: &HashJoin) -> Result<Arc<HashJoinState>> {
fn build_join_state(&mut self, join: &HashJoin, id: IndexType) -> Result<Arc<HashJoinState>> {
HashJoinState::try_create(
self.ctx.clone(),
join.build.output_schema()?,
&join.build_projections,
HashJoinDesc::create(join)?,
&join.probe_to_build,
id,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl PipelineBuilder {
base_block_ids: None,
update_stream_columns: table.change_tracking_enabled(),
data_mask_policy: None,
table_index: usize::MAX,
};

self.ctx.set_partitions(plan.parts.clone())?;
Expand Down
2 changes: 1 addition & 1 deletion src/query/service/src/pipelines/executor/executor_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ enum State {
Finished,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
struct EdgeInfo {
input_index: usize,
output_index: usize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ pub struct BuildState {
pub(crate) outer_scan_map: Vec<Vec<bool>>,
/// LeftMarkScan map, initialized at `HashJoinBuildState`, used in `HashJoinProbeState`
pub(crate) mark_scan_map: Vec<Vec<u8>>,
/// A copy of build chunks, used by runtime filter.
/// After finishing creating filters, clear it.
pub(crate) build_chunks: Vec<DataBlock>,
}

impl BuildState {
Expand All @@ -31,6 +34,7 @@ impl BuildState {
generation_state: BuildBlockGenerationState::new(),
outer_scan_map: Vec::new(),
mark_scan_map: Vec::new(),
build_chunks: vec![],
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ pub struct HashJoinDesc {
pub(crate) marker_join_desc: MarkJoinDesc,
/// Whether the Join are derived from correlated subquery.
pub(crate) from_correlated_subquery: bool,
pub(crate) probe_keys_rt: Vec<Expr<String>>,
// Under cluster, mark if the join is broadcast join.
pub broadcast: bool,
}

impl HashJoinDesc {
Expand All @@ -56,6 +59,12 @@ impl HashJoinDesc {
.map(|k| k.as_expr(&BUILTIN_FUNCTIONS))
.collect();

let probe_keys_rt: Vec<Expr<String>> = join
.probe_keys_rt
.iter()
.map(|k| k.as_expr(&BUILTIN_FUNCTIONS))
.collect();

Ok(HashJoinDesc {
join_type: join.join_type.clone(),
build_keys,
Expand All @@ -66,6 +75,8 @@ impl HashJoinDesc {
// marker_index: join.marker_index,
},
from_correlated_subquery: join.from_correlated_subquery,
probe_keys_rt,
broadcast: join.broadcast,
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ use crate::pipelines::processors::transforms::hash_join::SingleStringHashJoinHas
use crate::pipelines::processors::HashJoinState;
use crate::sessions::QueryContext;

const INLIST_RUNTIME_FILTER_THRESHOLD: usize = 10_000;

/// Define some shared states for all hash join build threads.
pub struct HashJoinBuildState {
pub(crate) ctx: Arc<QueryContext>,
Expand Down Expand Up @@ -227,6 +229,17 @@ impl HashJoinBuildState {
.build_num_rows
};

let build_chunks =
&mut unsafe { &mut *self.hash_join_state.build_state.get() }.build_chunks;
if build_num_rows <= INLIST_RUNTIME_FILTER_THRESHOLD {
*build_chunks = unsafe {
(*self.hash_join_state.build_state.get())
.generation_state
.chunks
.clone()
};
}

if self.hash_join_state.hash_join_desc.join_type == JoinType::Cross {
return Ok(());
}
Expand Down Expand Up @@ -677,11 +690,21 @@ impl HashJoinBuildState {
.fetch_sub(1, Ordering::Relaxed);
if old_count == 1 {
let build_state = unsafe { &mut *self.hash_join_state.build_state.get() };
info!(
"finish build hash table with {} rows",
build_state.generation_state.build_num_rows
);
let build_num_rows = build_state.generation_state.build_num_rows;
info!("finish build hash table with {} rows", build_num_rows);

let data_blocks = &mut build_state.generation_state.chunks;

if self.hash_join_state.hash_join_desc.join_type == JoinType::Inner
&& self.ctx.get_settings().get_join_spilling_threshold()? == 0
{
let is_cluster = !self.ctx.get_cluster().is_empty();
let is_broadcast_join = self.hash_join_state.hash_join_desc.broadcast;
if !is_cluster || is_broadcast_join {
self.hash_join_state.generate_runtime_filters()?;
}
}

if !data_blocks.is_empty()
&& self.hash_join_state.hash_join_desc.join_type != JoinType::Cross
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::sync::Arc;
use common_base::base::tokio::sync::watch;
use common_base::base::tokio::sync::watch::Receiver;
use common_base::base::tokio::sync::watch::Sender;
use common_catalog::table_context::TableContext;
use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::DataSchemaRef;
Expand All @@ -34,12 +35,14 @@ use common_hashtable::HashtableKeyable;
use common_hashtable::StringHashJoinHashMap;
use common_sql::plans::JoinType;
use common_sql::ColumnSet;
use common_sql::IndexType;
use ethnum::U256;
use parking_lot::RwLock;

use crate::pipelines::processors::transforms::hash_join::build_state::BuildState;
use crate::pipelines::processors::transforms::hash_join::row::RowSpace;
use crate::pipelines::processors::transforms::hash_join::util::build_schema_wrap_nullable;
use crate::pipelines::processors::transforms::hash_join::util::inlist_filter;
use crate::pipelines::processors::HashJoinDesc;
use crate::sessions::QueryContext;

Expand Down Expand Up @@ -74,6 +77,7 @@ pub enum HashJoinHashTable {
/// It will like a bridge to connect build and probe.
/// Such as build side will pass hash table to probe side by it
pub struct HashJoinState {
pub(crate) ctx: Arc<QueryContext>,
/// A shared big hash table stores all the rows from build side
pub(crate) hash_table: SyncUnsafeCell<HashJoinHashTable>,
/// It will be increased by 1 when a new hash join build processor is created.
Expand Down Expand Up @@ -115,6 +119,9 @@ pub struct HashJoinState {
/// tell build processors to restore data in the partition
/// If partition_id is -1, it means all partitions are spilled.
pub(crate) partition_id: AtomicI8,

/// If the join node generate runtime filters, the scan node will use it to do prune.
pub(crate) table_index: IndexType,
}

impl HashJoinState {
Expand All @@ -124,6 +131,7 @@ impl HashJoinState {
build_projections: &ColumnSet,
hash_join_desc: HashJoinDesc,
probe_to_build: &[(usize, (bool, bool))],
table_index: IndexType,
) -> Result<Arc<HashJoinState>> {
if matches!(
hash_join_desc.join_type,
Expand All @@ -137,6 +145,7 @@ impl HashJoinState {
let (build_done_watcher, _build_done_dummy_receiver) = watch::channel(0);
let (continue_build_watcher, _continue_build_dummy_receiver) = watch::channel(false);
Ok(Arc::new(HashJoinState {
ctx: ctx.clone(),
hash_table: SyncUnsafeCell::new(HashJoinHashTable::Null),
hash_table_builders: AtomicUsize::new(0),
build_done_watcher,
Expand All @@ -151,6 +160,7 @@ impl HashJoinState {
continue_build_watcher,
_continue_build_dummy_receiver,
partition_id: AtomicI8::new(-2),
table_index,
}))
}

Expand Down Expand Up @@ -237,4 +247,34 @@ impl HashJoinState {
}
build_state.generation_state.is_build_projected = true;
}

// Generate runtime filters
pub(crate) fn generate_runtime_filters(&self) -> Result<()> {
// If build side rows < 10k, using inlist filter
// TODO: else using bloom filter
let func_ctx = self.ctx.get_function_context()?;
let build_state = unsafe { &mut *self.build_state.get() };
let data_blocks = &mut build_state.build_chunks;

let num_rows = build_state.generation_state.build_num_rows;
if num_rows > 10_000 {
data_blocks.clear();
return Ok(());
}
let mut runtime_filters = Vec::with_capacity(self.hash_join_desc.build_keys.len());
for (build_key, probe_key) in self
.hash_join_desc
.build_keys
.iter()
.zip(self.hash_join_desc.probe_keys_rt.iter())
{
if let Some(filter) = inlist_filter(&func_ctx, build_key, probe_key, data_blocks)? {
runtime_filters.push(filter);
}
}
self.ctx
.set_runtime_filter((self.table_index, runtime_filters));
data_blocks.clear();
Ok(())
}
}
Loading

0 comments on commit 4b94823

Please sign in to comment.