From d456081c61978c1b0eb2b654a37a7a868adb5854 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 2 Jan 2023 17:34:36 +0300 Subject: [PATCH 01/39] Prunable symmetric hash join implementation --- datafusion/core/src/execution/context.rs | 3 + .../physical_optimizer/pipeline_checker.rs | 4 +- .../src/physical_optimizer/pipeline_fixer.rs | 243 ++- .../src/physical_plan/coalesce_batches.rs | 1 + .../src/physical_plan/coalesce_partitions.rs | 2 +- .../core/src/physical_plan/joins/hash_join.rs | 615 +----- .../physical_plan/joins/hash_join_utils.rs | 565 +++++ .../core/src/physical_plan/joins/mod.rs | 16 +- .../physical_plan/joins/nested_loop_join.rs | 5 +- .../joins/symmetric_hash_join.rs | 1916 +++++++++++++++++ .../core/src/physical_plan/joins/utils.rs | 122 +- datafusion/core/src/test_util.rs | 41 + datafusion/core/tests/fifo.rs | 114 +- datafusion/expr/src/operator.rs | 35 + datafusion/physical-expr/Cargo.toml | 2 + .../physical-expr/src/expressions/binary.rs | 6 +- .../physical-expr/src/expressions/cast.rs | 10 +- .../physical-expr/src/expressions/column.rs | 4 +- .../physical-expr/src/intervals/cp_solver.rs | 619 ++++++ .../src/intervals/interval_aritmetics.rs | 697 ++++++ datafusion/physical-expr/src/intervals/mod.rs | 4 + datafusion/physical-expr/src/lib.rs | 2 + .../src/physical_expr_visitor.rs | 98 + datafusion/physical-expr/src/rewrite.rs | 15 + datafusion/physical-expr/src/utils.rs | 188 +- 25 files changed, 4674 insertions(+), 653 deletions(-) create mode 100644 datafusion/core/src/physical_plan/joins/hash_join_utils.rs create mode 100644 datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs create mode 100644 datafusion/physical-expr/src/intervals/cp_solver.rs create mode 100644 datafusion/physical-expr/src/intervals/interval_aritmetics.rs create mode 100644 datafusion/physical-expr/src/intervals/mod.rs create mode 100644 datafusion/physical-expr/src/physical_expr_visitor.rs diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 98fe6ff79758..15619061a7e8 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1464,6 +1464,9 @@ impl SessionState { // repartitioning and local sorting steps to meet distribution and ordering requirements. // Therefore, it should run before EnforceDistribution and EnforceSorting. Arc::new(JoinSelection::new()), + // Enforce sort before PipelineFixer + Arc::new(EnforceDistribution::new()), + Arc::new(EnforceSorting::new()), // If the query is processing infinite inputs, the PipelineFixer rule applies the // necessary transformations to make the query runnable (if it is not already runnable). // If the query can not be made runnable, the rule emits an error with a diagnostic message. diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 96f0b0ff6932..8a6b0e003adf 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -305,7 +305,7 @@ mod sql_tests { FROM test LIMIT 5".to_string(), cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "Window Error".to_string() + error_operator: "Sort Error".to_string() }; case.run().await?; @@ -328,7 +328,7 @@ mod sql_tests { SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 FROM test".to_string(), cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "Window Error".to_string() + error_operator: "Sort Error".to_string() }; case.run().await?; Ok(()) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index ab74ad37e03e..d98f8b269310 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -29,11 +29,24 @@ use crate::physical_optimizer::pipeline_checker::{ check_finiteness_requirements, PipelineStatePropagator, }; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::{HashJoinExec, PartitionMode}; +use crate::physical_plan::joins::utils::{JoinFilter, JoinSide}; +use crate::physical_plan::joins::{ + HashJoinExec, PartitionMode, SortedFilterExpr, SymmetricHashJoinExec, +}; use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::ExecutionPlan; +use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::DataFusionError; use datafusion_expr::logical_plan::JoinType; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; +use datafusion_physical_expr::physical_expr_visitor::{ + ExprVisitable, PhysicalExpressionVisitor, Recursion, +}; +use datafusion_physical_expr::rewrite::TreeNodeRewritable as physical; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + +use std::collections::HashMap; use std::sync::Arc; /// The [PipelineFixer] rule tries to modify a given plan so that it can @@ -57,8 +70,10 @@ impl PhysicalOptimizerRule for PipelineFixer { _config: &ConfigOptions, ) -> Result> { let pipeline = PipelineStatePropagator::new(plan); - let physical_optimizer_subrules: Vec> = - vec![Box::new(hash_join_swap_subrule)]; + let physical_optimizer_subrules: Vec> = vec![ + Box::new(hash_join_convert_symmetric_subrule), + Box::new(hash_join_swap_subrule), + ]; let state = pipeline.transform_up(&|p| { apply_subrules_and_check_finiteness_requirements( p, @@ -77,6 +92,228 @@ impl PhysicalOptimizerRule for PipelineFixer { } } +/// Interval calculations is supported by these operators. This will extend in the future. +fn is_operator_supported(op: &Operator) -> bool { + matches!( + op, + &Operator::Plus + | &Operator::Minus + | &Operator::And + | &Operator::Gt + | &Operator::Lt + ) +} + +/// Currently, integer data types are supported. +fn is_datatype_supported(data_type: &DataType) -> bool { + matches!( + data_type, + &DataType::Int64 + | &DataType::Int32 + | &DataType::Int16 + | &DataType::Int8 + | &DataType::UInt64 + | &DataType::UInt32 + | &DataType::UInt16 + | &DataType::UInt8 + ) +} + +/// We do not support every type of [PhysicalExpr] for interval calculation. Also, we are not +/// supporting every type of [Operator]s in [BinaryExpr]. This check is subject to change while +/// we are adding additional [PhysicalExpr] and [Operator] support. +/// +/// We support [CastExpr], [BinaryExpr], [Column], and [Literal] for interval calculations. +#[derive(Debug)] +pub struct UnsupportedPhysicalExprVisitor<'a> { + /// Supported state + pub supported: &'a mut bool, +} + +impl PhysicalExpressionVisitor for UnsupportedPhysicalExprVisitor<'_> { + fn pre_visit(self, expr: Arc) -> Result> { + if let Some(binary_expr) = expr.as_any().downcast_ref::() { + let op_is_supported = is_operator_supported(binary_expr.op()); + *self.supported = *self.supported && op_is_supported; + } else if !(expr.as_any().is::() + || expr.as_any().is::() + || expr.as_any().is::()) + { + *self.supported = false; + } + Ok(Recursion::Continue(self)) + } +} + +// If it does not contain any unsupported [Operator]. +// Sorting information must cover every column in filter expression. +// Float is unsupported +// Anything beside than BinaryExpr, Col, Literal and operations Ge and Le is currently unsupported. +fn is_suitable_for_symmetric_hash_join(filter: &JoinFilter) -> Result { + let expr: Arc = filter.expression().clone(); + let mut is_expr_supported = true; + expr.accept(UnsupportedPhysicalExprVisitor { + supported: &mut is_expr_supported, + })?; + let is_fields_supported = filter + .schema() + .fields() + .iter() + .all(|f| is_datatype_supported(f.data_type())); + Ok(is_expr_supported && is_fields_supported) +} + +// TODO: Implement CollectLeft, CollectRight and CollectBoth modes for SymmetricHashJoin +fn enforce_symmetric_hash_join( + hash_join: &HashJoinExec, + sorted_columns: Vec, +) -> Result> { + let left = hash_join.left(); + let right = hash_join.right(); + let new_join = Arc::new(SymmetricHashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + hash_join + .on() + .iter() + .map(|(l, r)| (l.clone(), r.clone())) + .collect(), + hash_join.filter().unwrap().clone(), + sorted_columns, + hash_join.join_type(), + hash_join.null_equals_null(), + )?); + Ok(new_join) +} + +/// We assume that the sort information is on a column, not a binary expr. +/// If a + b is sorted, we can not find intervals for 'a" and 'b". +/// +/// Here, we map the main column information to filter intermediate schema. +/// child_orders are derived from child schemas. However, filter schema is different. We enforce +/// that we get the column order for each corresponding [ColumnIndex]. If not, we return None and +/// assume this expression is not capable of pruning. +fn build_filter_input_order( + filter: &JoinFilter, + left_main_schema: SchemaRef, + right_main_schema: SchemaRef, + child_orders: Vec>, +) -> Result>> { + let left_child = child_orders[0]; + let right_child = child_orders[1]; + let mut left_hashmap: HashMap = HashMap::new(); + let mut right_hashmap: HashMap = HashMap::new(); + if left_child.is_none() || right_child.is_none() { + return Ok(None); + } + let left_sort_information = left_child.unwrap(); + let right_sort_information = right_child.unwrap(); + for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { + if index.side.eq(&JoinSide::Left) { + let main_field = left_main_schema.field(index.index); + let main_col = + Column::new_with_schema(main_field.name(), left_main_schema.as_ref())?; + let filter_field = filter.schema().field(filter_schema_index); + let filter_col = Column::new(filter_field.name(), filter_schema_index); + left_hashmap.insert(main_col, filter_col); + } else { + let main_field = right_main_schema.field(index.index); + let main_col = + Column::new_with_schema(main_field.name(), right_main_schema.as_ref())?; + let filter_field = filter.schema().field(filter_schema_index); + let filter_col = Column::new(filter_field.name(), filter_schema_index); + right_hashmap.insert(main_col, filter_col); + } + } + let mut sorted_exps = vec![]; + for sort_expr in left_sort_information { + let expr = sort_expr.expr.clone(); + let new_expr = + expr.transform_up(&|p| convert_filter_columns(p, &left_hashmap))?; + sorted_exps.push(SortedFilterExpr::new( + JoinSide::Left, + sort_expr.expr.clone(), + new_expr.clone(), + sort_expr.options, + )); + } + for sort_expr in right_sort_information { + let expr = sort_expr.expr.clone(); + let new_expr = + expr.transform_up(&|p| convert_filter_columns(p, &right_hashmap))?; + sorted_exps.push(SortedFilterExpr::new( + JoinSide::Right, + sort_expr.expr.clone(), + new_expr.clone(), + sort_expr.options, + )); + } + Ok(Some(sorted_exps)) +} + +fn convert_filter_columns( + input: Arc, + column_change_information: &HashMap, +) -> Result>> { + if let Some(col) = input.as_any().downcast_ref::() { + let filter_col = column_change_information.get(col).unwrap().clone(); + Ok(Some(Arc::new(filter_col))) + } else { + Ok(Some(input)) + } +} + +fn hash_join_convert_symmetric_subrule( + input: &PipelineStatePropagator, +) -> Option> { + let plan = input.plan.clone(); + let children = &input.children_unbounded; + if let Some(hash_join) = plan.as_any().downcast_ref::() { + let (left_unbounded, right_unbounded) = (children[0], children[1]); + let new_plan = if left_unbounded && right_unbounded { + match hash_join.filter() { + Some(filter) => match is_suitable_for_symmetric_hash_join(filter) { + Ok(suitable) => { + if suitable { + let children = input.plan.children(); + let input_order = children + .iter() + .map(|c| c.output_ordering()) + .collect::>(); + match build_filter_input_order( + filter, + hash_join.left().schema(), + hash_join.right().schema(), + input_order, + ) { + Ok(Some(sort_columns)) => { + enforce_symmetric_hash_join(hash_join, sort_columns) + } + Ok(None) => Ok(plan), + Err(e) => return Some(Err(e)), + } + } else { + Ok(plan) + } + } + Err(e) => return Some(Err(e)), + }, + None => Ok(plan), + } + } else { + Ok(plan) + }; + let new_state = new_plan.map(|plan| PipelineStatePropagator { + plan, + unbounded: left_unbounded || right_unbounded, + children_unbounded: vec![left_unbounded, right_unbounded], + }); + Some(new_state) + } else { + None + } +} + /// This subrule will swap build/probe sides of a hash join depending on whether its inputs /// may produce an infinite stream of records. The rule ensures that the left (build) side /// of the hash join always operates on an input stream that will produce a finite set of. diff --git a/datafusion/core/src/physical_plan/coalesce_batches.rs b/datafusion/core/src/physical_plan/coalesce_batches.rs index b50ac08ba5ad..2fd6bc3833d6 100644 --- a/datafusion/core/src/physical_plan/coalesce_batches.rs +++ b/datafusion/core/src/physical_plan/coalesce_batches.rs @@ -236,6 +236,7 @@ impl CoalesceBatchesStream { // reset buffer state self.buffer.clear(); self.buffered_rows = 0; + println!("coalesce yolluyor {}", batch.num_rows()); // return batch return Poll::Ready(Some(Ok(batch))); } diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index 1b847372eed6..3b88d35c3ddc 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -138,7 +138,7 @@ impl ExecutionPlan for CoalescePartitionsExec { // least one result in an attempt to maximize // parallelism. let (sender, receiver) = - mpsc::channel::>(input_partitions); + mpsc::channel::>(100000000000); // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index 16f143560173..589c20b80fdd 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -18,50 +18,38 @@ //! Defines the join plan for executing partitions in parallel and then joining the results //! into a set of partitions. -use ahash::RandomState; - -use arrow::{ - array::{ - ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, Decimal128Array, - DictionaryArray, LargeStringArray, PrimitiveArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, - UInt32BufferBuilder, UInt64BufferBuilder, - }, - datatypes::{ - Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, - }, -}; -use smallvec::{smallvec, SmallVec}; +use std::fmt; use std::sync::Arc; +use std::task::Poll; use std::{any::Any, usize}; use std::{time::Instant, vec}; -use futures::{ready, Stream, StreamExt, TryStreamExt}; +use ahash::RandomState; -use arrow::array::Array; -use arrow::datatypes::{ArrowNativeType, DataType}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; -use arrow::array::{ - Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - StringArray, TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, -}; - -use datafusion_common::cast::{as_dictionary_array, as_string_array}; - +use futures::{ready, Stream, StreamExt, TryStreamExt}; use hashbrown::raw::RawTable; +use log::debug; +use crate::arrow::array::BooleanBufferBuilder; + +use crate::error::{DataFusionError, Result}; +use crate::execution::context::TaskContext; +use crate::logical_expr::JoinType; +use crate::physical_plan::joins::hash_join_utils; +use crate::physical_plan::joins::hash_join_utils::JoinHashMap; +use crate::physical_plan::joins::utils::{ + adjust_indices_by_join_type, build_batch_from_indices, + get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide, +}; use crate::physical_plan::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, expressions::Column, expressions::PhysicalSortExpr, - hash_utils::create_hashes, joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, combine_join_equivalence_properties, estimate_join_statistics, @@ -72,44 +60,10 @@ use crate::physical_plan::{ PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::error::{DataFusionError, Result}; -use crate::logical_expr::JoinType; - -use crate::arrow::array::BooleanBufferBuilder; -use crate::arrow::datatypes::TimeUnit; -use crate::execution::context::TaskContext; - use super::{ utils::{OnceAsync, OnceFut}, PartitionMode, }; -use crate::physical_plan::joins::utils::{ - adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, - get_final_indices_from_bit_map, need_produce_result_in_final, -}; -use log::debug; -use std::fmt; -use std::task::Poll; - -// Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. -// -// Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used -// to put the indices in a certain bucket. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the left side, -// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. -// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 -// As the key is a hash value, we need to check possible hash collisions in the probe stage -// During this stage it might be the case that a row is contained the same hashmap value, -// but the values don't match. Those are checked in the [equal_rows] macro -// TODO: speed up collision check and move away from using a hashbrown HashMap -// https://github.com/apache/arrow-datafusion/issues/50 -struct JoinHashMap(RawTable<(u64, SmallVec<[u64; 1]>)>); - -impl fmt::Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { - Ok(()) - } -} type JoinLeftData = (JoinHashMap, RecordBatch); @@ -529,7 +483,7 @@ async fn collect_left_input( for batch in batches.iter() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); - update_hash( + hash_join_utils::update_hash( &on_left, batch, &mut hashmap, @@ -584,7 +538,7 @@ async fn partitioned_left_input( for batch in batches.iter() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); - update_hash( + hash_join_utils::update_hash( &on_left, batch, &mut hashmap, @@ -608,43 +562,6 @@ async fn partitioned_left_input( Ok((hashmap, single_batch)) } -/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, -/// assuming that the [RecordBatch] corresponds to the `index`th -fn update_hash( - on: &[Column], - batch: &RecordBatch, - hash_map: &mut JoinHashMap, - offset: usize, - random_state: &RandomState, - hashes_buffer: &mut Vec, -) -> Result<()> { - // evaluate the keys - let keys_values = on - .iter() - .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) - .collect::>>()?; - - // calculate the hash values - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - - // insert hashes to key of the hashmap - for (row, hash_value) in hash_values.iter().enumerate() { - let item = hash_map - .0 - .get_mut(*hash_value, |(hash, _)| *hash_value == *hash); - if let Some((_, indices)) = item { - indices.push((row + offset) as u64); - } else { - hash_map.0.insert( - *hash_value, - (*hash_value, smallvec![(row + offset) as u64]), - |(hash, _)| *hash, - ); - } - } - Ok(()) -} - /// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct HashJoinStream { /// Input schema @@ -681,465 +598,6 @@ impl RecordBatchStream for HashJoinStream { } } -// Get left and right indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join -#[allow(clippy::too_many_arguments)] -fn build_join_indices( - batch: &RecordBatch, - left_data: &JoinLeftData, - on_left: &[Column], - on_right: &[Column], - filter: Option<&JoinFilter>, - random_state: &RandomState, - null_equals_null: &bool, -) -> Result<(UInt64Array, UInt32Array)> { - // Get the indices which is satisfies the equal join condition, like `left.a1 = right.a2` - let (left_indices, right_indices) = build_equal_condition_join_indices( - left_data, - batch, - on_left, - on_right, - random_state, - null_equals_null, - )?; - if let Some(filter) = filter { - // Filter the indices which is satisfies the non-equal join condition, like `left.b1 = 10` - apply_join_filter_to_indices( - &left_data.1, - batch, - left_indices, - right_indices, - filter, - ) - } else { - Ok((left_indices, right_indices)) - } -} - -// Returns the index of equal condition join result: left_indices and right_indices -// On LEFT.b1 = RIGHT.b2 -// LEFT Table: -// a1 b1 c1 -// 1 1 10 -// 3 3 30 -// 5 5 50 -// 7 7 70 -// 9 8 90 -// 11 8 110 -// 13 10 130 -// RIGHT Table: -// a2 b2 c2 -// 2 2 20 -// 4 4 40 -// 6 6 60 -// 8 8 80 -// 10 10 100 -// 12 10 120 -// The result is -// "+----+----+-----+----+----+-----+", -// "| a1 | b1 | c1 | a2 | b2 | c2 |", -// "+----+----+-----+----+----+-----+", -// "| 11 | 8 | 110 | 8 | 8 | 80 |", -// "| 13 | 10 | 130 | 10 | 10 | 100 |", -// "| 13 | 10 | 130 | 12 | 10 | 120 |", -// "| 9 | 8 | 90 | 8 | 8 | 80 |", -// "+----+----+-----+----+----+-----+" -// And the result of left and right indices -// left indices: 5, 6, 6, 4 -// right indices: 3, 4, 5, 3 -fn build_equal_condition_join_indices( - left_data: &JoinLeftData, - right: &RecordBatch, - left_on: &[Column], - right_on: &[Column], - random_state: &RandomState, - null_equals_null: &bool, -) -> Result<(UInt64Array, UInt32Array)> { - let keys_values = right_on - .iter() - .map(|c| Ok(c.evaluate(right)?.into_array(right.num_rows()))) - .collect::>>()?; - let left_join_values = left_on - .iter() - .map(|c| Ok(c.evaluate(&left_data.1)?.into_array(left_data.1.num_rows()))) - .collect::>>()?; - let hashes_buffer = &mut vec![0; keys_values[0].len()]; - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - let left = &left_data.0; - // Using a buffer builder to avoid slower normal builder - let mut left_indices = UInt64BufferBuilder::new(0); - let mut right_indices = UInt32BufferBuilder::new(0); - - // Visit all of the right rows - for (row, hash_value) in hash_values.iter().enumerate() { - // Get the hash and find it in the build index - - // For every item on the left and right we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - if let Some((_, indices)) = - left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - for &i in indices { - // Check hash collisions - if equal_rows( - i as usize, - row, - &left_join_values, - &keys_values, - *null_equals_null, - )? { - left_indices.append(i); - right_indices.append(row as u32); - } - } - } - } - let left = ArrayData::builder(DataType::UInt64) - .len(left_indices.len()) - .add_buffer(left_indices.finish()) - .build() - .unwrap(); - let right = ArrayData::builder(DataType::UInt32) - .len(right_indices.len()) - .add_buffer(right_indices.finish()) - .build() - .unwrap(); - - Ok(( - PrimitiveArray::::from(left), - PrimitiveArray::::from(right), - )) -} - -macro_rules! equal_rows_elem { - ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ - let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); - let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); - - match (left_array.is_null($left), right_array.is_null($right)) { - (false, false) => left_array.value($left) == right_array.value($right), - (true, true) => $null_equals_null, - _ => false, - } - }}; -} - -macro_rules! equal_rows_elem_with_string_dict { - ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ - let left_array: &DictionaryArray<$key_array_type> = - as_dictionary_array::<$key_array_type>($l).unwrap(); - let right_array: &DictionaryArray<$key_array_type> = - as_dictionary_array::<$key_array_type>($r).unwrap(); - - let (left_values, left_values_index) = { - let keys_col = left_array.keys(); - if keys_col.is_valid($left) { - let values_index = keys_col - .value($left) - .to_usize() - .expect("Can not convert index to usize in dictionary"); - - ( - as_string_array(left_array.values()).unwrap(), - Some(values_index), - ) - } else { - (as_string_array(left_array.values()).unwrap(), None) - } - }; - let (right_values, right_values_index) = { - let keys_col = right_array.keys(); - if keys_col.is_valid($right) { - let values_index = keys_col - .value($right) - .to_usize() - .expect("Can not convert index to usize in dictionary"); - - ( - as_string_array(right_array.values()).unwrap(), - Some(values_index), - ) - } else { - (as_string_array(right_array.values()).unwrap(), None) - } - }; - - match (left_values_index, right_values_index) { - (Some(left_values_index), Some(right_values_index)) => { - left_values.value(left_values_index) - == right_values.value(right_values_index) - } - (None, None) => $null_equals_null, - _ => false, - } - }}; -} - -/// Left and right row have equal values -/// If more data types are supported here, please also add the data types in can_hash function -/// to generate hash join logical plan. -fn equal_rows( - left: usize, - right: usize, - left_arrays: &[ArrayRef], - right_arrays: &[ArrayRef], - null_equals_null: bool, -) -> Result { - let mut err = None; - let res = left_arrays - .iter() - .zip(right_arrays) - .all(|(l, r)| match l.data_type() { - DataType::Null => { - // lhs and rhs are both `DataType::Null`, so the equal result - // is dependent on `null_equals_null` - null_equals_null - } - DataType::Boolean => { - equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null) - } - DataType::Int8 => { - equal_rows_elem!(Int8Array, l, r, left, right, null_equals_null) - } - DataType::Int16 => { - equal_rows_elem!(Int16Array, l, r, left, right, null_equals_null) - } - DataType::Int32 => { - equal_rows_elem!(Int32Array, l, r, left, right, null_equals_null) - } - DataType::Int64 => { - equal_rows_elem!(Int64Array, l, r, left, right, null_equals_null) - } - DataType::UInt8 => { - equal_rows_elem!(UInt8Array, l, r, left, right, null_equals_null) - } - DataType::UInt16 => { - equal_rows_elem!(UInt16Array, l, r, left, right, null_equals_null) - } - DataType::UInt32 => { - equal_rows_elem!(UInt32Array, l, r, left, right, null_equals_null) - } - DataType::UInt64 => { - equal_rows_elem!(UInt64Array, l, r, left, right, null_equals_null) - } - DataType::Float32 => { - equal_rows_elem!(Float32Array, l, r, left, right, null_equals_null) - } - DataType::Float64 => { - equal_rows_elem!(Float64Array, l, r, left, right, null_equals_null) - } - DataType::Date32 => { - equal_rows_elem!(Date32Array, l, r, left, right, null_equals_null) - } - DataType::Date64 => { - equal_rows_elem!(Date64Array, l, r, left, right, null_equals_null) - } - DataType::Time32(time_unit) => match time_unit { - TimeUnit::Second => { - equal_rows_elem!(Time32SecondArray, l, r, left, right, null_equals_null) - } - TimeUnit::Millisecond => { - equal_rows_elem!(Time32MillisecondArray, l, r, left, right, null_equals_null) - } - _ => { - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - } - DataType::Time64(time_unit) => match time_unit { - TimeUnit::Microsecond => { - equal_rows_elem!(Time64MicrosecondArray, l, r, left, right, null_equals_null) - } - TimeUnit::Nanosecond => { - equal_rows_elem!(Time64NanosecondArray, l, r, left, right, null_equals_null) - } - _ => { - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - } - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => { - equal_rows_elem!( - TimestampSecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - TimeUnit::Millisecond => { - equal_rows_elem!( - TimestampMillisecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - TimeUnit::Microsecond => { - equal_rows_elem!( - TimestampMicrosecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - TimeUnit::Nanosecond => { - equal_rows_elem!( - TimestampNanosecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - }, - DataType::Utf8 => { - equal_rows_elem!(StringArray, l, r, left, right, null_equals_null) - } - DataType::LargeUtf8 => { - equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null) - } - DataType::Decimal128(_, lscale) => match r.data_type() { - DataType::Decimal128(_, rscale) => { - if lscale == rscale { - equal_rows_elem!( - Decimal128Array, - l, - r, - left, - right, - null_equals_null - ) - } else { - err = Some(Err(DataFusionError::Internal( - "Inconsistent Decimal data type in hasher, the scale should be same".to_string(), - ))); - false - } - } - _ => { - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - }, - DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - match key_type.as_ref() { - DataType::Int8 => { - equal_rows_elem_with_string_dict!( - Int8Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::Int16 => { - equal_rows_elem_with_string_dict!( - Int16Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::Int32 => { - equal_rows_elem_with_string_dict!( - Int32Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::Int64 => { - equal_rows_elem_with_string_dict!( - Int64Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt8 => { - equal_rows_elem_with_string_dict!( - UInt8Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt16 => { - equal_rows_elem_with_string_dict!( - UInt16Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt32 => { - equal_rows_elem_with_string_dict!( - UInt32Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt64 => { - equal_rows_elem_with_string_dict!( - UInt64Type, - l, - r, - left, - right, - null_equals_null - ) - } - _ => { - // should not happen - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - } - } - other => { - // This is internal because we should have caught this before. - err = Some(Err(DataFusionError::Internal(format!( - "Unsupported data type in hasher: {other}" - )))); - false - } - }); - - err.unwrap_or(Ok(res)) -} - impl HashJoinStream { /// Separate implementation function that unpins the [`HashJoinStream`] so /// that partial borrows work correctly @@ -1169,7 +627,7 @@ impl HashJoinStream { BooleanBufferBuilder::new(0) } }); - + let mut hashes_buffer = vec![]; self.right .poll_next_unpin(cx) .map(|maybe_batch| match maybe_batch { @@ -1180,14 +638,18 @@ impl HashJoinStream { let timer = self.join_metrics.probe_time.timer(); // get the matched two indices for the on condition - let left_right_indices = build_join_indices( + let left_right_indices = hash_join_utils::build_join_indices( &batch, - left_data, + &left_data.0, + &left_data.1, &self.on_left, &self.on_right, self.filter.as_ref(), &self.random_state, &self.null_equals_null, + &mut hashes_buffer, + None, + JoinSide::Left, ); let result = match left_right_indices { @@ -1215,6 +677,7 @@ impl HashJoinStream { left_side, right_side, &self.column_indices, + JoinSide::Left, ); self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(batch.num_rows()); @@ -1249,6 +712,7 @@ impl HashJoinStream { left_side, right_side, &self.column_indices, + JoinSide::Left, ); if let Ok(ref batch) = result { @@ -1284,27 +748,31 @@ impl Stream for HashJoinStream { #[cfg(test)] mod tests { + use std::sync::Arc; + + use super::*; use crate::physical_expr::expressions::BinaryExpr; + use crate::physical_plan::joins::hash_join_utils::build_equal_condition_join_indices; + use crate::physical_plan::joins::utils::JoinSide; + use crate::prelude::SessionContext; use crate::{ assert_batches_sorted_eq, physical_plan::{ - common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, + common, expressions::Column, hash_utils::create_hashes, memory::MemoryExec, + repartition::RepartitionExec, }, test::exec::MockExec, test::{build_table_i32, columns}, }; use arrow::array::UInt32Builder; use arrow::array::UInt64Builder; - use arrow::datatypes::Field; + use arrow::array::{ArrayRef, Date32Array, Int32Array}; + use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; - use datafusion_expr::Operator; - - use super::*; - use crate::physical_plan::joins::utils::JoinSide; - use crate::prelude::SessionContext; use datafusion_common::ScalarValue; + use datafusion_expr::Operator; use datafusion_physical_expr::expressions::Literal; - use std::sync::Arc; + use smallvec::smallvec; fn build_table( a: (&str, &Vec), @@ -2648,12 +2116,15 @@ mod tests { let left_data = (JoinHashMap(hashmap_left), left); let (l, r) = build_equal_condition_join_indices( - &left_data, + &left_data.0, + &left_data.1, &right, &[Column::new("a", 0)], &[Column::new("a", 0)], &random_state, &false, + &mut vec![0; right.num_rows()], + None, )?; let mut left_ids = UInt64Builder::with_capacity(0); diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs new file mode 100644 index 000000000000..d5423068c079 --- /dev/null +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -0,0 +1,565 @@ +use std::{fmt, usize}; + +use ahash::RandomState; +use arrow::record_batch::RecordBatch; +use arrow::{ + array::{ + Array, ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, + Decimal128Array, DictionaryArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeStringArray, PrimitiveArray, StringArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt32BufferBuilder, UInt64Array, UInt64BufferBuilder, UInt8Array, + }, + datatypes::{ + ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, +}; +use hashbrown::raw::RawTable; +use smallvec::{smallvec, SmallVec}; + +use datafusion_common::cast::{as_dictionary_array, as_string_array}; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::hash_utils::create_hashes; +use datafusion_physical_expr::PhysicalExpr; + +use crate::physical_plan::joins::utils::{ + apply_join_filter_to_indices, JoinFilter, JoinSide, +}; + +/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, +/// assuming that the [RecordBatch] corresponds to the `index`th +pub fn update_hash( + on: &[Column], + batch: &RecordBatch, + hash_map: &mut JoinHashMap, + offset: usize, + random_state: &RandomState, + hashes_buffer: &mut Vec, +) -> datafusion_common::Result<()> { + // evaluate the keys + let keys_values = on + .iter() + .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) + .collect::>>()?; + + // calculate the hash values + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + + // insert hashes to key of the hashmap + for (row, hash_value) in hash_values.iter().enumerate() { + let item = hash_map + .0 + .get_mut(*hash_value, |(hash, _)| *hash_value == *hash); + if let Some((_, indices)) = item { + indices.push((row + offset) as u64); + } else { + hash_map.0.insert( + *hash_value, + (*hash_value, smallvec![(row + offset) as u64]), + |(hash, _)| *hash, + ); + } + } + Ok(()) +} + +// Get left and right indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join +#[allow(clippy::too_many_arguments)] +pub fn build_join_indices( + probe_batch: &RecordBatch, + build_hashmap: &JoinHashMap, + build_input_buffer: &RecordBatch, + on_build: &[Column], + on_probe: &[Column], + filter: Option<&JoinFilter>, + random_state: &RandomState, + null_equals_null: &bool, + hashes_buffer: &mut Vec, + offset: Option, + build_side: JoinSide, +) -> datafusion_common::Result<(UInt64Array, UInt32Array)> { + // Get the indices which is satisfies the equal join condition, like `left.a1 = right.a2` + let (build_indices, probe_indices) = build_equal_condition_join_indices( + build_hashmap, + build_input_buffer, + probe_batch, + on_build, + on_probe, + random_state, + null_equals_null, + hashes_buffer, + offset, + )?; + if let Some(filter) = filter { + // Filter the indices which is satisfies the non-equal join condition, like `left.b1 = 10` + apply_join_filter_to_indices( + build_input_buffer, + probe_batch, + build_indices, + probe_indices, + filter, + build_side, + ) + } else { + Ok((build_indices, probe_indices)) + } +} + +macro_rules! equal_rows_elem { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => left_array.value($left) == right_array.value($right), + (true, true) => $null_equals_null, + _ => false, + } + }}; +} + +macro_rules! equal_rows_elem_with_string_dict { + ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ + let left_array: &DictionaryArray<$key_array_type> = + as_dictionary_array::<$key_array_type>($l).unwrap(); + let right_array: &DictionaryArray<$key_array_type> = + as_dictionary_array::<$key_array_type>($r).unwrap(); + + let (left_values, left_values_index) = { + let keys_col = left_array.keys(); + if keys_col.is_valid($left) { + let values_index = keys_col + .value($left) + .to_usize() + .expect("Can not convert index to usize in dictionary"); + + ( + as_string_array(left_array.values()).unwrap(), + Some(values_index), + ) + } else { + (as_string_array(left_array.values()).unwrap(), None) + } + }; + let (right_values, right_values_index) = { + let keys_col = right_array.keys(); + if keys_col.is_valid($right) { + let values_index = keys_col + .value($right) + .to_usize() + .expect("Can not convert index to usize in dictionary"); + + ( + as_string_array(right_array.values()).unwrap(), + Some(values_index), + ) + } else { + (as_string_array(right_array.values()).unwrap(), None) + } + }; + + match (left_values_index, right_values_index) { + (Some(left_values_index), Some(right_values_index)) => { + left_values.value(left_values_index) + == right_values.value(right_values_index) + } + (None, None) => $null_equals_null, + _ => false, + } + }}; +} + +/// Left and right row have equal values +/// If more data types are supported here, please also add the data types in can_hash function +/// to generate hash join logical plan. +pub fn equal_rows( + left: usize, + right: usize, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], + null_equals_null: bool, +) -> datafusion_common::Result { + let mut err = None; + let res = left_arrays + .iter() + .zip(right_arrays) + .all(|(l, r)| match l.data_type() { + DataType::Null => { + // lhs and rhs are both `DataType::Null`, so the equal result + // is dependent on `null_equals_null` + null_equals_null + } + DataType::Boolean => { + equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null) + } + DataType::Int8 => { + equal_rows_elem!(Int8Array, l, r, left, right, null_equals_null) + } + DataType::Int16 => { + equal_rows_elem!(Int16Array, l, r, left, right, null_equals_null) + } + DataType::Int32 => { + equal_rows_elem!(Int32Array, l, r, left, right, null_equals_null) + } + DataType::Int64 => { + equal_rows_elem!(Int64Array, l, r, left, right, null_equals_null) + } + DataType::UInt8 => { + equal_rows_elem!(UInt8Array, l, r, left, right, null_equals_null) + } + DataType::UInt16 => { + equal_rows_elem!(UInt16Array, l, r, left, right, null_equals_null) + } + DataType::UInt32 => { + equal_rows_elem!(UInt32Array, l, r, left, right, null_equals_null) + } + DataType::UInt64 => { + equal_rows_elem!(UInt64Array, l, r, left, right, null_equals_null) + } + DataType::Float32 => { + equal_rows_elem!(Float32Array, l, r, left, right, null_equals_null) + } + DataType::Float64 => { + equal_rows_elem!(Float64Array, l, r, left, right, null_equals_null) + } + DataType::Date32 => { + equal_rows_elem!(Date32Array, l, r, left, right, null_equals_null) + } + DataType::Date64 => { + equal_rows_elem!(Date64Array, l, r, left, right, null_equals_null) + } + DataType::Time32(time_unit) => match time_unit { + TimeUnit::Second => { + equal_rows_elem!(Time32SecondArray, l, r, left, right, null_equals_null) + } + TimeUnit::Millisecond => { + equal_rows_elem!(Time32MillisecondArray, l, r, left, right, null_equals_null) + } + _ => { + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + } + DataType::Time64(time_unit) => match time_unit { + TimeUnit::Microsecond => { + equal_rows_elem!(Time64MicrosecondArray, l, r, left, right, null_equals_null) + } + TimeUnit::Nanosecond => { + equal_rows_elem!(Time64NanosecondArray, l, r, left, right, null_equals_null) + } + _ => { + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + } + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => { + equal_rows_elem!( + TimestampSecondArray, + l, + r, + left, + right, + null_equals_null + ) + } + TimeUnit::Millisecond => { + equal_rows_elem!( + TimestampMillisecondArray, + l, + r, + left, + right, + null_equals_null + ) + } + TimeUnit::Microsecond => { + equal_rows_elem!( + TimestampMicrosecondArray, + l, + r, + left, + right, + null_equals_null + ) + } + TimeUnit::Nanosecond => { + equal_rows_elem!( + TimestampNanosecondArray, + l, + r, + left, + right, + null_equals_null + ) + } + }, + DataType::Utf8 => { + equal_rows_elem!(StringArray, l, r, left, right, null_equals_null) + } + DataType::LargeUtf8 => { + equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null) + } + DataType::Decimal128(_, lscale) => match r.data_type() { + DataType::Decimal128(_, rscale) => { + if lscale == rscale { + equal_rows_elem!( + Decimal128Array, + l, + r, + left, + right, + null_equals_null + ) + } else { + err = Some(Err(DataFusionError::Internal( + "Inconsistent Decimal data type in hasher, the scale should be same".to_string(), + ))); + false + } + } + _ => { + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + }, + DataType::Dictionary(key_type, value_type) + if *value_type.as_ref() == DataType::Utf8 => + { + match key_type.as_ref() { + DataType::Int8 => { + equal_rows_elem_with_string_dict!( + Int8Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int16 => { + equal_rows_elem_with_string_dict!( + Int16Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int32 => { + equal_rows_elem_with_string_dict!( + Int32Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int64 => { + equal_rows_elem_with_string_dict!( + Int64Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt8 => { + equal_rows_elem_with_string_dict!( + UInt8Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt16 => { + equal_rows_elem_with_string_dict!( + UInt16Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt32 => { + equal_rows_elem_with_string_dict!( + UInt32Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt64 => { + equal_rows_elem_with_string_dict!( + UInt64Type, + l, + r, + left, + right, + null_equals_null + ) + } + _ => { + // should not happen + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + } + } + other => { + // This is internal because we should have caught this before. + err = Some(Err(DataFusionError::Internal(format!( + "Unsupported data type in hasher: {other}" + )))); + false + } + }); + + err.unwrap_or(Ok(res)) +} + +// Returns the index of equal condition join result: left_indices and right_indices +// On LEFT.b1 = RIGHT.b2 +// LEFT Table: +// a1 b1 c1 +// 1 1 10 +// 3 3 30 +// 5 5 50 +// 7 7 70 +// 9 8 90 +// 11 8 110 +// 13 10 130 +// RIGHT Table: +// a2 b2 c2 +// 2 2 20 +// 4 4 40 +// 6 6 60 +// 8 8 80 +// 10 10 100 +// 12 10 120 +// The result is +// "+----+----+-----+----+----+-----+", +// "| a1 | b1 | c1 | a2 | b2 | c2 |", +// "+----+----+-----+----+----+-----+", +// "| 11 | 8 | 110 | 8 | 8 | 80 |", +// "| 13 | 10 | 130 | 10 | 10 | 100 |", +// "| 13 | 10 | 130 | 12 | 10 | 120 |", +// "| 9 | 8 | 90 | 8 | 8 | 80 |", +// "+----+----+-----+----+----+-----+" +// And the result of left and right indices +// left indices: 5, 6, 6, 4 +// right indices: 3, 4, 5, 3 +#[allow(clippy::too_many_arguments)] +pub fn build_equal_condition_join_indices( + build_hashmap: &JoinHashMap, + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_on: &[Column], + probe_on: &[Column], + random_state: &RandomState, + null_equals_null: &bool, + hashes_buffer: &mut Vec, + offset: Option, +) -> datafusion_common::Result<(UInt64Array, UInt32Array)> { + let keys_values = probe_on + .iter() + .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) + .collect::>>()?; + let build_join_values = build_on + .iter() + .map(|c| { + Ok(c.evaluate(build_input_buffer)? + .into_array(build_input_buffer.num_rows())) + }) + .collect::>>()?; + hashes_buffer.clear(); + hashes_buffer.resize(probe_batch.num_rows(), 0); + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + // Using a buffer builder to avoid slower normal builder + let mut left_indices = UInt64BufferBuilder::new(0); + let mut right_indices = UInt32BufferBuilder::new(0); + let offset_value = offset.unwrap_or(0); + // Visit all of the right rows + for (row, hash_value) in hash_values.iter().enumerate() { + // Get the hash and find it in the build index + + // For every item on the left and right we check if it matches + // This possibly contains rows with hash collisions, + // So we have to check here whether rows are equal or not + if let Some((_, indices)) = build_hashmap + .0 + .get(*hash_value, |(hash, _)| *hash_value == *hash) + { + for &i in indices { + // Check hash collisions + let offset_build_index = i as usize - offset_value; + // Check hash collisions + if equal_rows( + offset_build_index, + row, + &build_join_values, + &keys_values, + *null_equals_null, + )? { + left_indices.append(offset_build_index as u64); + right_indices.append(row as u32); + } + } + } + } + let left = ArrayData::builder(DataType::UInt64) + .len(left_indices.len()) + .add_buffer(left_indices.finish()) + .build() + .unwrap(); + let right = ArrayData::builder(DataType::UInt32) + .len(right_indices.len()) + .add_buffer(right_indices.finish()) + .build() + .unwrap(); + + Ok(( + PrimitiveArray::::from(left), + PrimitiveArray::::from(right), + )) +} + +// Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. +// +// Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used +// to put the indices in a certain bucket. +// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the left side, +// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. +// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 +// As the key is a hash value, we need to check possible hash collisions in the probe stage +// During this stage it might be the case that a row is contained the same hashmap value, +// but the values don't match. Those are checked in the [equal_rows] macro +// TODO: speed up collision check and move away from using a hashbrown HashMap +// https://github.com/apache/arrow-datafusion/issues/50 +pub struct JoinHashMap(pub RawTable<(u64, SmallVec<[u64; 1]>)>); + +impl fmt::Debug for JoinHashMap { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/core/src/physical_plan/joins/mod.rs index 63762ab3cf1b..feeefc1605aa 100644 --- a/datafusion/core/src/physical_plan/joins/mod.rs +++ b/datafusion/core/src/physical_plan/joins/mod.rs @@ -17,10 +17,19 @@ //! DataFusion Join implementations +pub use cross_join::CrossJoinExec; +pub use hash_join::HashJoinExec; +pub use nested_loop_join::NestedLoopJoinExec; +// Note: SortMergeJoin is not used in plans yet +pub use sort_merge_join::SortMergeJoinExec; +pub use symmetric_hash_join::{SortedFilterExpr, SymmetricHashJoinExec}; + mod cross_join; mod hash_join; +mod hash_join_utils; mod nested_loop_join; mod sort_merge_join; +mod symmetric_hash_join; pub mod utils; #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -34,10 +43,3 @@ pub enum PartitionMode { /// It will also consider swapping the left and right inputs for the Join Auto, } - -pub use cross_join::CrossJoinExec; -pub use hash_join::HashJoinExec; -pub use nested_loop_join::NestedLoopJoinExec; - -// Note: SortMergeJoin is not used in plans yet -pub use sort_merge_join::SortMergeJoinExec; diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs index 8343925125c0..e95e51869594 100644 --- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs +++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs @@ -24,7 +24,7 @@ use crate::physical_plan::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, check_join_is_valid, combine_join_equivalence_properties, estimate_join_statistics, get_final_indices_from_bit_map, need_produce_result_in_final, ColumnIndex, - JoinFilter, OnceAsync, OnceFut, + JoinFilter, JoinSide, OnceAsync, OnceFut, }; use crate::physical_plan::{ DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, @@ -346,6 +346,7 @@ fn build_join_indices( left_indices, right_indices, filter, + JoinSide::Left, ) } else { Ok((left_indices, right_indices)) @@ -447,6 +448,7 @@ impl NestedLoopJoinStream { left_side, right_side, &self.column_indices, + JoinSide::Left, ); Some(result) } @@ -473,6 +475,7 @@ impl NestedLoopJoinStream { left_side, right_side, &self.column_indices, + JoinSide::Left, ); self.is_exhausted = true; Some(result) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs new file mode 100644 index 000000000000..16652c04edb4 --- /dev/null +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -0,0 +1,1916 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the join plan for executing partitions in parallel and then joining the results +//! into a set of partitions. + +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; +use std::task::Poll; +use std::vec; +use std::{any::Any, usize}; + +use ahash::RandomState; +use arrow::array::PrimitiveArray; +use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveBuilder}; +use arrow::compute::{concat_batches, SortOptions}; +use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::error::{ArrowError, Result as ArrowResult}; +use arrow::record_batch::RecordBatch; +use futures::{Stream, StreamExt}; +use hashbrown::raw::RawTable; +use hashbrown::HashSet; +use itertools::Itertools; + +use datafusion_common::bisect::bisect; +use datafusion_common::ScalarValue; +use datafusion_physical_expr::intervals::interval_aritmetics::{Interval, Range}; +use datafusion_physical_expr::intervals::ExprIntervalGraph; + +use crate::arrow::array::BooleanBufferBuilder; +use crate::error::{DataFusionError, Result}; +use crate::execution::context::TaskContext; +use crate::logical_expr::JoinType; +use crate::physical_plan::common::merge_batches; +use crate::physical_plan::joins::hash_join_utils::update_hash; +use crate::physical_plan::joins::hash_join_utils::{build_join_indices, JoinHashMap}; +use crate::physical_plan::joins::utils::build_batch_from_indices; +use crate::physical_plan::{ + expressions::Column, + expressions::PhysicalSortExpr, + joins::utils::{ + build_join_schema, check_join_is_valid, combine_join_equivalence_properties, + partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, JoinSide, + }, + metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, + PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, +}; + +#[derive(Debug, Clone)] +/// The SortedFilterExpr struct is used to represent a sorted filter expression in the +/// [SymmetricHashJoinExec] struct. It contains information about the join side, the origin +/// expression, the filter expression, and the sort option. +/// The struct has several methods to access and modify its fields. +pub struct SortedFilterExpr { + // Column side + join_side: JoinSide, + // Sorted expr from a particular join side (child) + origin_expr: Arc, + // For interval calculations, one to one mapping of the columns according to filter expression, + // and column indices. + filter_expr: Arc, + // Sort option + sort_option: SortOptions, + // Interval + interval: Option, +} + +impl SortedFilterExpr { + /// Constructor + pub fn new( + join_side: JoinSide, + origin_expr: Arc, + filter_expr: Arc, + sort_option: SortOptions, + ) -> Self { + Self { + join_side, + origin_expr, + filter_expr, + sort_option, + interval: None, + } + } + /// Get origin expr information + pub fn origin_expr(&self) -> Arc { + self.origin_expr.clone() + } + /// Get filter expr information + pub fn filter_expr(&self) -> Arc { + self.filter_expr.clone() + } + /// Get sort information + pub fn sort_option(&self) -> SortOptions { + self.sort_option + } + /// Get interval information + pub fn interval(&self) -> &Option { + &self.interval + } + /// Sets interval + pub fn set_interval(&mut self, interval: Option) { + self.interval = interval; + } +} + +/// The symmetric hash join is a special type of hash join designed for data streams and large +/// datasets. +/// For each input, create a hash table. +/// - For each new record, hash and insert into inputs hash table. +/// - Test if input is equal to a predefined set of other inputs. +/// - If so, output the records, record the visited rows. +/// - If the join type indicates that unmatched rows results must be produced (LEFT, FULL etc.), +/// produce result if a pruning happens or at the end of the data. +pub struct SymmetricHashJoinExec { + /// left side stream + pub(crate) left: Arc, + /// right side stream + pub(crate) right: Arc, + /// Set of common columns used to join on + pub(crate) on: Vec<(Column, Column)>, + /// Filters which are applied while finding matching rows + pub(crate) filter: JoinFilter, + /// How the join is performed + pub(crate) join_type: JoinType, + /// Order information of filter columns + filter_columns: Vec, + /// The schema once the join is applied + schema: SchemaRef, + /// Shares the `RandomState` for the hashing algorithm + random_state: RandomState, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Information of index and left / right placement of columns + column_indices: Vec, + /// If null_equals_null is true, null == null else null != null + pub(crate) null_equals_null: bool, +} + +/// Metrics for HashJoinExec +#[derive(Debug)] +struct SymmetricHashJoinMetrics { + /// Number of left batches consumed by this operator + left_input_batches: metrics::Count, + /// Number of right batches consumed by this operator + right_input_batches: metrics::Count, + /// Number of left rows consumed by this operator + left_input_rows: metrics::Count, + /// Number of right rows consumed by this operator + right_input_rows: metrics::Count, + /// Number of batches produced by this operator + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl SymmetricHashJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let left_input_batches = + MetricBuilder::new(metrics).counter("left_input_batches", partition); + let right_input_batches = + MetricBuilder::new(metrics).counter("right_input_batches", partition); + + let left_input_rows = + MetricBuilder::new(metrics).counter("left_input_rows", partition); + + let right_input_rows = + MetricBuilder::new(metrics).counter("right_input_rows", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + left_input_batches, + right_input_batches, + left_input_rows, + right_input_rows, + output_batches, + output_rows, + } + } +} + +impl SymmetricHashJoinExec { + /// Tries to create a new [SymmetricHashJoinExec]. + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + /// TODO: Support for CollectLeft (or CorrectRight options.) + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + filter_columns: Vec, + join_type: &JoinType, + null_equals_null: &bool, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + if on.is_empty() { + return Err(DataFusionError::Plan( + "On constraints in HashJoinExec should be non-empty".to_string(), + )); + } + + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let (schema, column_indices) = + build_join_schema(&left_schema, &right_schema, join_type); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + Ok(SymmetricHashJoinExec { + left, + right, + on, + filter, + filter_columns, + join_type: *join_type, + schema: Arc::new(schema), + random_state, + metrics: ExecutionPlanMetricsSet::new(), + column_indices, + null_equals_null: *null_equals_null, + }) + } + + /// left stream + pub fn left(&self) -> &Arc { + &self.left + } + + /// right stream + pub fn right(&self) -> &Arc { + &self.right + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(Column, Column)] { + &self.on + } + + /// Filters applied before join output + pub fn filter(&self) -> &JoinFilter { + &self.filter + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } + + /// Get null_equals_null + pub fn null_equals_null(&self) -> &bool { + &self.null_equals_null + } +} + +impl Debug for SymmetricHashJoinExec { + fn fmt(&self, _f: &mut Formatter<'_>) -> fmt::Result { + todo!() + } +} + +impl ExecutionPlan for SymmetricHashJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn required_input_ordering(&self) -> Vec> { + vec![] + } + + fn unbounded_output(&self, children: &[bool]) -> Result { + let (left, right) = (children[0], children[1]); + Ok(left || right) + } + + fn required_input_distribution(&self) -> Vec { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + // TODO: This will change when we extend collected executions. + vec![ + if self.left.output_partitioning().partition_count() == 1 { + Distribution::SinglePartition + } else { + Distribution::HashPartitioned(left_expr) + }, + if self.right.output_partitioning().partition_count() == 1 { + Distribution::SinglePartition + } else { + Distribution::HashPartitioned(right_expr) + }, + ] + } + + fn output_partitioning(&self) -> Partitioning { + let left_columns_len = self.left.schema().fields.len(); + partitioned_join_output_partitioning( + self.join_type, + self.left.output_partitioning(), + self.right.output_partitioning(), + left_columns_len, + ) + } + // TODO Output ordering might be kept for some cases. + // For example if it is inner join then the stream side order can be kept + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + let left_columns_len = self.left.schema().fields.len(); + combine_join_equivalence_properties( + self.join_type, + self.left.equivalence_properties(), + self.right.equivalence_properties(), + left_columns_len, + self.on(), + self.schema(), + ) + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(SymmetricHashJoinExec::try_new( + children[0].clone(), + children[1].clone(), + self.on.clone(), + self.filter.clone(), + self.filter_columns.clone(), + &self.join_type, + &self.null_equals_null, + )?)) + } + + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default => { + let display_filter = format!(", filter={:?}", self.filter.expression()); + write!( + f, + "SymmetricHashJoinExec: join_type={:?}, on={:?}{}", + self.join_type, self.on, display_filter + ) + } + } + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + // TODO stats: it is not possible in general to know the output size of joins + Statistics::default() + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); + let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); + // TODO: Currently, working on partitioned left and right. We can also coalesce. + let left_side_joiner = OneSideHashJoiner::new(JoinSide::Left, self.left.clone()); + let right_side_joiner = + OneSideHashJoiner::new(JoinSide::Right, self.right.clone()); + // TODO: Discuss unbounded mpsc and bounded one. + let left_stream = self.left.execute(partition, context.clone())?; + let right_stream = self.right.execute(partition, context)?; + let physical_expr_graph = ExprIntervalGraph::try_new( + self.filter.expression().clone(), + &self + .filter_columns + .iter() + .map(|sorted_expr| sorted_expr.filter_expr()) + .collect_vec(), + )?; + + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + on_left, + on_right, + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + join_metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), + physical_expr_graph, + null_equals_null: self.null_equals_null, + filter_columns: self.filter_columns.clone(), + final_result: false, + data_side: JoinSide::Left, + })) + } +} + +/// A stream that issues [RecordBatch]es as they arrive from the right of the join. +struct SymmetricHashJoinStream { + left_stream: SendableRecordBatchStream, + right_stream: SendableRecordBatchStream, + /// Input schema + schema: Arc, + /// columns from the left + on_left: Vec, + /// columns from the right used to compute the hash + on_right: Vec, + /// join filter + filter: JoinFilter, + /// type of the join + join_type: JoinType, + // left + left: OneSideHashJoiner, + /// right + right: OneSideHashJoiner, + /// Information of index and left / right placement of columns + column_indices: Vec, + // Range Prunner. + physical_expr_graph: ExprIntervalGraph, + /// Information of filter columns + filter_columns: Vec, + /// Random state used for hashing initialization + random_state: RandomState, + /// If null_equals_null is true, null == null else null != null + null_equals_null: bool, + /// Metrics + join_metrics: SymmetricHashJoinMetrics, + /// There is nothing to process anymore + final_result: bool, + data_side: JoinSide, +} + +impl RecordBatchStream for SymmetricHashJoinStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for SymmetricHashJoinStream { + type Item = ArrowResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_next_impl(cx) + } +} + +fn prune_hash_values( + prune_length: usize, + hashmap: &mut JoinHashMap, + row_hash_values: &mut VecDeque, + offset: u64, +) -> Result<()> { + // Create a (hash)-(row number set) map + let mut hash_value_map: HashMap> = HashMap::new(); + for index in 0..prune_length { + let hash_value = row_hash_values.pop_front().unwrap(); + if let Some(set) = hash_value_map.get_mut(&hash_value) { + set.insert(offset + index as u64); + } else { + let mut set = HashSet::new(); + set.insert(offset + index as u64); + hash_value_map.insert(hash_value, set); + } + } + for (hash_value, index_set) in hash_value_map.iter() { + if let Some((_, separation_chain)) = hashmap + .0 + .get_mut(*hash_value, |(hash, _)| *hash_value == *hash) + { + separation_chain.retain(|n| !index_set.contains(n)); + if separation_chain.is_empty() { + hashmap + .0 + .remove_entry(*hash_value, |(hash, _)| *hash_value == *hash); + } + } + } + Ok(()) +} + +fn prune_visited_rows( + prune_length: usize, + visited_rows: &mut HashSet, + deleted_offset: usize, +) -> Result<()> { + (deleted_offset..(deleted_offset + prune_length)).for_each(|row| { + visited_rows.remove(&row); + }); + Ok(()) +} + +/// We can prune build side when a probe batch comes. +fn column_stats_two_side( + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + filter_columns: &mut [SortedFilterExpr], + build_side: JoinSide, +) -> Result<()> { + for sorted_expr in filter_columns.iter_mut() { + let SortedFilterExpr { + join_side, + origin_expr, + sort_option, + interval, + .. + } = sorted_expr; + let array = if build_side.eq(join_side) { + // Get first value for expr + origin_expr + .evaluate(&build_input_buffer.slice(0, 1))? + .into_array(1) + } else { + // Get last value for expr + origin_expr + .evaluate(&probe_batch.slice(probe_batch.num_rows() - 1, 1))? + .into_array(1) + }; + let value = ScalarValue::try_from_array(&array, 0)?; + let infinite = ScalarValue::try_from(value.get_datatype())?; + *interval = Some(if sort_option.descending { + Interval::Range(Range { + lower: infinite, + upper: value, + }) + } else { + Interval::Range(Range { + lower: value, + upper: infinite, + }) + }); + } + Ok(()) +} + +fn determine_prune_length( + buffer: &RecordBatch, + filter_columns: &[SortedFilterExpr], + build_side: JoinSide, +) -> Result { + Ok(filter_columns + .iter() + .flat_map(|sorted_expr| { + let SortedFilterExpr { + join_side, + origin_expr, + sort_option, + interval, + .. + } = sorted_expr; + if build_side.eq(join_side) { + let batch_arr = origin_expr + .evaluate(buffer) + .unwrap() + .into_array(buffer.num_rows()); + let target = if sort_option.descending { + interval.as_ref()?.upper_value() + } else { + interval.as_ref()?.lower_value() + }; + Some(bisect::(&[batch_arr], &[target], &[*sort_option])) + } else { + None + } + }) + .collect::>>() + .into_iter() + .collect::>>()? + .into_iter() + .min() + .unwrap()) +} + +fn need_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool { + if build_side.eq(&JoinSide::Left) { + matches!( + join_type, + JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi + ) + } else { + matches!( + join_type, + JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi + ) + } +} + +fn get_anti_indices( + prune_length: usize, + deleted_offset: usize, + visited_rows: &HashSet, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(prune_length); + bitmap.append_n(prune_length, false); + (0..prune_length).for_each(|v| { + let row = &(v + deleted_offset); + bitmap.set_bit(v, visited_rows.contains(row)); + }); + // get the anti index + (0..prune_length) + .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) + .collect::>() +} + +fn get_semi_indices( + prune_length: usize, + deleted_offset: usize, + visited_rows: &HashSet, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(prune_length); + bitmap.append_n(prune_length, false); + (0..prune_length).for_each(|v| { + let row = &(v + deleted_offset); + bitmap.set_bit(v, visited_rows.contains(row)); + }); + // get the semi index + (0..prune_length) + .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) + .collect::>() +} + +fn record_visited_indices( + visited: &mut HashSet, + offset: usize, + indices: &PrimitiveArray, +) { + let batch_indices: &[T::Native] = indices.values(); + for i in batch_indices { + visited.insert(i.as_usize() + offset); + } +} + +fn calculate_indices_by_join_type( + build_side: JoinSide, + prune_length: usize, + visited_rows: &HashSet, + deleted_offset: usize, + join_type: JoinType, +) -> Result<(PrimitiveArray, PrimitiveArray)> +where + NativeAdapter: From<::Native>, +{ + let result = match (build_side, join_type) { + (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) + | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) + | (_, JoinType::Full) => { + let build_unmatched_indices = + get_anti_indices(prune_length, deleted_offset, visited_rows); + // right_indices + // all the element in the right side is None + let mut builder = + PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); + builder.append_nulls(build_unmatched_indices.len()); + let probe_indices = builder.finish(); + (build_unmatched_indices, probe_indices) + } + (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => { + let build_unmatched_indices = + get_semi_indices(prune_length, deleted_offset, visited_rows); + let mut builder = + PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); + builder.append_nulls(build_unmatched_indices.len()); + let probe_indices = builder.finish(); + (build_unmatched_indices, probe_indices) + } + _ => unreachable!(), + }; + Ok(result) +} + +struct OneSideHashJoiner { + // Build side + build_side: JoinSide, + // Inout record batch buffer + input_buffer: RecordBatch, + /// Hashmap + hashmap: JoinHashMap, + /// To optimize hash deleting in case of pruning, we hold them in memory + row_hash_values: VecDeque, + /// Reuse the hashes buffer + hashes_buffer: Vec, + /// Matched rows + visited_rows: HashSet, + /// Offset + offset: usize, + /// Deleted offset + deleted_offset: usize, + /// Side is exhausted + exhausted: bool, +} + +impl OneSideHashJoiner { + pub fn new(build_side: JoinSide, plan: Arc) -> Self { + Self { + build_side, + input_buffer: RecordBatch::new_empty(plan.schema()), + hashmap: JoinHashMap(RawTable::with_capacity(10_000)), + row_hash_values: VecDeque::new(), + hashes_buffer: vec![], + visited_rows: HashSet::new(), + offset: 0, + deleted_offset: 0, + exhausted: false, + } + } + + fn update_internal_state( + &mut self, + on_build: &[Column], + batch: &RecordBatch, + random_state: &RandomState, + ) -> Result<()> { + self.input_buffer = + merge_batches(&self.input_buffer, batch, batch.schema()).unwrap(); + self.hashes_buffer.resize(batch.num_rows(), 0); + update_hash( + on_build, + batch, + &mut self.hashmap, + self.offset, + random_state, + &mut self.hashes_buffer, + )?; + self.hashes_buffer + .drain(0..) + .into_iter() + .for_each(|hash| self.row_hash_values.push_back(hash)); + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + fn record_batch_from_other_side( + &mut self, + schema: SchemaRef, + join_type: JoinType, + on_build: &[Column], + on_probe: &[Column], + filter: Option<&JoinFilter>, + probe_batch: &RecordBatch, + probe_visited: &mut HashSet, + probe_offset: usize, + column_indices: &[ColumnIndex], + random_state: &RandomState, + null_equals_null: &bool, + ) -> ArrowResult> { + if self.input_buffer.num_rows() == 0 { + return Ok(Some(RecordBatch::new_empty(schema))); + } + let (build_side, probe_side) = build_join_indices( + probe_batch, + &self.hashmap, + &self.input_buffer, + on_build, + on_probe, + filter, + random_state, + null_equals_null, + &mut self.hashes_buffer, + Some(self.deleted_offset), + self.build_side, + )?; + if need_produce_result_in_final(self.build_side, join_type) { + record_visited_indices( + &mut self.visited_rows, + self.deleted_offset, + &build_side, + ); + } + if need_produce_result_in_final(self.build_side.negate(), join_type) { + record_visited_indices(probe_visited, probe_offset, &probe_side); + } + match (self.build_side, join_type) { + ( + _, + JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi, + ) => Ok(None), + (_, _) => { + let res = build_batch_from_indices( + schema.as_ref(), + &self.input_buffer, + probe_batch, + build_side.clone(), + probe_side.clone(), + column_indices, + self.build_side, + )?; + Ok(Some(res)) + } + } + } + + fn build_side_determined_results( + &self, + output_schema: SchemaRef, + prune_length: usize, + probe_schema: SchemaRef, + join_type: JoinType, + column_indices: &[ColumnIndex], + ) -> ArrowResult> { + let result = if need_produce_result_in_final(self.build_side, join_type) { + let (build_indices, probe_indices) = calculate_indices_by_join_type( + self.build_side, + prune_length, + &self.visited_rows, + self.deleted_offset, + join_type, + )?; + let empty_probe_batch = RecordBatch::new_empty(probe_schema); + Some(build_batch_from_indices( + output_schema.as_ref(), + &self.input_buffer, + &empty_probe_batch, + build_indices, + probe_indices, + column_indices, + self.build_side, + )?) + } else { + None + }; + Ok(result) + } + + fn prune_build_side( + &mut self, + schema: SchemaRef, + probe_batch: &RecordBatch, + filter_columns: &mut [SortedFilterExpr], + join_type: JoinType, + column_indices: &[ColumnIndex], + physical_expr_graph: &mut ExprIntervalGraph, + ) -> ArrowResult> { + if self.input_buffer.num_rows() == 0 { + return Ok(None); + } + column_stats_two_side( + &self.input_buffer, + probe_batch, + filter_columns, + self.build_side, + )?; + let mut filter_intervals: Vec<(Arc, Interval)> = filter_columns + .iter() + .map(|sorted_expr| { + ( + sorted_expr.filter_expr(), + sorted_expr.interval().as_ref().unwrap().clone(), + ) + }) + .collect_vec(); + physical_expr_graph.calculate_new_intervals(&mut filter_intervals)?; + for (sorted_expr, (_, interval)) in + filter_columns.iter_mut().zip(filter_intervals.into_iter()) + { + sorted_expr.set_interval(Some(interval.clone())) + } + + let prune_length = + determine_prune_length(&self.input_buffer, filter_columns, self.build_side)?; + if prune_length > 0 { + let result = self.build_side_determined_results( + schema, + prune_length, + probe_batch.schema(), + join_type, + column_indices, + ); + prune_hash_values( + prune_length, + &mut self.hashmap, + &mut self.row_hash_values, + self.deleted_offset as u64, + )?; + prune_visited_rows( + prune_length, + &mut self.visited_rows, + self.deleted_offset, + )?; + self.input_buffer = self + .input_buffer + .slice(prune_length, self.input_buffer.num_rows() - prune_length); + self.deleted_offset += prune_length; + result + } else { + Ok(None) + } + } +} + +fn produce_batch_result( + output_schema: SchemaRef, + equal_batch: ArrowResult>, + anti_batch: ArrowResult>, +) -> ArrowResult { + match (equal_batch, anti_batch) { + (Ok(Some(batch)), Ok(None)) | (Ok(None), Ok(Some(batch))) => Ok(batch), + (Err(e), _) | (_, Err(e)) => Err(e), + (Ok(Some(equal_batch)), Ok(Some(anti_batch))) => { + concat_batches(&output_schema, &[equal_batch, anti_batch]) + } + (Ok(None), Ok(None)) => Ok(RecordBatch::new_empty(output_schema)), + } +} + +impl SymmetricHashJoinStream { + /// Separate implementation function that unpins the [`SymmetricHashJoinStream`] so + /// that partial borrows work correctly + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + if self.final_result { + return Poll::Ready(None); + } + if self.right.exhausted && self.left.exhausted { + let left_result = self.left.build_side_determined_results( + self.schema.clone(), + self.left.input_buffer.num_rows(), + self.right.input_buffer.schema(), + self.join_type, + &self.column_indices, + ); + let right_result = self.right.build_side_determined_results( + self.schema.clone(), + self.right.input_buffer.num_rows(), + self.left.input_buffer.schema(), + self.join_type, + &self.column_indices, + ); + self.final_result = true; + let result = + produce_batch_result(self.schema.clone(), left_result, right_result); + if let Ok(batch) = &result { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + return Poll::Ready(Some(result)); + } + if self.data_side.eq(&JoinSide::Left) { + match self.left_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(probe_batch))) => { + self.join_metrics.left_input_batches.add(1); + self.join_metrics + .left_input_rows + .add(probe_batch.num_rows()); + match self.left.update_internal_state( + &self.on_left, + &probe_batch, + &self.random_state, + ) { + Ok(_) => {} + Err(e) => { + return Poll::Ready(Some(Err(ArrowError::ComputeError( + e.to_string(), + )))) + } + } + // Using right as build side. + let equal_result = self.right.record_batch_from_other_side( + self.schema.clone(), + self.join_type, + &self.on_right, + &self.on_left, + Some(&self.filter), + &probe_batch, + &mut self.left.visited_rows, + self.left.offset, + &self.column_indices, + &self.random_state, + &self.null_equals_null, + ); + self.left.offset += probe_batch.num_rows(); + // Right side will be pruned since the batch coming from left. + let anti_result = self.right.prune_build_side( + self.schema.clone(), + &probe_batch, + &mut self.filter_columns, + self.join_type, + &self.column_indices, + &mut self.physical_expr_graph, + ); + let result = produce_batch_result( + self.schema.clone(), + equal_result, + anti_result, + ); + if let Ok(batch) = &result { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + if !self.right.exhausted { + self.data_side = JoinSide::Right; + } + Poll::Ready(Some(result)) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.left.exhausted = true; + self.data_side = JoinSide::Right; + Poll::Ready(Some(Ok(RecordBatch::new_empty(self.schema())))) + } + Poll::Pending => Poll::Pending, + } + } else { + match self.right_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(probe_batch))) => { + self.join_metrics.right_input_batches.add(1); + self.join_metrics + .right_input_rows + .add(probe_batch.num_rows()); + // Right is build side + match self.right.update_internal_state( + &self.on_right, + &probe_batch, + &self.random_state, + ) { + Ok(_) => {} + Err(e) => { + return Poll::Ready(Some(Err(ArrowError::ComputeError( + e.to_string(), + )))) + } + } + let equal_result = self.left.record_batch_from_other_side( + self.schema.clone(), + self.join_type, + &self.on_left, + &self.on_right, + Some(&self.filter), + &probe_batch, + &mut self.right.visited_rows, + self.right.offset, + &self.column_indices, + &self.random_state, + &self.null_equals_null, + ); + self.right.offset += probe_batch.num_rows(); + let anti_result = self.left.prune_build_side( + self.schema.clone(), + &probe_batch, + &mut self.filter_columns, + self.join_type, + &self.column_indices, + &mut self.physical_expr_graph, + ); + let result = produce_batch_result( + self.schema.clone(), + equal_result, + anti_result, + ); + if let Ok(batch) = &result { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + if !self.left.exhausted { + self.data_side = JoinSide::Left; + } + Poll::Ready(Some(result)) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.right.exhausted = true; + self.data_side = JoinSide::Left; + Poll::Ready(Some(Ok(RecordBatch::new_empty(self.schema())))) + } + Poll::Pending => Poll::Pending, + } + } + } +} + +#[cfg(test)] +mod tests { + use std::fs::File; + + use arrow::array::ArrayRef; + use arrow::array::{Int32Array, TimestampNanosecondArray}; + use arrow::compute::CastOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::util::pretty::pretty_format_batches; + use rstest::*; + use tempfile::TempDir; + + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; + use datafusion_physical_expr::utils; + + use crate::physical_plan::collect; + use crate::physical_plan::joins::{HashJoinExec, PartitionMode}; + use crate::physical_plan::{ + common, memory::MemoryExec, repartition::RepartitionExec, + }; + use crate::prelude::{SessionConfig, SessionContext}; + use crate::test_util; + + use super::*; + + fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { + // compare + let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); + let second_formatted = pretty_format_batches(collected_2).unwrap().to_string(); + + let mut first_formatted_sorted: Vec<&str> = + first_formatted.trim().lines().collect(); + first_formatted_sorted.sort_unstable(); + + let mut second_formatted_sorted: Vec<&str> = + second_formatted.trim().lines().collect(); + second_formatted_sorted.sort_unstable(); + + for (i, (first_line, second_line)) in first_formatted_sorted + .iter() + .zip(&second_formatted_sorted) + .enumerate() + { + assert_eq!((i, first_line), (i, second_line)); + } + } + #[allow(clippy::too_many_arguments)] + async fn partitioned_sym_join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + sorted_filter_columns: Vec, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, + ) -> Result> { + let partition_count = 1; + + let left_expr: Vec> = on + .iter() + .map(|(l, _)| Arc::new(l.clone()) as Arc) + .collect_vec(); + + let right_expr: Vec> = on + .iter() + .map(|(_, r)| Arc::new(r.clone()) as Arc) + .collect_vec(); + + let join = SymmetricHashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + filter, + sorted_filter_columns, + join_type, + &null_equals_null, + )?; + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i, context.clone())?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok(batches) + } + #[allow(clippy::too_many_arguments)] + async fn partitioned_hash_join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, + ) -> Result> { + let partition_count = 1; + + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + + let join = HashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + Some(filter), + join_type, + PartitionMode::Partitioned, + &null_equals_null, + )?; + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i, context.clone())?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok(batches) + } + + pub fn split_record_batches( + batch: &RecordBatch, + num_split: usize, + ) -> Result> { + let row_num = batch.num_rows(); + let number_of_batch = row_num / num_split; + let mut sizes = vec![num_split; number_of_batch]; + sizes.push(row_num - (num_split * number_of_batch)); + let mut result = vec![]; + for (i, size) in sizes.iter().enumerate() { + result.push(batch.slice(i * num_split, *size)); + } + Ok(result) + } + + fn build_table(columns: Vec<(&str, ArrayRef)>) -> Arc { + let schema = Schema::new( + columns + .iter() + .map(|(name, array)| Field::new(*name, array.data_type().clone(), false)) + .collect(), + ); + let batch = RecordBatch::try_new( + Arc::new(schema), + columns.into_iter().map(|(_, array)| array).collect(), + ) + .unwrap(); + + let schema = batch.schema(); + Arc::new( + MemoryExec::try_new( + &[split_record_batches(&batch, 13).unwrap()], + schema, + None, + ) + .unwrap(), + ) + } + + fn join_expr_tests_fixture( + expr_id: usize, + left_watermark: Arc, + right_watermark: Arc, + ) -> Arc { + match expr_id { + // left_watermark + 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 + 0 => utils::filter_numeric_expr_generation( + left_watermark.clone(), + right_watermark.clone(), + Operator::Plus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + 1, + 5, + 3, + 10, + ), + // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 + 1 => utils::filter_numeric_expr_generation( + left_watermark.clone(), + right_watermark.clone(), + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + 1, + 5, + 3, + 10, + ), + // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10 + 2 => utils::filter_numeric_expr_generation( + left_watermark.clone(), + right_watermark.clone(), + Operator::Minus, + Operator::Plus, + Operator::Minus, + Operator::Plus, + 1, + 5, + 3, + 10, + ), + // left_watermark - 10 > right_watermark - 5 AND left_watermark - 3 < right_watermark + 10 + 3 => utils::filter_numeric_expr_generation( + left_watermark.clone(), + right_watermark.clone(), + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + 10, + 5, + 3, + 10, + ), + // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 + 4 => utils::filter_numeric_expr_generation( + left_watermark.clone(), + right_watermark.clone(), + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + 10, + 5, + 30, + 3, + ), + _ => unreachable!(), + } + } + + fn create_memory_table( + table_size: i32, + key_cardinality: (i32, i32), + ) -> (Arc, Arc) { + let initial_range = 0..table_size; + let left = build_table(vec![ + ( + "la1", + Arc::new(Int32Array::from_iter( + initial_range.clone().collect::>(), + )), + ), + ( + "lb1", + Arc::new(Int32Array::from_iter( + initial_range.clone().map(|x| x % 4).collect::>(), + )), + ), + ( + "lc1", + Arc::new(Int32Array::from_iter( + initial_range + .clone() + .map(|x| x % key_cardinality.0) + .collect::>(), + )), + ), + ( + "lt1", + Arc::new(TimestampNanosecondArray::from( + initial_range + .clone() + .map(|x| 1664264591000000000 + (5000000000 * (x as i64))) + .collect::>(), + )), + ), + ( + "la2", + Arc::new(Int32Array::from_iter( + initial_range.clone().collect::>(), + )), + ), + ( + "la1_des", + Arc::new(Int32Array::from_iter( + initial_range.clone().rev().collect::>(), + )), + ), + ]); + let right = build_table(vec![ + ( + "ra1", + Arc::new(Int32Array::from_iter( + initial_range.clone().collect::>(), + )), + ), + ( + "rb1", + Arc::new(Int32Array::from_iter( + initial_range.clone().map(|x| x % 7).collect::>(), + )), + ), + ( + "rc1", + Arc::new(Int32Array::from_iter( + initial_range + .clone() + .map(|x| x % key_cardinality.1) + .collect::>(), + )), + ), + ( + "rt1", + Arc::new(TimestampNanosecondArray::from( + initial_range + .clone() + .map(|x| 1664264591000000000 + (5000000000 * (x as i64))) + .collect::>(), + )), + ), + ( + "ra2", + Arc::new(Int32Array::from_iter( + initial_range.clone().collect::>(), + )), + ), + ( + "ra1_des", + Arc::new(Int32Array::from_iter( + initial_range.rev().collect::>(), + )), + ), + ]); + (left, right) + } + + fn complicated_fiter() -> Arc { + let left_expr = BinaryExpr { + left: Arc::new(CastExpr { + expr: Arc::new(BinaryExpr { + left: Arc::new(Column { + name: "0".to_string(), + index: 0, + }), + op: Operator::Plus, + right: Arc::new(Column { + name: "1".to_string(), + index: 1, + }), + }), + cast_type: DataType::Int64, + cast_options: CastOptions { safe: false }, + }), + op: Operator::Gt, + right: Arc::new(BinaryExpr { + left: Arc::new(CastExpr { + expr: Arc::new(Column { + name: "2".to_string(), + index: 2, + }), + cast_type: DataType::Int64, + cast_options: CastOptions { safe: false }, + }), + op: Operator::Plus, + right: Arc::new(Literal::new(ScalarValue::Int64(Some(10)))), + }), + }; + let right_expr = BinaryExpr { + left: Arc::new(CastExpr { + expr: Arc::new(BinaryExpr { + left: Arc::new(Column { + name: "0".to_string(), + index: 0, + }), + op: Operator::Plus, + right: Arc::new(Column { + name: "1".to_string(), + index: 1, + }), + }), + cast_type: DataType::Int64, + cast_options: CastOptions { safe: false }, + }), + op: Operator::Lt, + right: Arc::new(BinaryExpr { + left: Arc::new(CastExpr { + expr: Arc::new(Column { + name: "2".to_string(), + index: 2, + }), + cast_type: DataType::Int64, + cast_options: CastOptions { safe: false }, + }), + op: Operator::Plus, + right: Arc::new(Literal::new(ScalarValue::Int64(Some(100)))), + }), + }; + Arc::new(BinaryExpr::new( + Arc::new(left_expr), + Operator::And, + Arc::new(right_expr), + )) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread", worker_threads = 3)] + async fn complex_join_all_one_ascending_numeric( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (31, 71), + (99, 12), + )] + cardinality: (i32, i32), + ) -> Result<()> { + // a + b > c + 10 AND a + b < c + 100 + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left, right) = create_memory_table(1000, cardinality); + + let on = vec![( + Column::new_with_schema("lc1", &left.schema())?, + Column::new_with_schema("rc1", &right.schema())?, + )]; + + let filter_col_0 = Arc::new(Column::new("0", 0)); + let filter_col_1 = Arc::new(Column::new("1", 1)); + let filter_col_2 = Arc::new(Column::new("2", 2)); + + let main_sorted_binary = Arc::new(BinaryExpr { + left: Arc::new(Column::new_with_schema("la1", &left.schema())?), + op: Operator::Plus, + right: Arc::new(Column::new_with_schema("la2", &left.schema())?), + }); + + let filter_sorted_binary = Arc::new(BinaryExpr { + left: filter_col_0.clone(), + op: Operator::Plus, + right: filter_col_1.clone(), + }); + + let right_main_sorted_col = + Arc::new(Column::new_with_schema("ra1", &right.schema())?); + + let sorted_filter_columns = vec![ + SortedFilterExpr::new( + JoinSide::Left, + main_sorted_binary, + filter_sorted_binary, + SortOptions::default(), + ), + SortedFilterExpr::new( + JoinSide::Right, + right_main_sorted_col.clone(), + filter_col_2.clone(), + SortOptions::default(), + ), + ]; + + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(filter_col_0.name(), DataType::Int32, true), + Field::new(filter_col_1.name(), DataType::Int32, true), + Field::new(filter_col_2.name(), DataType::Int32, true), + ]); + + let filter_expr = complicated_fiter(); + + let filter = JoinFilter::new( + filter_expr, + column_indices.clone(), + intermediate_schema.clone(), + ); + + let first_batches = partitioned_sym_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + sorted_filter_columns, + &join_type, + false, + task_ctx.clone(), + ) + .await?; + let second_batches = partitioned_hash_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + &join_type, + false, + task_ctx.clone(), + ) + .await?; + compare_batches(&first_batches, &second_batches); + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread", worker_threads = 3)] + async fn join_all_one_ascending_numeric( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (31, 71), + (99, 12), + )] + cardinality: (i32, i32), + #[values(0, 1, 2, 3, 4)] case_expr: usize, + ) -> Result<()> { + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left, right) = create_memory_table(1000, cardinality); + + let on = vec![( + Column::new_with_schema("lc1", &left.schema())?, + Column::new_with_schema("rc1", &right.schema())?, + )]; + + let left_col = Arc::new(Column::new("left", 0)); + let right_col = Arc::new(Column::new("right", 1)); + + let sorted_filter_columns = vec![ + SortedFilterExpr::new( + JoinSide::Left, + Arc::new(Column::new_with_schema("la1", &left.schema())?), + left_col.clone(), + SortOptions::default(), + ), + SortedFilterExpr::new( + JoinSide::Right, + Arc::new(Column::new_with_schema("ra1", &right.schema())?), + right_col.clone(), + SortOptions::default(), + ), + ]; + + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(left_col.name(), DataType::Int32, true), + Field::new(right_col.name(), DataType::Int32, true), + ]); + + let filter_expr = + join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); + + let filter = JoinFilter::new( + filter_expr, + column_indices.clone(), + intermediate_schema.clone(), + ); + + let first_batches = partitioned_sym_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + sorted_filter_columns, + &join_type, + false, + task_ctx.clone(), + ) + .await?; + let second_batches = partitioned_hash_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + &join_type, + false, + task_ctx.clone(), + ) + .await?; + compare_batches(&first_batches, &second_batches); + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread", worker_threads = 3)] + async fn join_all_one_descending_numeric_particular( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (31, 71), + (99, 12), + )] + cardinality: (i32, i32), + #[values(0, 1, 2, 3, 4)] case_expr: usize, + ) -> Result<()> { + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left, right) = create_memory_table(20, cardinality); + + let on = vec![( + Column::new_with_schema("lc1", &left.schema())?, + Column::new_with_schema("rc1", &right.schema())?, + )]; + + let left_col = Arc::new(Column::new("left", 0)); + let right_col = Arc::new(Column::new("right", 1)); + + let sorted_filter_columns = vec![ + SortedFilterExpr::new( + JoinSide::Left, + Arc::new(Column::new_with_schema("la1_des", &left.schema())?), + left_col.clone(), + SortOptions { + descending: true, + nulls_first: true, + }, + ), + SortedFilterExpr::new( + JoinSide::Right, + Arc::new(Column::new_with_schema("ra1_des", &right.schema())?), + right_col.clone(), + SortOptions { + descending: true, + nulls_first: true, + }, + ), + ]; + + let column_indices = vec![ + ColumnIndex { + index: 5, + side: JoinSide::Left, + }, + ColumnIndex { + index: 5, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(left_col.name(), DataType::Int32, true), + Field::new(right_col.name(), DataType::Int32, true), + ]); + + let filter_expr = + join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); + + let filter = JoinFilter::new( + filter_expr, + column_indices.clone(), + intermediate_schema.clone(), + ); + + let first_batches = partitioned_sym_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + sorted_filter_columns, + &join_type, + false, + task_ctx.clone(), + ) + .await?; + let second_batches = partitioned_hash_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + &join_type, + false, + task_ctx.clone(), + ) + .await?; + compare_batches(&first_batches, &second_batches); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 20)] + async fn join_change_in_planner() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::with_config(config); + let tmp_dir = TempDir::new().unwrap(); + let left_file_path = tmp_dir.path().join("left.csv"); + File::create(left_file_path.clone()).unwrap(); + test_util::test_create_unbounded_sorted_file( + &ctx, + left_file_path.clone(), + "left", + ) + .await?; + let right_file_path = tmp_dir.path().join("right.csv"); + File::create(right_file_path.clone()).unwrap(); + test_util::test_create_unbounded_sorted_file( + &ctx, + right_file_path.clone(), + "right", + ) + .await?; + let df = ctx.sql("EXPLAIN SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; + let physical_plan = df.create_physical_plan().await?; + let task_ctx = ctx.task_ctx(); + let results = collect(physical_plan.clone(), task_ctx).await.unwrap(); + let formatted = pretty_format_batches(&results).unwrap().to_string(); + println!("Query Output:\n\n{formatted}"); + let found = formatted + .lines() + .any(|line| line.contains("SymmetricHashJoinExec")); + assert!(found); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index 3f568f98ec9d..cc7cb52142be 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -17,10 +17,13 @@ //! Join related functionality used both on logical and physical plans -use crate::error::{DataFusionError, Result}; -use crate::logical_expr::JoinType; -use crate::physical_plan::expressions::Column; -use crate::physical_plan::SchemaRef; +use std::cmp::max; +use std::collections::HashSet; +use std::future::Future; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::usize; + use arrow::array::{ new_null_array, Array, BooleanBufferBuilder, PrimitiveArray, UInt32Array, UInt32Builder, UInt64Array, @@ -29,22 +32,24 @@ use arrow::compute; use arrow::datatypes::{Field, Schema, UInt32Type, UInt64Type}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; -use datafusion_common::cast::as_boolean_array; -use datafusion_common::ScalarValue; -use datafusion_physical_expr::{EquivalentClass, PhysicalExpr}; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use parking_lot::Mutex; -use std::cmp::max; -use std::collections::HashSet; -use std::future::Future; -use std::sync::Arc; -use std::task::{Context, Poll}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::ScalarValue; + +use datafusion_physical_expr::rewrite::TreeNodeRewritable; +use datafusion_physical_expr::{EquivalentClass, PhysicalExpr}; + +use crate::error::{DataFusionError, Result}; +use crate::logical_expr::JoinType; +use crate::physical_plan::expressions::Column; + +use crate::physical_plan::SchemaRef; use crate::physical_plan::{ ColumnStatistics, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics, }; -use datafusion_physical_expr::rewrite::TreeNodeRewritable; /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; @@ -222,7 +227,7 @@ pub fn cross_join_equivalence_properties( } /// Used in ColumnIndex to distinguish which side the index is for -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Copy)] pub enum JoinSide { /// Left side of the join Left, @@ -230,6 +235,16 @@ pub enum JoinSide { Right, } +impl JoinSide { + /// Inverse the join side + pub fn negate(&self) -> Self { + match self { + JoinSide::Left => JoinSide::Right, + JoinSide::Right => JoinSide::Left, + } + } +} + /// Information about the index and placement (left or right) of the columns #[derive(Debug, Clone)] pub struct ColumnIndex { @@ -742,26 +757,26 @@ pub(crate) fn get_final_indices_from_bit_map( (left_indices, right_indices) } -/// Use the `left_indices` and `right_indices` to restructure tuples, and apply the `filter` to -/// all of them to get the matched left and right indices. pub(crate) fn apply_join_filter_to_indices( - left: &RecordBatch, - right: &RecordBatch, - left_indices: UInt64Array, - right_indices: UInt32Array, + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_indices: UInt64Array, + probe_indices: UInt32Array, filter: &JoinFilter, + build_side: JoinSide, ) -> Result<(UInt64Array, UInt32Array)> { - if left_indices.is_empty() && right_indices.is_empty() { - return Ok((left_indices, right_indices)); + if build_indices.is_empty() && probe_indices.is_empty() { + return Ok((build_indices, probe_indices)); }; let intermediate_batch = build_batch_from_indices( filter.schema(), - left, - right, - PrimitiveArray::from(left_indices.data().clone()), - PrimitiveArray::from(right_indices.data().clone()), + build_input_buffer, + probe_batch, + PrimitiveArray::from(build_indices.data().clone()), + PrimitiveArray::from(probe_indices.data().clone()), filter.column_indices(), + build_side, )?; let filter_result = filter .expression() @@ -770,12 +785,11 @@ pub(crate) fn apply_join_filter_to_indices( let mask = as_boolean_array(&filter_result)?; let left_filtered = PrimitiveArray::::from( - compute::filter(&left_indices, mask)?.data().clone(), + compute::filter(&build_indices, mask)?.data().clone(), ); let right_filtered = PrimitiveArray::::from( - compute::filter(&right_indices, mask)?.data().clone(), + compute::filter(&probe_indices, mask)?.data().clone(), ); - Ok((left_filtered, right_filtered)) } @@ -783,11 +797,12 @@ pub(crate) fn apply_join_filter_to_indices( /// The resulting batch has [Schema] `schema`. pub(crate) fn build_batch_from_indices( schema: &Schema, - left: &RecordBatch, - right: &RecordBatch, - left_indices: UInt64Array, - right_indices: UInt32Array, + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_indices: UInt64Array, + probe_indices: UInt32Array, column_indices: &[ColumnIndex], + build_side: JoinSide, ) -> ArrowResult { // build the columns of the new [RecordBatch]: // 1. pick whether the column is from the left or right @@ -795,27 +810,24 @@ pub(crate) fn build_batch_from_indices( let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); for column_index in column_indices { - let array = match column_index.side { - JoinSide::Left => { - let array = left.column(column_index.index); - if array.is_empty() || left_indices.null_count() == left_indices.len() { - // Outer join would generate a null index when finding no match at our side. - // Therefore, it's possible we are empty but need to populate an n-length null array, - // where n is the length of the index array. - assert_eq!(left_indices.null_count(), left_indices.len()); - new_null_array(array.data_type(), left_indices.len()) - } else { - compute::take(array.as_ref(), &left_indices, None)? - } + let array = if column_index.side == build_side { + let array = build_input_buffer.column(column_index.index); + if array.is_empty() || build_indices.null_count() == build_indices.len() { + // Outer join would generate a null index when finding no match at our side. + // Therefore, it's possible we are empty but need to populate an n-length null array, + // where n is the length of the index array. + assert_eq!(build_indices.null_count(), build_indices.len()); + new_null_array(array.data_type(), build_indices.len()) + } else { + compute::take(array.as_ref(), &build_indices, None)? } - JoinSide::Right => { - let array = right.column(column_index.index); - if array.is_empty() || right_indices.null_count() == right_indices.len() { - assert_eq!(right_indices.null_count(), right_indices.len()); - new_null_array(array.data_type(), right_indices.len()) - } else { - compute::take(array.as_ref(), &right_indices, None)? - } + } else { + let array = probe_batch.column(column_index.index); + if array.is_empty() || probe_indices.null_count() == probe_indices.len() { + assert_eq!(probe_indices.null_count(), probe_indices.len()); + new_null_array(array.data_type(), probe_indices.len()) + } else { + compute::take(array.as_ref(), &probe_indices, None)? } }; columns.push(array); @@ -936,10 +948,12 @@ pub(crate) fn get_semi_indices( #[cfg(test)] mod tests { - use super::*; use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; + use super::*; + fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { let left = left .iter() diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index 2db6f3bc46ba..3eb5ec576fb8 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -26,6 +26,7 @@ use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::execution::context::SessionState; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; +use crate::prelude::{CsvReadOptions, SessionContext}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion_common::DataFusionError; @@ -383,3 +384,43 @@ mod tests { assert!(PathBuf::from(res).is_dir()); } } +/// It creates unbounded sorted files for tests +pub async fn test_create_unbounded_sorted_file( + ctx: &SessionContext, + file_path: PathBuf, + table_name: &str, +) -> datafusion_common::Result<()> { + // Register table + let right_schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::UInt32, false), + Field::new("a2", DataType::UInt32, false), + ])); + // The sort order is specified + let file_sort_order = [datafusion_expr::col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(); + // Mark infinite and provide schema + let fifo_options = CsvReadOptions::new() + .schema(right_schema.as_ref()) + .has_header(false) + .mark_infinite(true); + // Get listing options + let options_sort = fifo_options + .to_listing_options(ctx.copied_config().target_partitions()) + .with_file_sort_order(Some(file_sort_order)); + // Register table + ctx.register_listing_table( + table_name, + file_path.as_os_str().to_str().unwrap(), + options_sort.clone(), + Some(right_schema), + None, + ) + .await?; + Ok(()) +} diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 7b524454d2ba..1d182835dc09 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -18,8 +18,11 @@ //! This test demonstrates the DataFusion FIFO capabilities. //! #[cfg(not(target_os = "windows"))] +#[cfg(test)] mod unix_test { + use arrow::array::Array; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::test_util::test_create_unbounded_sorted_file; use datafusion::{ prelude::{CsvReadOptions, SessionConfig, SessionContext}, test_util::{aggr_test_schema, arrow_test_data}, @@ -29,6 +32,8 @@ mod unix_test { use itertools::enumerate; use nix::sys::stat; use nix::unistd; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; use rstest::*; use std::fs::{File, OpenOptions}; use std::io::Write; @@ -38,8 +43,10 @@ mod unix_test { use std::sync::mpsc::{Receiver, Sender}; use std::sync::{Arc, Mutex}; use std::thread; + use std::thread::JoinHandle; use std::time::{Duration, Instant}; use tempfile::TempDir; + // ! For the sake of the test, do not alter the numbers. ! // Session batch size const TEST_BATCH_SIZE: usize = 20; @@ -133,7 +140,7 @@ mod unix_test { } // This test provides a relatively realistic end-to-end scenario where - // we ensure that we swap join sides correctly to accommodate a FIFO source. + // we ensure that unbounded execution happened #[rstest] #[timeout(std::time::Duration::from_secs(30))] #[tokio::test(flavor = "multi_thread", worker_threads = 5)] @@ -147,7 +154,7 @@ mod unix_test { let (tx, rx): (Sender, Receiver) = mpsc::channel(); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; - let fifo_path = create_fifo_file(&tmp_dir, "fisrt_fifo.csv")?; + let fifo_path = create_fifo_file(&tmp_dir, "first_fifo.csv")?; // Prevent move let fifo_path_thread = fifo_path.clone(); // Timeout for a long period of BrokenPipe error @@ -217,4 +224,107 @@ mod unix_test { assert_eq!(interleave(&result), unbounded_file); Ok(()) } + #[derive(Debug, PartialEq)] + enum JoinOperation { + LeftUnmatched, + RightUnmatched, + Equal, + } + + // This test provides a relatively realistic end-to-end scenario where + // we ensure that we change join into [SymmetricHashJoin] correctly to accommodate the FIFO sources. + // Creates 2 FIFO files with unbounded source. + #[rstest] + #[timeout(std::time::Duration::from_secs(30))] + #[tokio::test(flavor = "multi_thread")] + async fn unbounded_file_with_symmetric_join() -> Result<()> { + // To make unbounded deterministic + let waiting = Arc::new(Mutex::new(true)); + let thread_bools = vec![waiting.clone(), waiting.clone()]; + // Create a new temporary FIFO file + let tmp_dir = TempDir::new()?; + let file_names = vec!["first_fifo.csv", "second_fifo.csv"]; + // The sender endpoint can be copied + let (threads, file_paths): (Vec>, Vec) = file_names + .iter() + .zip(thread_bools.iter()) + .map(|(file_name, lock)| { + let waiting_thread = lock.clone(); + let fifo_path = create_fifo_file(&tmp_dir, file_name).unwrap(); + let return_path = fifo_path.clone(); + // Timeout for a long period of BrokenPipe error + let broken_pipe_timeout = Duration::from_secs(5); + // Spawn a new thread to write to the FIFO file + let fifo_writer = thread::spawn(move || { + let mut rng = StdRng::seed_from_u64(42); + let file = OpenOptions::new() + .write(true) + .open(fifo_path.clone()) + .unwrap(); + // Reference time to use when deciding to fail the test + let execution_start = Instant::now(); + // Join filter + let a1_iter = (0..TEST_DATA_SIZE).clone().map(|x| { + let change = rng.gen_range(0.0..1.0); + if change < 0.3 { + x - 1 + } else { + x + } + }); + // Join key + let a2_iter = (0..TEST_DATA_SIZE).clone().map(|x| x % 10); + for (cnt, (a1, a2)) in a1_iter.zip(a2_iter).enumerate() { + // Wait a reading sign for unbounded execution + // After first batch FIFO reading, we will wait for a batch created. + while *waiting_thread.lock().unwrap() && TEST_BATCH_SIZE < cnt { + thread::sleep(Duration::from_millis(200)); + } + let line = format!("{},{}\n", a1, a2).to_owned(); + write_to_fifo(&file, &line, execution_start, broken_pipe_timeout) + .unwrap(); + } + }); + (fifo_writer, return_path) + }) + .unzip(); + let config = SessionConfig::new() + .with_batch_size(TEST_BATCH_SIZE) + .set_bool("datafusion.execution.coalesce_batches", false) + .with_target_partitions(1); + let ctx = SessionContext::with_config(config); + test_create_unbounded_sorted_file(&ctx, file_paths[0].clone(), "left").await?; + test_create_unbounded_sorted_file(&ctx, file_paths[1].clone(), "right").await?; + // Execute the query + let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; + let mut stream = df.execute_stream().await?; + let mut operations = vec![]; + while let Some(Ok(batch)) = stream.next().await { + *waiting.lock().unwrap() = false; + let op = if batch.column(0).null_count() > 0 { + JoinOperation::LeftUnmatched + } else if batch.column(2).null_count() > 0 { + JoinOperation::RightUnmatched + } else { + JoinOperation::Equal + }; + operations.push(op); + } + // It check if the right unmatches or left unmatches generated at the end of the file, which + // should not in unbounded mode. + assert!( + operations + .iter() + .filter(|&n| JoinOperation::RightUnmatched.eq(n)) + .count() + > 1 + && operations + .iter() + .filter(|&n| JoinOperation::LeftUnmatched.eq(n)) + .count() + > 1 + ); + threads.into_iter().for_each(|j| j.join().unwrap()); + Ok(()) + } } diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index fac81654c4a8..d4a40ea083e6 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -110,6 +110,41 @@ impl Operator { } } + /// Return true if the operator is a comparison operator. + /// + /// For example, 'Binary(a, >, b)' would be a comparison. + pub fn is_comparison_operator(&self) -> bool { + matches!( + self, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + ) + } + + /// Return true if the operator is a logic operator. + /// + /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would be a logic. + pub fn is_logic_operator(&self) -> bool { + matches!( + self, + Operator::And + | Operator::Or + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft + ) + } + /// Return the operator where swapping lhs and rhs wouldn't change the result. /// /// For example `Binary(50, >=, a)` could also be represented as `Binary(a, <=, 50)`. diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 3fb632ebd034..c23d07c6d2de 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -57,6 +57,7 @@ lazy_static = { version = "^1.4.0" } md-5 = { version = "^0.10.0", optional = true } num-traits = { version = "0.2", default-features = false } paste = "^1.0" +petgraph = "0.6.2" rand = "0.8" regex = { version = "^1.4.3", optional = true } sha2 = { version = "^0.10.1", optional = true } @@ -66,6 +67,7 @@ uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] criterion = "0.4" rand = "0.8" +rstest = "0.16.0" [[bench]] harness = false diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 149eee593424..08d648c63741 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -81,9 +81,9 @@ use datafusion_expr::{ColumnarValue, Operator}; /// Binary expression #[derive(Debug)] pub struct BinaryExpr { - left: Arc, - op: Operator, - right: Arc, + pub left: Arc, + pub op: Operator, + pub right: Arc, } impl BinaryExpr { diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 4d01131353d8..f8d0ba28a0c1 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -38,11 +38,11 @@ pub const DEFAULT_DATAFUSION_CAST_OPTIONS: CastOptions = CastOptions { safe: fal #[derive(Debug)] pub struct CastExpr { /// The expression to cast - expr: Arc, + pub expr: Arc, /// The data type to cast to - cast_type: DataType, + pub cast_type: DataType, /// Cast options - cast_options: CastOptions, + pub cast_options: CastOptions, } impl CastExpr { @@ -68,6 +68,10 @@ impl CastExpr { pub fn cast_type(&self) -> &DataType { &self.cast_type } + /// The data type to cast to + pub fn cast_options(&self) -> &CastOptions { + &self.cast_options + } } impl fmt::Display for CastExpr { diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index eb2be5ef217c..c7e153da6867 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -33,8 +33,8 @@ use datafusion_expr::ColumnarValue; /// Represents the column at a given index in a RecordBatch #[derive(Debug, Hash, PartialEq, Eq, Clone)] pub struct Column { - name: String, - index: usize, + pub name: String, + pub index: usize, } impl Column { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs new file mode 100644 index 000000000000..bf316a6294e5 --- /dev/null +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -0,0 +1,619 @@ +use crate::expressions::{BinaryExpr, CastExpr, Literal}; +use crate::intervals::interval_aritmetics::{ + apply_operator, negate_ops, propagate_logical_operators, Interval, +}; +use crate::utils::{build_physical_expr_graph, ExprTreeNode}; +use crate::PhysicalExpr; +use std::fmt::{Display, Formatter}; + +use datafusion_common::Result; +use datafusion_expr::Operator; +use petgraph::graph::NodeIndex; +use petgraph::visit::{Bfs, DfsPostOrder}; +use petgraph::Outgoing; + +use petgraph::stable_graph::StableGraph; +use std::ops::{Index, IndexMut}; +use std::sync::Arc; + +#[derive(Clone, Debug)] +/// Graph nodes contains interval information. +pub struct ExprIntervalGraphNode { + expr: Arc, + interval: Option, +} + +impl Display for ExprIntervalGraphNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.expr) + } +} + +impl ExprIntervalGraphNode { + // Construct ExprIntervalGraphNode + pub fn new(expr: Arc) -> Self { + ExprIntervalGraphNode { + expr, + interval: None, + } + } + /// Specify interval + pub fn with_interval(mut self, interval: Interval) -> Self { + self.interval = Some(interval); + self + } + /// Get interval + pub fn interval(&self) -> &Option { + &self.interval + } + + /// Within this static method, one can customize how PhysicalExpr generator constructs its nodes + pub fn expr_node_builder(input: Arc) -> ExprIntervalGraphNode { + let binding = input.expr(); + let plan_any = binding.as_any(); + if let Some(literal) = plan_any.downcast_ref::() { + // Create interval + let interval = Interval::Singleton(literal.value().clone()); + ExprIntervalGraphNode::new(input.expr().clone()).with_interval(interval) + } else { + ExprIntervalGraphNode::new(input.expr().clone()) + } + } +} + +impl PartialEq for ExprIntervalGraphNode { + fn eq(&self, other: &ExprIntervalGraphNode) -> bool { + self.expr.eq(&other.expr) + } +} + +pub struct ExprIntervalGraph( + NodeIndex, + StableGraph, + Vec<(Arc, usize)>, +); + +/// This function calculates the current node interval +fn calculate_node_interval( + parent: &Interval, + current_child_interval: &Interval, + negated_op: Operator, + other_child_interval: &Interval, +) -> Result { + Ok(if current_child_interval.is_singleton() { + current_child_interval.clone() + } else { + apply_operator(parent, &negated_op, other_child_interval)? + .intersect(current_child_interval)? + }) +} + +impl ExprIntervalGraph { + /// We initiate [ExprIntervalGraph] with [PhysicalExpr] and the ones with we provide internals. + /// Let's say we have "a@0 + b@1 > 3" expression. + /// We can provide initial interval information for "a@0" and "b@1". Also, we can provide + /// information + pub fn try_new( + expr: Arc, + provided_expr: &[Arc], + ) -> Result { + let (root_node, mut graph) = + build_physical_expr_graph(expr, &ExprIntervalGraphNode::expr_node_builder)?; + let mut bfs = Bfs::new(&graph, root_node); + let mut will_be_removed = vec![]; + let mut expr_node_indices: Vec<(Arc, usize)> = provided_expr + .iter() + .map(|e| (e.clone(), usize::MAX)) + .collect(); + while let Some(node) = bfs.next(&graph) { + // Get plan + let input = graph.index(node); + let expr = input.expr.clone(); + // If we find a expr, remove the childs since we could not propagate the intervals there. + if let Some(value) = provided_expr.iter().position(|e| expr.eq(e)) { + // Remove all child nodes here, since we have information about them. If there is no child, + // iterator will not have anything. + let expr_node_index = expr_node_indices.get_mut(value).unwrap(); + *expr_node_index = (expr_node_index.0.clone(), node.index()); + let mut edges = graph.neighbors_directed(node, Outgoing).detach(); + while let Some(n_index) = edges.next_node(&graph) { + // We will not delete while traversing + will_be_removed.push(n_index); + } + } + } + for node in will_be_removed { + graph.remove_node(node); + } + Ok(Self(root_node, graph, expr_node_indices)) + } + + pub fn calculate_new_intervals( + &mut self, + expr_stats: &mut [(Arc, Interval)], + ) -> Result<()> { + self.post_order_interval_calculation(expr_stats)?; + self.pre_order_interval_propagation(expr_stats)?; + Ok(()) + } + + fn post_order_interval_calculation( + &mut self, + expr_stats: &[(Arc, Interval)], + ) -> Result<()> { + let mut dfs = DfsPostOrder::new(&self.1, self.0); + while let Some(node) = dfs.next(&self.1) { + // Outgoing (Children) edges + let mut edges = self.1.neighbors_directed(node, Outgoing).detach(); + // Get PhysicalExpr + let expr = self.1[node].expr.clone(); + // Check if we have a interval information about given PhysicalExpr, if so, directly + // propagate it to the upper. + if let Some((_, interval)) = + expr_stats.iter().find(|(e, _interval)| expr.eq(e)) + { + let input = self.1.index_mut(node); + input.interval = Some(interval.clone()); + continue; + } + + let expr_any = expr.as_any(); + if let Some(binary) = expr_any.downcast_ref::() { + // Access left child immutable + let second_child_node_index = edges.next_node(&self.1).unwrap(); + // Access left child immutable + let first_child_node_index = edges.next_node(&self.1).unwrap(); + // Do not use any reference from graph, MUST clone here. + let left_interval = + match self.1.index(first_child_node_index).interval().as_ref() { + Some(v) => v.clone(), + None => continue, + }; + let right_interval = + match self.1.index(second_child_node_index).interval().as_ref() { + Some(v) => v.clone(), + None => continue, + }; + // Since we release the reference, we can get mutable reference. + let input = self.1.index_mut(node); + // Calculate and replace the interval + input.interval = Some(apply_operator( + &left_interval, + binary.op(), + &right_interval, + )?); + } else if let Some(CastExpr { + cast_type, + cast_options, + .. + }) = expr_any.downcast_ref::() + { + // Access the child immutable + let child_index = edges.next_node(&self.1).unwrap(); + let child = self.1.index(child_index); + // Cast the interval + let new_interval = match &child.interval { + Some(interval) => interval.cast_to(cast_type, cast_options)?, + // If there is no child, continue + None => continue, + }; + // Update the interval + let input = self.1.index_mut(node); + input.interval = Some(new_interval); + } + } + Ok(()) + } + + pub fn pre_order_interval_propagation( + &mut self, + expr_stats: &mut [(Arc, Interval)], + ) -> Result<()> { + let mut bfs = Bfs::new(&self.1, self.0); + while let Some(node) = bfs.next(&self.1) { + // Get plan + let input = self.1.index(node); + let expr = input.expr.clone(); + // Get calculated interval. BinaryExpr will propagate the interval according to + // this. + let parent_calculated_interval = match input.interval().as_ref() { + Some(i) => i.clone(), + None => continue, + }; + + if let Some(index) = expr_stats.iter().position(|(e, _interval)| expr.eq(e)) { + let col_interval = expr_stats.get_mut(index).unwrap(); + *col_interval = (col_interval.0.clone(), parent_calculated_interval); + continue; + } + + // Outgoing (Children) edges + let mut edges = self.1.neighbors_directed(node, Outgoing).detach(); + + let expr_any = expr.as_any(); + if let Some(binary) = expr_any.downcast_ref::() { + // Get right node. + let second_child_node_index = edges.next_node(&self.1).unwrap(); + let second_child_interval = + match self.1.index(second_child_node_index).interval().as_ref() { + Some(v) => v, + None => continue, + }; + // Get left node. + let first_child_node_index = edges.next_node(&self.1).unwrap(); + let first_child_interval = + match self.1.index(first_child_node_index).interval().as_ref() { + Some(v) => v, + None => continue, + }; + + let (shrink_left_interval, shrink_right_interval) = + if parent_calculated_interval.is_boolean() { + propagate_logical_operators( + first_child_interval, + binary.op(), + second_child_interval, + )? + } else { + let negated_op = negate_ops(*binary.op()); + let new_right_operator = calculate_node_interval( + &parent_calculated_interval, + second_child_interval, + negated_op, + first_child_interval, + )?; + let new_left_operator = calculate_node_interval( + &parent_calculated_interval, + first_child_interval, + negated_op, + &new_right_operator, + )?; + (new_left_operator, new_right_operator) + }; + let mutable_first_child = self.1.index_mut(first_child_node_index); + mutable_first_child.interval = Some(shrink_left_interval.clone()); + let mutable_second_child = self.1.index_mut(second_child_node_index); + mutable_second_child.interval = Some(shrink_right_interval.clone()) + } else if let Some(cast) = expr_any.downcast_ref::() { + // Calculate new interval + let child_index = edges.next_node(&self.1).unwrap(); + let child = self.1.index(child_index); + // Get child's internal datatype. If None, propagate None. + let cast_type = match child.interval() { + Some(interval) => interval.get_datatype(), + None => continue, + }; + let cast_options = cast.cast_options(); + let new_child_interval = match &child.interval { + Some(_) => { + parent_calculated_interval.cast_to(&cast_type, cast_options)? + } + None => continue, + }; + let mutable_child = self.1.index_mut(child_index); + mutable_child.interval = Some(new_child_interval.clone()); + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::filter_numeric_expr_generation; + use itertools::Itertools; + + use crate::expressions::Column; + use crate::intervals::interval_aritmetics::Range; + use datafusion_common::ScalarValue; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use rstest::*; + + fn experiment( + expr: Arc, + exprs: (Arc, Arc), + left_interval: (Option, Option), + right_interval: (Option, Option), + left_waited: (Option, Option), + right_waited: (Option, Option), + ) -> Result<()> { + let mut col_stats = vec![ + ( + exprs.0.clone(), + Interval::Range(Range { + lower: ScalarValue::Int32(left_interval.0), + upper: ScalarValue::Int32(left_interval.1), + }), + ), + ( + exprs.1.clone(), + Interval::Range(Range { + lower: ScalarValue::Int32(right_interval.0), + upper: ScalarValue::Int32(right_interval.1), + }), + ), + ]; + let expected = vec![ + ( + exprs.0.clone(), + Interval::Range(Range { + lower: ScalarValue::Int32(left_waited.0), + upper: ScalarValue::Int32(left_waited.1), + }), + ), + ( + exprs.1.clone(), + Interval::Range(Range { + lower: ScalarValue::Int32(right_waited.0), + upper: ScalarValue::Int32(right_waited.1), + }), + ), + ]; + let mut graph = ExprIntervalGraph::try_new( + expr, + &col_stats.iter().map(|(e, _)| e.clone()).collect_vec(), + )?; + graph.calculate_new_intervals(&mut col_stats)?; + col_stats + .iter() + .zip(expected.iter()) + .for_each(|((_, res), (_, expected))| assert_eq!(res, expected)); + Ok(()) + } + + fn generate_case( + expr: Arc, + left_col: Arc, + right_col: Arc, + seed: u64, + expr_left: i32, + expr_right: i32, + ) -> Result<()> { + let mut r = StdRng::seed_from_u64(seed); + + let (left_interval, right_interval, left_waited, right_waited) = if ASC { + let left = (Some(r.gen_range(0..1000)), None); + let right = (Some(r.gen_range(0..1000)), None); + ( + left, + right, + ( + Some(std::cmp::max(left.0.unwrap(), right.0.unwrap() + expr_left)), + None, + ), + ( + Some(std::cmp::max( + right.0.unwrap(), + left.0.unwrap() + expr_right, + )), + None, + ), + ) + } else { + let left = (None, Some(r.gen_range(0..1000))); + let right = (None, Some(r.gen_range(0..1000))); + ( + left, + right, + ( + None, + Some(std::cmp::min(left.1.unwrap(), right.1.unwrap() + expr_left)), + ), + ( + None, + Some(std::cmp::min( + right.1.unwrap(), + left.1.unwrap() + expr_right, + )), + ), + ) + }; + experiment( + expr, + (left_col, right_col), + left_interval, + right_interval, + left_waited, + right_waited, + )?; + Ok(()) + } + #[rstest] + #[test] + fn case_1( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 < right_watermark + 33 + let expr = filter_numeric_expr_generation( + left_col.clone(), + right_col.clone(), + Operator::Plus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + 1, + 11, + 3, + 33, + ); + // l > r + 10 AND r > l - 30 + let l_gt_r = 10; + let r_gt_l = -30; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 10 AND l < r + 30 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + + Ok(()) + } + #[rstest] + #[test] + fn case_2( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 + let expr = filter_numeric_expr_generation( + left_col.clone(), + right_col.clone(), + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + 1, + 5, + 3, + 10, + ); + // l > r + 6 AND r > l - 7 + let l_gt_r = 6; + let r_gt_l = -7; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 6 AND l < r + 7 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + + Ok(()) + } + + #[rstest] + #[test] + fn case_3( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10 + let expr = filter_numeric_expr_generation( + left_col.clone(), + right_col.clone(), + Operator::Minus, + Operator::Plus, + Operator::Minus, + Operator::Plus, + 1, + 5, + 3, + 10, + ); + // l > r + 6 AND r > l - 13 + let l_gt_r = 6; + let r_gt_l = -13; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 6 AND l < r + 13 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + + Ok(()) + } + #[rstest] + #[test] + fn case_4( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark - 10 > right_watermark - 5 AND left_watermark - 3 < right_watermark + 10 + let expr = filter_numeric_expr_generation( + left_col.clone(), + right_col.clone(), + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + 10, + 5, + 3, + 10, + ); + // l > r + 5 AND r > l - 13 + let l_gt_r = 5; + let r_gt_l = -13; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 5 AND l < r + 13 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + Ok(()) + } + + #[rstest] + #[test] + fn case_5( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 + + let expr = filter_numeric_expr_generation( + left_col.clone(), + right_col.clone(), + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + 10, + 5, + 30, + 3, + ); + // l > r + 5 AND r > l - 27 + let l_gt_r = 5; + let r_gt_l = -27; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 5 AND l < r + 27 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs new file mode 100644 index 000000000000..39ce14a401d1 --- /dev/null +++ b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs @@ -0,0 +1,697 @@ +use std::borrow::Borrow; + +use arrow::datatypes::DataType; + +use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; + +use datafusion_expr::Operator; + +use std::fmt; +use std::fmt::{Display, Formatter}; + +use crate::aggregate::min_max::{max, min}; + +use arrow::compute::{kernels, CastOptions}; +use std::ops::{Add, Sub}; + +#[derive(Debug, PartialEq, Clone, Eq, Hash)] +pub struct Range { + pub lower: ScalarValue, + pub upper: ScalarValue, +} + +impl Range { + fn new(lower: ScalarValue, upper: ScalarValue) -> Range { + Range { lower, upper } + } +} + +impl Add for Range { + type Output = Range; + + fn add(self, other: Range) -> Range { + Range::new( + self.lower.add(&other.lower).unwrap(), + self.upper.add(&other.upper).unwrap(), + ) + } +} + +impl Sub for Range { + type Output = Range; + + fn sub(self, other: Range) -> Range { + Range::new( + self.lower.add(&other.lower).unwrap(), + self.upper.sub(&other.lower).unwrap(), + ) + } +} + +impl Display for Interval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Interval::Singleton(value) => write!(f, "Singleton [{}]", value), + Interval::Range(Range { lower, upper }) => { + write!(f, "Range [{}, {}]", lower, upper) + } + } + } +} + +/// Casting data type with arrow kernel +fn cast_scalar_value( + value: &ScalarValue, + data_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let scalar_array = value.to_array(); + let cast_array = + kernels::cast::cast_with_options(&scalar_array, data_type, cast_options)?; + ScalarValue::try_from_array(&cast_array, 0) +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Interval { + Singleton(ScalarValue), + Range(Range), +} + +impl Interval { + pub(crate) fn cast_to( + &self, + data_type: &DataType, + cast_options: &CastOptions, + ) -> Result { + Ok(match self { + Interval::Range(Range { lower, upper }) => Interval::Range(Range { + lower: cast_scalar_value(lower, data_type, cast_options)?, + upper: cast_scalar_value(upper, data_type, cast_options)?, + }), + Interval::Singleton(value) => { + Interval::Singleton(cast_scalar_value(value, data_type, cast_options)?) + } + }) + } + pub(crate) fn is_boolean(&self) -> bool { + matches!( + self, + &Interval::Range(Range { + lower: ScalarValue::Boolean(_), + upper: ScalarValue::Boolean(_), + }) + ) + } + + pub fn lower_value(&self) -> ScalarValue { + match self { + Interval::Range(Range { lower, .. }) => lower.clone(), + Interval::Singleton(val) => val.clone(), + } + } + + pub fn upper_value(&self) -> ScalarValue { + match self { + Interval::Range(Range { upper, .. }) => upper.clone(), + Interval::Singleton(val) => val.clone(), + } + } + + pub(crate) fn get_datatype(&self) -> DataType { + match self { + Interval::Range(Range { lower, .. }) => lower.get_datatype(), + Interval::Singleton(val) => val.get_datatype(), + } + } + pub(crate) fn is_singleton(&self) -> bool { + match self { + Interval::Range(_) => false, + Interval::Singleton(_) => true, + } + } + pub(crate) fn is_strict_false(&self) -> bool { + matches!( + self, + &Interval::Range(Range { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + }) + ) + } + pub(crate) fn gt(&self, other: &Interval) -> Interval { + let bools = match (self, other) { + (Interval::Singleton(val), Interval::Singleton(val2)) => { + let equal = val > val2; + (equal, equal) + } + (Interval::Singleton(val), Interval::Range(Range { lower, upper })) => { + if val < lower { + (false, false) + } else if val > upper { + (true, true) + } else { + (false, true) + } + } + (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { + if val > upper { + (false, false) + } else if lower > val { + (true, true) + } else { + (false, true) + } + } + ( + Interval::Range(Range { lower, upper }), + Interval::Range(Range { + lower: lower2, + upper: upper2, + }), + ) => { + if upper < lower2 { + (false, false) + } else if lower > upper2 { + (true, true) + } else { + (false, true) + } + } + }; + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(bools.0)), + upper: ScalarValue::Boolean(Some(bools.1)), + }) + } + + pub(crate) fn and(&self, other: &Interval) -> Result { + let bools = match (self, other) { + ( + Interval::Singleton(ScalarValue::Boolean(Some(val))), + Interval::Singleton(ScalarValue::Boolean(Some(val2))), + ) => { + let equal = *val && *val2; + (equal, equal) + } + ( + Interval::Singleton(ScalarValue::Boolean(Some(val))), + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }), + ) + | ( + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }), + Interval::Singleton(ScalarValue::Boolean(Some(val))), + ) => { + if (*val && *lower) && *upper { + (true, true) + } else if !(*val && *lower) && *upper { + (false, false) + } else { + (false, true) + } + } + ( + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }), + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(lower2)), + upper: ScalarValue::Boolean(Some(upper2)), + }), + ) => { + if (*lower && *lower2) && (*upper && *upper2) { + (true, true) + } else if (*lower || *lower2) || (*upper || *upper2) { + (false, true) + } else { + (false, false) + } + } + _ => { + return Err(DataFusionError::Internal( + "Booleans cannot be None.".to_string(), + )) + } + }; + Ok(Interval::Range(Range { + lower: ScalarValue::Boolean(Some(bools.0)), + upper: ScalarValue::Boolean(Some(bools.1)), + })) + } + + pub(crate) fn equal(&self, other: &Interval) -> Interval { + let bool = match (self, other) { + (Interval::Singleton(val), Interval::Singleton(val2)) => { + let equal = val == val2; + (equal, equal) + } + (Interval::Singleton(val), Interval::Range(Range { lower, upper })) => { + let equal = lower >= val && val <= upper; + (equal, equal) + } + (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { + let equal = lower >= val && val <= upper; + (equal, equal) + } + ( + Interval::Range(Range { lower, upper }), + Interval::Range(Range { + lower: lower2, + upper: upper2, + }), + ) => { + if lower == lower2 && upper == upper2 { + (true, true) + } else if (lower < lower2 && upper < upper2) + || (lower > lower2 && upper > upper2) + { + (false, false) + } else { + (false, true) + } + } + }; + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(bool.0)), + upper: ScalarValue::Boolean(Some(bool.1)), + }) + } + + pub(crate) fn lt(&self, other: &Interval) -> Interval { + let bool = match (self, other) { + (Interval::Singleton(val), Interval::Singleton(val2)) => { + let result = val < val2; + (result, result) + } + (Interval::Singleton(ref val), Interval::Range(Range { lower, upper })) => { + if val > upper { + (false, false) + } else if val < lower { + (true, true) + } else { + (true, false) + } + } + ( + Interval::Range(Range { + ref lower, + ref upper, + }), + Interval::Singleton(val), + ) => { + if val < upper { + (false, false) + } else if lower > val { + (true, true) + } else { + (true, false) + } + } + ( + Interval::Range(Range { + ref lower, + ref upper, + }), + Interval::Range(Range { + lower: lower2, + upper: upper2, + }), + ) => { + if upper < lower2 { + (true, true) + } else if lower >= upper2 { + (false, false) + } else { + (true, false) + } + } + }; + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(bool.0)), + upper: ScalarValue::Boolean(Some(bool.1)), + }) + } + + pub(crate) fn intersect(&self, other: &Interval) -> Result { + let result = match (self, other) { + (Interval::Singleton(val1), Interval::Singleton(val2)) => { + if val1 == val2 { + Interval::Singleton(val1.clone()) + } else { + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + }) + } + } + (Interval::Singleton(val), Interval::Range(range)) => { + if val >= &range.lower && val <= &range.upper { + Interval::Singleton(val.clone()) + } else { + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + }) + } + } + (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { + if (lower.is_null() && upper.is_null()) || (val >= lower && val <= upper) + { + Interval::Singleton(val.clone()) + } else { + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + }) + } + } + ( + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(val)), + upper: ScalarValue::Boolean(Some(val2)), + }), + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(val3)), + upper: ScalarValue::Boolean(Some(val4)), + }), + ) => Interval::Range(Range { + lower: ScalarValue::Boolean(Some(*val || *val3)), + upper: ScalarValue::Boolean(Some(*val2 || *val4)), + }), + ( + Interval::Range(Range { + lower: lower1, + upper: upper1, + }), + Interval::Range(Range { + lower: lower2, + upper: upper2, + }), + ) => { + let lower = if lower1.is_null() { + lower2.clone() + } else if lower2.is_null() { + lower1.clone() + } else { + max(lower1, lower2)? + }; + let upper = if upper1.is_null() { + upper2.clone() + } else if upper2.is_null() { + upper1.clone() + } else { + min(upper1, upper2)? + }; + if !upper.is_null() && lower > upper { + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + }) + } else { + Interval::Range(Range { lower, upper }) + } + } + }; + Ok(result) + } + + pub(crate) fn arithmetic_negate(&self) -> Result { + Ok(match self { + Interval::Singleton(value) => Interval::Singleton(value.arithmetic_negate()?), + Interval::Range(Range { lower, upper }) => Interval::Range(Range { + lower: upper.arithmetic_negate()?, + upper: lower.arithmetic_negate()?, + }), + }) + } + + pub fn add>(&self, other: T) -> Result { + let rhs = other.borrow(); + let result = match (self, rhs) { + (Interval::Singleton(val), Interval::Singleton(val2)) => { + Interval::Singleton(val.add(val2).unwrap()) + } + (Interval::Singleton(val), Interval::Range(Range { lower, upper })) + | (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { + let new_lower = if lower.is_null() { + lower.clone() + } else { + lower.add(val).unwrap() + }; + let new_upper = if upper.is_null() { + upper.clone() + } else { + upper.add(val).unwrap() + }; + Interval::Range(Range { + lower: new_lower, + upper: new_upper, + }) + } + + ( + Interval::Range(Range { lower, upper }), + Interval::Range(Range { + lower: lower2, + upper: upper2, + }), + ) => { + let new_lower = if lower.is_null() || lower2.is_null() { + ScalarValue::try_from(lower.get_datatype()).unwrap() + } else { + lower.add(lower2).unwrap() + }; + + let new_upper = if upper.is_null() || upper2.is_null() { + ScalarValue::try_from(upper.get_datatype()).unwrap() + } else { + upper.add(upper2).unwrap() + }; + Interval::Range(Range { + lower: new_lower, + upper: new_upper, + }) + } + }; + Ok(result) + } + + pub fn sub>(&self, other: T) -> Result { + let rhs = other.borrow(); + let result = match (self, rhs) { + (Interval::Singleton(val), Interval::Singleton(val2)) => { + Interval::Singleton(val.sub(val2).unwrap()) + } + (Interval::Singleton(val), Interval::Range(Range { lower, upper })) => { + let new_lower = if lower.is_null() { + lower.clone() + } else { + lower.add(val)?.arithmetic_negate()? + }; + let new_upper = if upper.is_null() { + upper.clone() + } else { + upper.add(val)?.arithmetic_negate()? + }; + Interval::Range(Range { + lower: new_lower, + upper: new_upper, + }) + } + (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { + let new_lower = if lower.is_null() { + lower.clone() + } else { + lower.sub(val)? + }; + let new_upper = if upper.is_null() { + upper.clone() + } else { + upper.sub(val)? + }; + Interval::Range(Range { + lower: new_lower, + upper: new_upper, + }) + } + ( + Interval::Range(Range { lower, upper }), + Interval::Range(Range { + lower: lower2, + upper: upper2, + }), + ) => { + let new_lower = if lower.is_null() || upper2.is_null() { + ScalarValue::try_from(&lower.get_datatype())? + } else { + lower.sub(upper2)? + }; + + let new_upper = if upper.is_null() || lower2.is_null() { + ScalarValue::try_from(&upper.get_datatype())? + } else { + upper.sub(lower2)? + }; + + Interval::Range(Range { + lower: new_lower, + upper: new_upper, + }) + } + }; + Ok(result) + } + pub fn mul>(&self, other: T) -> Result { + let rhs = other.borrow(); + let result = match (self, rhs) { + (Interval::Singleton(val), Interval::Singleton(val2)) => { + Interval::Singleton(val.sub(val2).unwrap()) + } + (Interval::Singleton(val), Interval::Range(Range { lower, upper })) => { + let new_lower = if lower.is_null() { + lower.clone() + } else { + lower.add(val)?.arithmetic_negate()? + }; + let new_upper = if upper.is_null() { + upper.clone() + } else { + upper.add(val)?.arithmetic_negate()? + }; + Interval::Range(Range { + lower: new_lower, + upper: new_upper, + }) + } + (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { + let new_lower = if lower.is_null() { + lower.clone() + } else { + lower.sub(val)? + }; + let new_upper = if upper.is_null() { + upper.clone() + } else { + upper.sub(val)? + }; + Interval::Range(Range { + lower: new_lower, + upper: new_upper, + }) + } + ( + Interval::Range(Range { lower, upper }), + Interval::Range(Range { + lower: lower2, + upper: upper2, + }), + ) => { + let new_lower = if lower.is_null() || upper2.is_null() { + ScalarValue::try_from(&lower.get_datatype())? + } else { + lower.sub(upper2)? + }; + + let new_upper = if upper.is_null() || lower2.is_null() { + ScalarValue::try_from(&upper.get_datatype())? + } else { + upper.sub(lower2)? + }; + + Interval::Range(Range { + lower: new_lower, + upper: new_upper, + }) + } + }; + Ok(result) + } +} + +pub fn apply_operator(lhs: &Interval, op: &Operator, rhs: &Interval) -> Result { + match *op { + Operator::Eq => Ok(lhs.equal(rhs)), + Operator::Gt => Ok(lhs.gt(rhs)), + Operator::Lt => Ok(lhs.lt(rhs)), + Operator::And => lhs.and(rhs), + Operator::Plus => lhs.add(rhs), + Operator::Minus => lhs.sub(rhs), + _ => Ok(Interval::Singleton(ScalarValue::Null)), + } +} + +pub fn target_parent_interval(datatype: &DataType, op: &Operator) -> Result { + let inf = ScalarValue::try_from(datatype)?; + let zero = ScalarValue::try_from_string("0".to_string(), datatype)?; + let parent_interval = match *op { + // [x1, y1] > [x2, y2] ise + // [x1, y1] + [-y2, -x2] = [0, inf] + // [x1_new, x2_new] = ([0, inf] - [-y2, -x2]) intersect [x1, y1] + // [-y2_new, -x2_new] = ([0, inf] - [x1_new, x2_new]) intersect [-y2, -x2] + Operator::Gt => { + // [0, inf] + Interval::Range(Range { + lower: zero, + upper: inf, + }) + } + // [x1, y1] < [x2, y2] ise + // [x1, y1] + [-y2, -x2] = [-inf, 0] + Operator::Lt => { + // [-inf, 0] + Interval::Range(Range { + lower: inf, + upper: zero, + }) + } + _ => unreachable!(), + }; + Ok(parent_interval) +} + +pub fn propagate_logical_operators( + left_interval: &Interval, + op: &Operator, + right_interval: &Interval, +) -> Result<(Interval, Interval)> { + if left_interval.is_boolean() || right_interval.is_boolean() { + return Ok(( + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(true)), + upper: ScalarValue::Boolean(Some(true)), + }), + Interval::Range(Range { + lower: ScalarValue::Boolean(Some(true)), + upper: ScalarValue::Boolean(Some(true)), + }), + )); + } + + let intersection = left_interval.clone().intersect(right_interval)?; + if intersection.is_strict_false() { + return Ok((left_interval.clone(), right_interval.clone())); + } + let parent_interval = target_parent_interval(&left_interval.get_datatype(), op)?; + let negate_right = right_interval.arithmetic_negate()?; + let new_left = parent_interval + .sub(&negate_right)? + .intersect(left_interval)?; + let new_right = parent_interval.sub(&new_left)?.intersect(&negate_right)?; + let range = (new_left, new_right.arithmetic_negate()?); + Ok(range) +} + +pub fn negate_ops(op: Operator) -> Operator { + match op { + Operator::Plus => Operator::Minus, + Operator::Minus => Operator::Plus, + _ => unreachable!(), + } +} diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs new file mode 100644 index 000000000000..11b7db4a2316 --- /dev/null +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -0,0 +1,4 @@ +mod cp_solver; +pub mod interval_aritmetics; + +pub use cp_solver::ExprIntervalGraph; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index c9658a048ca8..e15da3a1e753 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -26,8 +26,10 @@ pub mod execution_props; pub mod expressions; pub mod functions; pub mod hash_utils; +pub mod intervals; pub mod math_expressions; mod physical_expr; +pub mod physical_expr_visitor; pub mod planner; #[cfg(feature = "regex_expressions")] pub mod regex_expressions; diff --git a/datafusion/physical-expr/src/physical_expr_visitor.rs b/datafusion/physical-expr/src/physical_expr_visitor.rs new file mode 100644 index 000000000000..e916e0d8b178 --- /dev/null +++ b/datafusion/physical-expr/src/physical_expr_visitor.rs @@ -0,0 +1,98 @@ +use crate::expressions::{ + BinaryExpr, CastExpr, Column, DateTimeIntervalExpr, GetIndexedFieldExpr, InListExpr, +}; +use crate::PhysicalExpr; + +use datafusion_common::Result; + +use std::sync::Arc; + +/// Controls how the visitor recursion should proceed. +pub enum Recursion { + /// Attempt to visit all the children, recursively, of this expression. + Continue(V), + /// Do not visit the children of this expression, though the walk + /// of parents of this expression will not be affected + Stop(V), +} + +/// Encode the traversal of an expression tree. When passed to +/// `Expr::accept`, `PhysicalExpressionVisitor::visit` is invoked +/// recursively on all nodes of an expression tree. See the comments +/// on `Expr::accept` for details on its use +pub trait PhysicalExpressionVisitor>: + Sized +{ + /// Invoked before any children of `expr` are visited. + fn pre_visit(self, expr: E) -> Result> + where + Self: PhysicalExpressionVisitor; + + /// Invoked after all children of `expr` are visited. Default + /// implementation does nothing. + fn post_visit(self, _expr: E) -> Result { + Ok(self) + } +} + +/// trait for types that can be visited by [`ExpressionVisitor`] +pub trait ExprVisitable: Sized { + /// accept a visitor, calling `visit` on all children of this + fn accept>(self, visitor: V) -> Result; +} + +// TODO: Widen the options. +impl ExprVisitable for Arc { + fn accept>(self, visitor: V) -> Result { + let visitor = match visitor.pre_visit(self.clone())? { + Recursion::Continue(visitor) => visitor, + // If the recursion should stop, do not visit children + Recursion::Stop(visitor) => return Ok(visitor), + }; + let visitor = + if let Some(binary_expr) = self.as_any().downcast_ref::() { + let left_expr = binary_expr.left().clone(); + let right_expr = binary_expr.right().clone(); + let visitor = left_expr.accept(visitor)?; + right_expr.accept(visitor) + } else if let Some(datatime_expr) = + self.as_any().downcast_ref::() + { + let lhs = datatime_expr.lhs().clone(); + let rhs = datatime_expr.rhs().clone(); + let visitor = lhs.accept(visitor)?; + rhs.accept(visitor) + } else if let Some(cast) = self.as_any().downcast_ref::() { + let expr = cast.expr().clone(); + expr.accept(visitor) + } else if let Some(get_index) = + self.as_any().downcast_ref::() + { + let arg = get_index.arg().clone(); + arg.accept(visitor) + } else if let Some(inlist_expr) = self.as_any().downcast_ref::() { + let expr = inlist_expr.expr().clone(); + let list = inlist_expr.list(); + let visitor = expr.clone().accept(visitor)?; + list.iter() + .try_fold(visitor, |visitor, arg| arg.clone().accept(visitor)) + } else { + Ok(visitor) + }?; + visitor.post_visit(self.clone()) + } +} + +#[derive(Default, Debug)] +pub struct PostOrderPhysicalColumnCollector { + pub columns: Vec, +} + +impl PhysicalExpressionVisitor for PostOrderPhysicalColumnCollector { + fn pre_visit(mut self, expr: Arc) -> Result> { + if let Some(column) = expr.as_any().downcast_ref::() { + self.columns.push(column.clone()) + } + Ok(Recursion::Continue(self)) + } +} diff --git a/datafusion/physical-expr/src/rewrite.rs b/datafusion/physical-expr/src/rewrite.rs index 327eabd4b0d0..c729fafb1062 100644 --- a/datafusion/physical-expr/src/rewrite.rs +++ b/datafusion/physical-expr/src/rewrite.rs @@ -113,6 +113,21 @@ pub trait TreeNodeRewritable: Clone { Ok(new_node) } + fn mutable_transform_up(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op_children = + self.map_children(|node| node.mutable_transform_up(op))?; + + let after_op_children_clone = after_op_children.clone(); + let new_node = match op(after_op_children)? { + Some(value) => value, + None => after_op_children_clone, + }; + Ok(new_node) + } + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result where diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index ea64d739769f..7baaf9f91042 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -16,16 +16,18 @@ // under the License. use crate::equivalence::EquivalentClass; -use crate::expressions::BinaryExpr; use crate::expressions::Column; use crate::expressions::UnKnownColumn; +use crate::expressions::{BinaryExpr, Literal}; use crate::rewrite::TreeNodeRewritable; use crate::PhysicalExpr; use crate::PhysicalSortExpr; -use datafusion_expr::Operator; - use arrow::datatypes::SchemaRef; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Operator; +use petgraph::graph::NodeIndex; +use petgraph::stable_graph::StableGraph; use std::collections::HashMap; use std::sync::Arc; @@ -185,6 +187,186 @@ pub fn normalize_sort_expr_with_equivalence_properties( } } +#[derive(Clone, Debug)] +pub struct ExprTreeNode { + expr: Arc, + node: Option, + child_nodes: Vec>, +} + +impl PartialEq for ExprTreeNode { + fn eq(&self, other: &ExprTreeNode) -> bool { + self.expr.eq(&other.expr) + } +} + +impl ExprTreeNode { + pub fn new(plan: Arc) -> Self { + ExprTreeNode { + expr: plan, + node: None, + child_nodes: vec![], + } + } + + pub fn node(&self) -> &Option { + &self.node + } + + pub fn child_nodes(&self) -> &Vec> { + &self.child_nodes + } + + pub fn expr(&self) -> Arc { + self.expr.clone() + } + + pub fn children(&self) -> Vec> { + let plan_children = self.expr.children(); + let child_intervals: Vec> = self.child_nodes.to_vec(); + if plan_children.len() == child_intervals.len() { + child_intervals + } else { + plan_children + .into_iter() + .map(|child| { + Arc::new(ExprTreeNode { + expr: child.clone(), + node: None, + child_nodes: vec![], + }) + }) + .collect() + } + } +} + +impl TreeNodeRewritable for Arc { + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if !children.is_empty() { + let new_children: Vec> = children + .into_iter() + .map(transform) + .collect::>>>() + .unwrap(); + Ok(Arc::new(ExprTreeNode { + expr: self.expr.clone(), + node: None, + child_nodes: new_children, + })) + } else { + Ok(self) + } + } +} + +fn add_physical_expr_to_graph( + input: Arc, + visited: &mut Vec<(Arc, NodeIndex)>, + graph: &mut StableGraph, + input_node: T, +) -> NodeIndex { + match visited + .iter() + .find(|(visited_expr, _node)| visited_expr.eq(&input.expr())) + { + Some((_, idx)) => *idx, + None => { + let node_idx = graph.add_node(input_node); + visited.push((input.expr().clone(), node_idx)); + input.child_nodes().iter().for_each(|expr_node| { + graph.add_edge(node_idx, expr_node.node().unwrap(), 0); + }); + node_idx + } + } +} + +fn post_order_tree_traverse_graph_create( + input: Arc, + graph: &mut StableGraph, + constructor: F, + visited: &mut Vec<(Arc, NodeIndex)>, +) -> Result>> +where + F: Fn(Arc) -> T, +{ + let node = constructor(input.clone()); + let node_idx = add_physical_expr_to_graph(input.clone(), visited, graph, node); + Ok(Some(Arc::new(ExprTreeNode { + expr: input.expr.clone(), + node: Some(node_idx), + child_nodes: input.child_nodes.to_vec(), + }))) +} + +pub fn build_physical_expr_graph( + expr: Arc, + constructor: &F, +) -> Result<(NodeIndex, StableGraph)> +where + F: Fn(Arc) -> T, +{ + let init = Arc::new(ExprTreeNode::new(expr.clone())); + let mut graph: StableGraph = StableGraph::new(); + // TODO: Make membership check O(1) + let mut visited_plans: Vec<(Arc, NodeIndex)> = vec![]; + // Create a graph + let root_tree_node = init.mutable_transform_up(&mut |expr| { + post_order_tree_traverse_graph_create( + expr, + &mut graph, + constructor, + &mut visited_plans, + ) + })?; + Ok((root_tree_node.node.unwrap(), graph)) +} + +#[allow(clippy::too_many_arguments)] +/// left_col (op_1) a > right_col (op_2) b AND left_col (op_3) c < right_col (op_4) d +pub fn filter_numeric_expr_generation( + left_col: Arc, + right_col: Arc, + op_1: Operator, + op_2: Operator, + op_3: Operator, + op_4: Operator, + a: i32, + b: i32, + c: i32, + d: i32, +) -> Arc { + let left_and_1 = Arc::new(BinaryExpr::new( + left_col.clone(), + op_1, + Arc::new(Literal::new(ScalarValue::Int32(Some(a)))), + )); + let left_and_2 = Arc::new(BinaryExpr::new( + right_col.clone(), + op_2, + Arc::new(Literal::new(ScalarValue::Int32(Some(b)))), + )); + + let right_and_1 = Arc::new(BinaryExpr::new( + left_col.clone(), + op_3, + Arc::new(Literal::new(ScalarValue::Int32(Some(c)))), + )); + let right_and_2 = Arc::new(BinaryExpr::new( + right_col.clone(), + op_4, + Arc::new(Literal::new(ScalarValue::Int32(Some(d)))), + )); + let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); + let right_expr = Arc::new(BinaryExpr::new(right_and_1, Operator::Lt, right_and_2)); + Arc::new(BinaryExpr::new(left_expr, Operator::And, right_expr)) +} + #[cfg(test)] mod tests { From b156bc43ce25b90310d6d20242fb96787255a38c Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Fri, 20 Jan 2023 17:17:30 +0300 Subject: [PATCH 02/39] Minor changes after merge --- datafusion-cli/Cargo.lock | 140 ++++--- .../physical_plan/joins/hash_join_utils.rs | 61 ++- .../joins/symmetric_hash_join.rs | 369 ++++++++++-------- .../core/src/physical_plan/joins/utils.rs | 14 +- .../physical-expr/src/intervals/cp_solver.rs | 138 +++---- .../src/intervals/interval_aritmetics.rs | 99 ++--- datafusion/physical-expr/src/intervals/mod.rs | 20 + .../src/physical_expr_visitor.rs | 17 + datafusion/physical-expr/src/utils.rs | 3 +- 9 files changed, 479 insertions(+), 382 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 7e1c7d25f045..7055d89a90ed 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -290,9 +290,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.61" +version = "0.1.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705339e0e4a9690e2908d2b3d049d85682cf19fbd5782494498fbf7003a6a282" +checksum = "689894c2db1ea643a50834b999abf1c110887402542955ff5451dab8f861f9ed" dependencies = [ "proc-macro2", "quote", @@ -316,12 +316,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "base64" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" - [[package]] name = "base64" version = "0.20.0" @@ -385,9 +379,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "2.3.2" +version = "2.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ad2d4653bf5ca36ae797b1f4bb4dbddb60ce49ca4aed8a2ce4829f60425b80" +checksum = "4b6561fd3f895a11e8f72af2cb7d22e08366bebc2b6b57f7744c4bda27034744" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -407,9 +401,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.11.1" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" [[package]] name = "byteorder" @@ -635,9 +629,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d1075c37807dcf850c379432f0df05ba52cc30f279c5cfc43cc221ce7f8579" +checksum = "b61a7545f753a88bcbe0a70de1fcc0221e10bfc752f576754fa91e663db1622e" dependencies = [ "cc", "cxxbridge-flags", @@ -647,9 +641,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5044281f61b27bc598f2f6647d480aed48d2bf52d6eb0b627d84c0361b17aa70" +checksum = "f464457d494b5ed6905c63b0c4704842aba319084a0a3561cdc1359536b53200" dependencies = [ "cc", "codespan-reporting", @@ -662,15 +656,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61b50bc93ba22c27b0d31128d2d130a0a6b3d267ae27ef7e4fae2167dfe8781c" +checksum = "43c7119ce3a3701ed81aca8410b9acf6fc399d2629d057b87e2efa4e63a3aaea" [[package]] name = "cxxbridge-macro" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e61fda7e62115119469c7b3591fd913ecca96fb766cfd3f2e2502ab7bc87a5" +checksum = "65e07508b90551e610910fa648a1878991d367064997a596135b86df30daf07e" dependencies = [ "proc-macro2", "quote", @@ -813,6 +807,7 @@ dependencies = [ "md-5", "num-traits", "paste", + "petgraph", "rand", "regex", "sha2", @@ -984,6 +979,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flatbuffers" version = "22.9.29" @@ -1341,9 +1342,9 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "io-lifetimes" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46112a93252b123d31a119a8d1a1ac19deac4fac6e0e8b0df58f0d4e5870e63c" +checksum = "e7d6c6f8c91b4b9ed43484ad1a938e393caf35960fce7f82a040497207bd8e9e" dependencies = [ "libc", "windows-sys", @@ -1613,10 +1614,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.24.3" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa52e972a9a719cecb6864fb88568781eb706bac2cd1d4f04a648542dbf78069" +checksum = "f346ff70e7dbfd675fe90590b92d59ef2de15a8779ae305ebcbfd3f0caf59be4" dependencies = [ + "autocfg", "bitflags", "cfg-if", "libc", @@ -1649,9 +1651,9 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ae39348c8bc5fbd7f40c727a9925f03517afd2ab27d46702108b6a7e5414c19" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" dependencies = [ "num-traits", ] @@ -1826,6 +1828,16 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +[[package]] +name = "petgraph" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5014253a1331579ce62aa67443b4a658c5e7dd03d4bc6d302b94474888143" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -1882,9 +1894,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.49" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5" +checksum = "6ef7d57beacfaf2d8aee5937dab7b7f28de3cb8b1828479bb5de2a7106f2bae2" dependencies = [ "unicode-ident", ] @@ -2002,11 +2014,11 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.11.13" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68cc60575865c7831548863cc02356512e3f1dc2f3f82cb837d7fc4cc8f3c97c" +checksum = "21eed90ec8570952d53b772ecf8f206aa1ec9a3d76b2521c56c42973f2d91ee9" dependencies = [ - "base64 0.13.1", + "base64 0.21.0", "bytes", "encoding_rs", "futures-core", @@ -2035,6 +2047,7 @@ dependencies = [ "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "winreg", @@ -2057,9 +2070,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.36.6" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4feacf7db682c6c329c4ede12649cd36ecab0f3be5b7d74e6a20304725db4549" +checksum = "d4fdebc4b395b7fbb9ab11e462e20ed9051e7b16e42d24042c776eca0ac81b03" dependencies = [ "bitflags", "errno", @@ -2071,9 +2084,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.20.7" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c" +checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" dependencies = [ "log", "ring", @@ -2098,9 +2111,9 @@ checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70" [[package]] name = "rustyline" -version = "10.0.0" +version = "10.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1cd5ae51d3f7bf65d7969d579d502168ef578f289452bd8ccc91de28fda20e" +checksum = "c1e83c32c3f3c33b08496e0d1df9ea8c64d39adb8eb36a1ebb1440c690697aef" dependencies = [ "bitflags", "cfg-if", @@ -2366,9 +2379,9 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.1.3" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" dependencies = [ "winapi-util", ] @@ -2436,9 +2449,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.24.1" +version = "1.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d9f76183f91ecfb55e1d7d5602bd1d979e38a3a522fe900241cf195624d67ae" +checksum = "597a12a59981d9e3c38d216785b0c37399f6e415e8d0712047620f189371b0bb" dependencies = [ "autocfg", "bytes", @@ -2562,9 +2575,9 @@ checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "unicode-bidi" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992" +checksum = "0046be40136ef78dc325e0edefccf84ccddacd0afcc1ca54103fa3c61bbdab1d" [[package]] name = "unicode-ident" @@ -2724,6 +2737,19 @@ version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" +[[package]] +name = "wasm-streams" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.60" @@ -2801,45 +2827,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" +checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" [[package]] name = "windows_aarch64_msvc" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" +checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" [[package]] name = "windows_i686_gnu" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" +checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" [[package]] name = "windows_i686_msvc" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" +checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" [[package]] name = "windows_x86_64_gnu" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" +checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" +checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" [[package]] name = "windows_x86_64_msvc" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" +checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" [[package]] name = "winreg" diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index d5423068c079..170732293223 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -1,3 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Hash Join related functionality used both on logical and physical plans +//! use std::{fmt, usize}; use ahash::RandomState; @@ -67,7 +86,7 @@ pub fn update_hash( Ok(()) } -// Get left and right indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join +// Get build and probe indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join #[allow(clippy::too_many_arguments)] pub fn build_join_indices( probe_batch: &RecordBatch, @@ -438,7 +457,7 @@ pub fn equal_rows( err.unwrap_or(Ok(res)) } -// Returns the index of equal condition join result: left_indices and right_indices +// Returns the index of equal condition join result: build_indices and probe_indices // On LEFT.b1 = RIGHT.b2 // LEFT Table: // a1 b1 c1 @@ -466,9 +485,9 @@ pub fn equal_rows( // "| 13 | 10 | 130 | 12 | 10 | 120 |", // "| 9 | 8 | 90 | 8 | 8 | 80 |", // "+----+----+-----+----+----+-----+" -// And the result of left and right indices -// left indices: 5, 6, 6, 4 -// right indices: 3, 4, 5, 3 +// And the result of build and probe indices +// build indices: 5, 6, 6, 4 +// probe indices: 3, 4, 5, 3 #[allow(clippy::too_many_arguments)] pub fn build_equal_condition_join_indices( build_hashmap: &JoinHashMap, @@ -496,14 +515,14 @@ pub fn build_equal_condition_join_indices( hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; // Using a buffer builder to avoid slower normal builder - let mut left_indices = UInt64BufferBuilder::new(0); - let mut right_indices = UInt32BufferBuilder::new(0); + let mut build_indices = UInt64BufferBuilder::new(0); + let mut probe_indices = UInt32BufferBuilder::new(0); let offset_value = offset.unwrap_or(0); - // Visit all of the right rows + // Visit all of the probe rows for (row, hash_value) in hash_values.iter().enumerate() { // Get the hash and find it in the build index - // For every item on the left and right we check if it matches + // For every item on the build and probe we check if it matches // This possibly contains rows with hash collisions, // So we have to check here whether rows are equal or not if let Some((_, indices)) = build_hashmap @@ -521,34 +540,34 @@ pub fn build_equal_condition_join_indices( &keys_values, *null_equals_null, )? { - left_indices.append(offset_build_index as u64); - right_indices.append(row as u32); + build_indices.append(offset_build_index as u64); + probe_indices.append(row as u32); } } } } - let left = ArrayData::builder(DataType::UInt64) - .len(left_indices.len()) - .add_buffer(left_indices.finish()) + let build = ArrayData::builder(DataType::UInt64) + .len(build_indices.len()) + .add_buffer(build_indices.finish()) .build() .unwrap(); - let right = ArrayData::builder(DataType::UInt32) - .len(right_indices.len()) - .add_buffer(right_indices.finish()) + let probe = ArrayData::builder(DataType::UInt32) + .len(probe_indices.len()) + .add_buffer(probe_indices.finish()) .build() .unwrap(); Ok(( - PrimitiveArray::::from(left), - PrimitiveArray::::from(right), + PrimitiveArray::::from(build), + PrimitiveArray::::from(probe), )) } -// Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. +// Maps a `u64` hash value based on the build ["on" values] to a list of indices with this key's value. // // Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used // to put the indices in a certain bucket. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the left side, +// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, // we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. // E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 // As the key is a hash value, we need to check possible hash collisions in the probe stage diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 16652c04edb4..f31e7b95b5e4 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -49,8 +49,9 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; use crate::logical_expr::JoinType; use crate::physical_plan::common::merge_batches; -use crate::physical_plan::joins::hash_join_utils::update_hash; -use crate::physical_plan::joins::hash_join_utils::{build_join_indices, JoinHashMap}; +use crate::physical_plan::joins::hash_join_utils::{ + build_join_indices, update_hash, JoinHashMap, +}; use crate::physical_plan::joins::utils::build_batch_from_indices; use crate::physical_plan::{ expressions::Column, @@ -80,7 +81,9 @@ pub struct SortedFilterExpr { // Sort option sort_option: SortOptions, // Interval - interval: Option, + interval: Interval, + // NodeIndex in Graph + node_index: usize, } impl SortedFilterExpr { @@ -96,7 +99,8 @@ impl SortedFilterExpr { origin_expr, filter_expr, sort_option, - interval: None, + interval: Interval::default(), + node_index: 0, } } /// Get origin expr information @@ -112,13 +116,17 @@ impl SortedFilterExpr { self.sort_option } /// Get interval information - pub fn interval(&self) -> &Option { + pub fn interval(&self) -> &Interval { &self.interval } /// Sets interval - pub fn set_interval(&mut self, interval: Option) { + pub fn set_interval(&mut self, interval: Interval) { self.interval = interval; } + /// Node index in ExprIntervalGraph + pub fn node_index(&self) -> usize { + self.node_index + } } /// The symmetric hash join is a special type of hash join designed for data streams and large @@ -407,6 +415,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { // TODO: Discuss unbounded mpsc and bounded one. let left_stream = self.left.execute(partition, context.clone())?; let right_stream = self.right.execute(partition, context)?; + let physical_expr_graph = ExprIntervalGraph::try_new( self.filter.expression().clone(), &self @@ -415,6 +424,18 @@ impl ExecutionPlan for SymmetricHashJoinExec { .map(|sorted_expr| sorted_expr.filter_expr()) .collect_vec(), )?; + // We inject calculated node indexes into SortedFilterExpr. In graph calculations, + // we will be using node index to put calculated intervals into Columns or BinaryExprs . + let node_injected_filter_exprs = self + .filter_columns + .iter() + .zip(physical_expr_graph.2.iter()) + .map(|(sorted_expr, (_, index))| { + let mut temp = sorted_expr.clone(); + temp.node_index = *index; + temp + }) + .collect_vec(); Ok(Box::pin(SymmetricHashJoinStream { left_stream, @@ -431,7 +452,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { join_metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), physical_expr_graph, null_equals_null: self.null_equals_null, - filter_columns: self.filter_columns.clone(), + filter_columns: node_injected_filter_exprs, final_result: false, data_side: JoinSide::Left, })) @@ -563,7 +584,7 @@ fn column_stats_two_side( }; let value = ScalarValue::try_from_array(&array, 0)?; let infinite = ScalarValue::try_from(value.get_datatype())?; - *interval = Some(if sort_option.descending { + *interval = if sort_option.descending { Interval::Range(Range { lower: infinite, upper: value, @@ -573,7 +594,7 @@ fn column_stats_two_side( lower: value, upper: infinite, }) - }); + }; } Ok(()) } @@ -599,9 +620,9 @@ fn determine_prune_length( .unwrap() .into_array(buffer.num_rows()); let target = if sort_option.descending { - interval.as_ref()?.upper_value() + interval.upper_value() } else { - interval.as_ref()?.lower_value() + interval.lower_value() }; Some(bisect::(&[batch_arr], &[target], &[*sort_option])) } else { @@ -893,20 +914,26 @@ impl OneSideHashJoiner { filter_columns, self.build_side, )?; - let mut filter_intervals: Vec<(Arc, Interval)> = filter_columns + + // We use Vec<(usize, Interval)> instead of Hashmap since the expected + // filter exprs relatively low and conversion between Vec and Hashmap + // back and forth may be slower. + + // Get calculated interval from Vec into a Vec. + // TODO: For ozan, abi hesaplamadan Vec verebilirdik direkt, + // ancak cyclic dependency olusturuyor dosyalar arası. SortedFilterExpr struct'ı common'a + // alabiliriz ve bu hesaplama olmadan SortedFilterExpr dogrudan graf icinde kullanilabilir. + let mut filter_intervals: Vec<(usize, Interval)> = filter_columns .iter() - .map(|sorted_expr| { - ( - sorted_expr.filter_expr(), - sorted_expr.interval().as_ref().unwrap().clone(), - ) - }) + .map(|sorted_expr| (sorted_expr.node_index(), sorted_expr.interval().clone())) .collect_vec(); + // Use this vector to seed the child PhysicalExpr interval. physical_expr_graph.calculate_new_intervals(&mut filter_intervals)?; + // Mutate the Vec for for (sorted_expr, (_, interval)) in filter_columns.iter_mut().zip(filter_intervals.into_iter()) { - sorted_expr.set_interval(Some(interval.clone())) + sorted_expr.set_interval(interval.clone()) } let prune_length = @@ -963,161 +990,177 @@ impl SymmetricHashJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { - if self.final_result { - return Poll::Ready(None); - } - if self.right.exhausted && self.left.exhausted { - let left_result = self.left.build_side_determined_results( - self.schema.clone(), - self.left.input_buffer.num_rows(), - self.right.input_buffer.schema(), - self.join_type, - &self.column_indices, - ); - let right_result = self.right.build_side_determined_results( - self.schema.clone(), - self.right.input_buffer.num_rows(), - self.left.input_buffer.schema(), - self.join_type, - &self.column_indices, - ); - self.final_result = true; - let result = - produce_batch_result(self.schema.clone(), left_result, right_result); - if let Ok(batch) = &result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + loop { + if self.final_result { + return Poll::Ready(None); } - return Poll::Ready(Some(result)); - } - if self.data_side.eq(&JoinSide::Left) { - match self.left_stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(probe_batch))) => { - self.join_metrics.left_input_batches.add(1); - self.join_metrics - .left_input_rows - .add(probe_batch.num_rows()); - match self.left.update_internal_state( - &self.on_left, - &probe_batch, - &self.random_state, - ) { - Ok(_) => {} - Err(e) => { - return Poll::Ready(Some(Err(ArrowError::ComputeError( - e.to_string(), - )))) + if self.right.exhausted && self.left.exhausted { + let left_result = self.left.build_side_determined_results( + self.schema.clone(), + self.left.input_buffer.num_rows(), + self.right.input_buffer.schema(), + self.join_type, + &self.column_indices, + ); + let right_result = self.right.build_side_determined_results( + self.schema.clone(), + self.right.input_buffer.num_rows(), + self.left.input_buffer.schema(), + self.join_type, + &self.column_indices, + ); + self.final_result = true; + let result = + produce_batch_result(self.schema.clone(), left_result, right_result); + if let Ok(batch) = &result { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + return Poll::Ready(Some(result)); + } + if self.data_side.eq(&JoinSide::Left) { + match self.left_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(probe_batch))) => { + self.join_metrics.left_input_batches.add(1); + self.join_metrics + .left_input_rows + .add(probe_batch.num_rows()); + match self.left.update_internal_state( + &self.on_left, + &probe_batch, + &self.random_state, + ) { + Ok(_) => {} + Err(e) => { + return Poll::Ready(Some(Err(ArrowError::ComputeError( + e.to_string(), + )))) + } } + // Using right as build side. + let equal_result = self.right.record_batch_from_other_side( + self.schema.clone(), + self.join_type, + &self.on_right, + &self.on_left, + Some(&self.filter), + &probe_batch, + &mut self.left.visited_rows, + self.left.offset, + &self.column_indices, + &self.random_state, + &self.null_equals_null, + ); + self.left.offset += probe_batch.num_rows(); + // Right side will be pruned since the batch coming from left. + let anti_result = self.right.prune_build_side( + self.schema.clone(), + &probe_batch, + &mut self.filter_columns, + self.join_type, + &self.column_indices, + &mut self.physical_expr_graph, + ); + let result = produce_batch_result( + self.schema.clone(), + equal_result, + anti_result, + ); + if let Ok(batch) = &result { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + if !self.right.exhausted { + self.data_side = JoinSide::Right; + } + return Poll::Ready(Some(result)); } - // Using right as build side. - let equal_result = self.right.record_batch_from_other_side( - self.schema.clone(), - self.join_type, - &self.on_right, - &self.on_left, - Some(&self.filter), - &probe_batch, - &mut self.left.visited_rows, - self.left.offset, - &self.column_indices, - &self.random_state, - &self.null_equals_null, - ); - self.left.offset += probe_batch.num_rows(); - // Right side will be pruned since the batch coming from left. - let anti_result = self.right.prune_build_side( - self.schema.clone(), - &probe_batch, - &mut self.filter_columns, - self.join_type, - &self.column_indices, - &mut self.physical_expr_graph, - ); - let result = produce_batch_result( - self.schema.clone(), - equal_result, - anti_result, - ); - if let Ok(batch) = &result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - if !self.right.exhausted { + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.left.exhausted = true; self.data_side = JoinSide::Right; + continue; } - Poll::Ready(Some(result)) - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - self.left.exhausted = true; - self.data_side = JoinSide::Right; - Poll::Ready(Some(Ok(RecordBatch::new_empty(self.schema())))) - } - Poll::Pending => Poll::Pending, - } - } else { - match self.right_stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(probe_batch))) => { - self.join_metrics.right_input_batches.add(1); - self.join_metrics - .right_input_rows - .add(probe_batch.num_rows()); - // Right is build side - match self.right.update_internal_state( - &self.on_right, - &probe_batch, - &self.random_state, - ) { - Ok(_) => {} - Err(e) => { - return Poll::Ready(Some(Err(ArrowError::ComputeError( - e.to_string(), - )))) + Poll::Pending => { + if !self.right.exhausted { + self.data_side = JoinSide::Right; + continue; + } else { + return Poll::Pending; } } - let equal_result = self.left.record_batch_from_other_side( - self.schema.clone(), - self.join_type, - &self.on_left, - &self.on_right, - Some(&self.filter), - &probe_batch, - &mut self.right.visited_rows, - self.right.offset, - &self.column_indices, - &self.random_state, - &self.null_equals_null, - ); - self.right.offset += probe_batch.num_rows(); - let anti_result = self.left.prune_build_side( - self.schema.clone(), - &probe_batch, - &mut self.filter_columns, - self.join_type, - &self.column_indices, - &mut self.physical_expr_graph, - ); - let result = produce_batch_result( - self.schema.clone(), - equal_result, - anti_result, - ); - if let Ok(batch) = &result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + } + } else { + match self.right_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(probe_batch))) => { + self.join_metrics.right_input_batches.add(1); + self.join_metrics + .right_input_rows + .add(probe_batch.num_rows()); + // Right is build side + match self.right.update_internal_state( + &self.on_right, + &probe_batch, + &self.random_state, + ) { + Ok(_) => {} + Err(e) => { + return Poll::Ready(Some(Err(ArrowError::ComputeError( + e.to_string(), + )))) + } + } + let equal_result = self.left.record_batch_from_other_side( + self.schema.clone(), + self.join_type, + &self.on_left, + &self.on_right, + Some(&self.filter), + &probe_batch, + &mut self.right.visited_rows, + self.right.offset, + &self.column_indices, + &self.random_state, + &self.null_equals_null, + ); + self.right.offset += probe_batch.num_rows(); + let anti_result = self.left.prune_build_side( + self.schema.clone(), + &probe_batch, + &mut self.filter_columns, + self.join_type, + &self.column_indices, + &mut self.physical_expr_graph, + ); + let result = produce_batch_result( + self.schema.clone(), + equal_result, + anti_result, + ); + if let Ok(batch) = &result { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + if !self.left.exhausted { + self.data_side = JoinSide::Left; + } + return Poll::Ready(Some(result)); } - if !self.left.exhausted { + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.right.exhausted = true; self.data_side = JoinSide::Left; + continue; + } + Poll::Pending => { + if !self.left.exhausted { + self.data_side = JoinSide::Left; + continue; + } else { + return Poll::Pending; + } } - Poll::Ready(Some(result)) - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - self.right.exhausted = true; - self.data_side = JoinSide::Left; - Poll::Ready(Some(Ok(RecordBatch::new_empty(self.schema())))) } - Poll::Pending => Poll::Pending, } } } diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index f25ec93abc69..80c882badcb8 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -17,12 +17,6 @@ //! Join related functionality used both on logical and physical plans -use std::cmp::max; -use std::collections::HashSet; -use std::future::Future; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::usize; use arrow::array::{ new_null_array, Array, BooleanBufferBuilder, PrimitiveArray, UInt32Array, UInt32Builder, UInt64Array, @@ -34,9 +28,15 @@ use arrow::record_batch::RecordBatch; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use parking_lot::Mutex; +use std::cmp::max; +use std::collections::HashSet; +use std::future::Future; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::usize; use datafusion_common::cast::as_boolean_array; -use datafusion_common::ScalarValue; +use datafusion_common::{ScalarValue, SharedResult}; use datafusion_physical_expr::rewrite::TreeNodeRewritable; use datafusion_physical_expr::{EquivalentClass, PhysicalExpr}; diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index bf316a6294e5..7dd2d82f6371 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -1,3 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! It solves constraint propagation problem for the custom PhysicalExpr +//! use crate::expressions::{BinaryExpr, CastExpr, Literal}; use crate::intervals::interval_aritmetics::{ apply_operator, negate_ops, propagate_logical_operators, Interval, @@ -20,7 +39,7 @@ use std::sync::Arc; /// Graph nodes contains interval information. pub struct ExprIntervalGraphNode { expr: Arc, - interval: Option, + interval: Interval, } impl Display for ExprIntervalGraphNode { @@ -34,16 +53,15 @@ impl ExprIntervalGraphNode { pub fn new(expr: Arc) -> Self { ExprIntervalGraphNode { expr, - interval: None, + interval: Interval::default(), } } /// Specify interval - pub fn with_interval(mut self, interval: Interval) -> Self { - self.interval = Some(interval); - self + pub fn new_with_interval(expr: Arc, interval: Interval) -> Self { + ExprIntervalGraphNode { expr, interval } } /// Get interval - pub fn interval(&self) -> &Option { + pub fn interval(&self) -> &Interval { &self.interval } @@ -54,7 +72,7 @@ impl ExprIntervalGraphNode { if let Some(literal) = plan_any.downcast_ref::() { // Create interval let interval = Interval::Singleton(literal.value().clone()); - ExprIntervalGraphNode::new(input.expr().clone()).with_interval(interval) + ExprIntervalGraphNode::new_with_interval(input.expr().clone(), interval) } else { ExprIntervalGraphNode::new(input.expr().clone()) } @@ -70,7 +88,7 @@ impl PartialEq for ExprIntervalGraphNode { pub struct ExprIntervalGraph( NodeIndex, StableGraph, - Vec<(Arc, usize)>, + pub Vec<(Arc, usize)>, ); /// This function calculates the current node interval @@ -101,6 +119,7 @@ impl ExprIntervalGraph { build_physical_expr_graph(expr, &ExprIntervalGraphNode::expr_node_builder)?; let mut bfs = Bfs::new(&graph, root_node); let mut will_be_removed = vec![]; + // We preserve the order with Vec in SymmetricHashJoin. let mut expr_node_indices: Vec<(Arc, usize)> = provided_expr .iter() .map(|e| (e.clone(), usize::MAX)) @@ -130,7 +149,7 @@ impl ExprIntervalGraph { pub fn calculate_new_intervals( &mut self, - expr_stats: &mut [(Arc, Interval)], + expr_stats: &mut [(usize, Interval)], ) -> Result<()> { self.post_order_interval_calculation(expr_stats)?; self.pre_order_interval_propagation(expr_stats)?; @@ -139,7 +158,7 @@ impl ExprIntervalGraph { fn post_order_interval_calculation( &mut self, - expr_stats: &[(Arc, Interval)], + expr_stats: &[(usize, Interval)], ) -> Result<()> { let mut dfs = DfsPostOrder::new(&self.1, self.0); while let Some(node) = dfs.next(&self.1) { @@ -150,10 +169,10 @@ impl ExprIntervalGraph { // Check if we have a interval information about given PhysicalExpr, if so, directly // propagate it to the upper. if let Some((_, interval)) = - expr_stats.iter().find(|(e, _interval)| expr.eq(e)) + expr_stats.iter().find(|(e, _)| *e == node.index()) { let input = self.1.index_mut(node); - input.interval = Some(interval.clone()); + input.interval = interval.clone(); continue; } @@ -165,23 +184,14 @@ impl ExprIntervalGraph { let first_child_node_index = edges.next_node(&self.1).unwrap(); // Do not use any reference from graph, MUST clone here. let left_interval = - match self.1.index(first_child_node_index).interval().as_ref() { - Some(v) => v.clone(), - None => continue, - }; + self.1.index(first_child_node_index).interval().clone(); let right_interval = - match self.1.index(second_child_node_index).interval().as_ref() { - Some(v) => v.clone(), - None => continue, - }; + self.1.index(second_child_node_index).interval().clone(); // Since we release the reference, we can get mutable reference. let input = self.1.index_mut(node); // Calculate and replace the interval - input.interval = Some(apply_operator( - &left_interval, - binary.op(), - &right_interval, - )?); + input.interval = + apply_operator(&left_interval, binary.op(), &right_interval)?; } else if let Some(CastExpr { cast_type, cast_options, @@ -192,14 +202,10 @@ impl ExprIntervalGraph { let child_index = edges.next_node(&self.1).unwrap(); let child = self.1.index(child_index); // Cast the interval - let new_interval = match &child.interval { - Some(interval) => interval.cast_to(cast_type, cast_options)?, - // If there is no child, continue - None => continue, - }; + let new_interval = child.interval.cast_to(cast_type, cast_options)?; // Update the interval let input = self.1.index_mut(node); - input.interval = Some(new_interval); + input.interval = new_interval; } } Ok(()) @@ -207,7 +213,7 @@ impl ExprIntervalGraph { pub fn pre_order_interval_propagation( &mut self, - expr_stats: &mut [(Arc, Interval)], + expr_stats: &mut [(usize, Interval)], ) -> Result<()> { let mut bfs = Bfs::new(&self.1, self.0); while let Some(node) = bfs.next(&self.1) { @@ -216,14 +222,12 @@ impl ExprIntervalGraph { let expr = input.expr.clone(); // Get calculated interval. BinaryExpr will propagate the interval according to // this. - let parent_calculated_interval = match input.interval().as_ref() { - Some(i) => i.clone(), - None => continue, - }; + let node_interval = input.interval().clone(); - if let Some(index) = expr_stats.iter().position(|(e, _interval)| expr.eq(e)) { - let col_interval = expr_stats.get_mut(index).unwrap(); - *col_interval = (col_interval.0.clone(), parent_calculated_interval); + if let Some((_, interval)) = + expr_stats.iter_mut().find(|(e, _)| *e == node.index()) + { + *interval = node_interval; continue; } @@ -235,20 +239,14 @@ impl ExprIntervalGraph { // Get right node. let second_child_node_index = edges.next_node(&self.1).unwrap(); let second_child_interval = - match self.1.index(second_child_node_index).interval().as_ref() { - Some(v) => v, - None => continue, - }; + self.1.index(second_child_node_index).interval(); // Get left node. let first_child_node_index = edges.next_node(&self.1).unwrap(); let first_child_interval = - match self.1.index(first_child_node_index).interval().as_ref() { - Some(v) => v, - None => continue, - }; + self.1.index(first_child_node_index).interval(); let (shrink_left_interval, shrink_right_interval) = - if parent_calculated_interval.is_boolean() { + if node_interval.is_boolean() { propagate_logical_operators( first_child_interval, binary.op(), @@ -257,13 +255,13 @@ impl ExprIntervalGraph { } else { let negated_op = negate_ops(*binary.op()); let new_right_operator = calculate_node_interval( - &parent_calculated_interval, + &node_interval, second_child_interval, negated_op, first_child_interval, )?; let new_left_operator = calculate_node_interval( - &parent_calculated_interval, + &node_interval, first_child_interval, negated_op, &new_right_operator, @@ -271,27 +269,20 @@ impl ExprIntervalGraph { (new_left_operator, new_right_operator) }; let mutable_first_child = self.1.index_mut(first_child_node_index); - mutable_first_child.interval = Some(shrink_left_interval.clone()); + mutable_first_child.interval = shrink_left_interval.clone(); let mutable_second_child = self.1.index_mut(second_child_node_index); - mutable_second_child.interval = Some(shrink_right_interval.clone()) + mutable_second_child.interval = shrink_right_interval.clone(); } else if let Some(cast) = expr_any.downcast_ref::() { // Calculate new interval let child_index = edges.next_node(&self.1).unwrap(); let child = self.1.index(child_index); - // Get child's internal datatype. If None, propagate None. - let cast_type = match child.interval() { - Some(interval) => interval.get_datatype(), - None => continue, - }; + // Get child's internal datatype. + let cast_type = child.interval().get_datatype(); let cast_options = cast.cast_options(); - let new_child_interval = match &child.interval { - Some(_) => { - parent_calculated_interval.cast_to(&cast_type, cast_options)? - } - None => continue, - }; + let new_child_interval = + node_interval.cast_to(&cast_type, cast_options)?; let mutable_child = self.1.index_mut(child_index); - mutable_child.interval = Some(new_child_interval.clone()); + mutable_child.interval = new_child_interval.clone(); } } Ok(()) @@ -319,7 +310,7 @@ mod tests { left_waited: (Option, Option), right_waited: (Option, Option), ) -> Result<()> { - let mut col_stats = vec![ + let col_stats = vec![ ( exprs.0.clone(), Interval::Range(Range { @@ -355,10 +346,21 @@ mod tests { expr, &col_stats.iter().map(|(e, _)| e.clone()).collect_vec(), )?; - graph.calculate_new_intervals(&mut col_stats)?; - col_stats + let mut col_stat_nodes = col_stats + .iter() + .zip(graph.2.iter()) + .map(|((_, interval), (_, index))| (*index, interval.clone())) + .collect_vec(); + let expected_nodes = expected + .iter() + .zip(graph.2.iter()) + .map(|((_, interval), (_, index))| (*index, interval.clone())) + .collect_vec(); + + graph.calculate_new_intervals(&mut col_stat_nodes[..])?; + col_stat_nodes .iter() - .zip(expected.iter()) + .zip(expected_nodes.iter()) .for_each(|((_, res), (_, expected))| assert_eq!(res, expected)); Ok(()) } diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs index 39ce14a401d1..9c22b5d5587a 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs @@ -1,3 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interval arithmetics library +//! use std::borrow::Borrow; use arrow::datatypes::DataType; @@ -27,6 +46,15 @@ impl Range { } } +impl Default for Range { + fn default() -> Self { + Range { + lower: ScalarValue::Null, + upper: ScalarValue::Null, + } + } +} + impl Add for Range { type Output = Range; @@ -78,6 +106,12 @@ pub enum Interval { Range(Range), } +impl Default for Interval { + fn default() -> Self { + Interval::Range(Range::default()) + } +} + impl Interval { pub(crate) fn cast_to( &self, @@ -548,71 +582,6 @@ impl Interval { }; Ok(result) } - pub fn mul>(&self, other: T) -> Result { - let rhs = other.borrow(); - let result = match (self, rhs) { - (Interval::Singleton(val), Interval::Singleton(val2)) => { - Interval::Singleton(val.sub(val2).unwrap()) - } - (Interval::Singleton(val), Interval::Range(Range { lower, upper })) => { - let new_lower = if lower.is_null() { - lower.clone() - } else { - lower.add(val)?.arithmetic_negate()? - }; - let new_upper = if upper.is_null() { - upper.clone() - } else { - upper.add(val)?.arithmetic_negate()? - }; - Interval::Range(Range { - lower: new_lower, - upper: new_upper, - }) - } - (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { - let new_lower = if lower.is_null() { - lower.clone() - } else { - lower.sub(val)? - }; - let new_upper = if upper.is_null() { - upper.clone() - } else { - upper.sub(val)? - }; - Interval::Range(Range { - lower: new_lower, - upper: new_upper, - }) - } - ( - Interval::Range(Range { lower, upper }), - Interval::Range(Range { - lower: lower2, - upper: upper2, - }), - ) => { - let new_lower = if lower.is_null() || upper2.is_null() { - ScalarValue::try_from(&lower.get_datatype())? - } else { - lower.sub(upper2)? - }; - - let new_upper = if upper.is_null() || lower2.is_null() { - ScalarValue::try_from(&upper.get_datatype())? - } else { - upper.sub(lower2)? - }; - - Interval::Range(Range { - lower: new_lower, - upper: new_upper, - }) - } - }; - Ok(result) - } } pub fn apply_operator(lhs: &Interval, op: &Operator, rhs: &Interval) -> Result { diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs index 11b7db4a2316..f35be672a3b5 100644 --- a/datafusion/physical-expr/src/intervals/mod.rs +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -1,3 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interval calculations +//! + mod cp_solver; pub mod interval_aritmetics; diff --git a/datafusion/physical-expr/src/physical_expr_visitor.rs b/datafusion/physical-expr/src/physical_expr_visitor.rs index e916e0d8b178..b67178a3e6c0 100644 --- a/datafusion/physical-expr/src/physical_expr_visitor.rs +++ b/datafusion/physical-expr/src/physical_expr_visitor.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use crate::expressions::{ BinaryExpr, CastExpr, Column, DateTimeIntervalExpr, GetIndexedFieldExpr, InListExpr, }; diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 7baaf9f91042..ecda28f1d786 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -270,6 +270,8 @@ fn add_physical_expr_to_graph( graph: &mut StableGraph, input_node: T, ) -> NodeIndex { + // If we visited the node before, we just return it. That means, this node will be have multiple + // parents. match visited .iter() .find(|(visited_expr, _node)| visited_expr.eq(&input.expr())) @@ -313,7 +315,6 @@ where { let init = Arc::new(ExprTreeNode::new(expr.clone())); let mut graph: StableGraph = StableGraph::new(); - // TODO: Make membership check O(1) let mut visited_plans: Vec<(Arc, NodeIndex)> = vec![]; // Create a graph let root_tree_node = init.mutable_transform_up(&mut |expr| { From 9adb7da887b4790c95bd52f841968ce3c9ae492a Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 23 Jan 2023 19:08:09 +0300 Subject: [PATCH 03/39] Filter mapping inside SymmetricHashJoin --- .../src/physical_optimizer/pipeline_fixer.rs | 191 +---- .../physical_plan/joins/hash_join_utils.rs | 503 +++++++++++- .../core/src/physical_plan/joins/mod.rs | 4 +- .../joins/symmetric_hash_join.rs | 762 ++++++++++-------- datafusion/core/src/physical_plan/memory.rs | 26 +- .../physical-expr/src/intervals/cp_solver.rs | 1 + .../src/physical_expr_visitor.rs | 31 +- 7 files changed, 1004 insertions(+), 514 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index d98f8b269310..5232a09c753c 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -29,24 +29,19 @@ use crate::physical_optimizer::pipeline_checker::{ check_finiteness_requirements, PipelineStatePropagator, }; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::utils::{JoinFilter, JoinSide}; -use crate::physical_plan::joins::{ - HashJoinExec, PartitionMode, SortedFilterExpr, SymmetricHashJoinExec, -}; +use crate::physical_plan::joins::{HashJoinExec, PartitionMode, SymmetricHashJoinExec}; use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::ExecutionPlan; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::DataType; use datafusion_common::DataFusionError; use datafusion_expr::logical_plan::JoinType; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; use datafusion_physical_expr::physical_expr_visitor::{ - ExprVisitable, PhysicalExpressionVisitor, Recursion, + PhysicalExprVisitable, PhysicalExpressionVisitor, Recursion, }; -use datafusion_physical_expr::rewrite::TreeNodeRewritable as physical; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_expr::PhysicalExpr; -use std::collections::HashMap; use std::sync::Arc; /// The [PipelineFixer] rule tries to modify a given plan so that it can @@ -149,118 +144,26 @@ impl PhysicalExpressionVisitor for UnsupportedPhysicalExprVisitor<'_> { // Sorting information must cover every column in filter expression. // Float is unsupported // Anything beside than BinaryExpr, Col, Literal and operations Ge and Le is currently unsupported. -fn is_suitable_for_symmetric_hash_join(filter: &JoinFilter) -> Result { - let expr: Arc = filter.expression().clone(); - let mut is_expr_supported = true; - expr.accept(UnsupportedPhysicalExprVisitor { - supported: &mut is_expr_supported, - })?; - let is_fields_supported = filter - .schema() - .fields() - .iter() - .all(|f| is_datatype_supported(f.data_type())); - Ok(is_expr_supported && is_fields_supported) -} - -// TODO: Implement CollectLeft, CollectRight and CollectBoth modes for SymmetricHashJoin -fn enforce_symmetric_hash_join( - hash_join: &HashJoinExec, - sorted_columns: Vec, -) -> Result> { - let left = hash_join.left(); - let right = hash_join.right(); - let new_join = Arc::new(SymmetricHashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join - .on() - .iter() - .map(|(l, r)| (l.clone(), r.clone())) - .collect(), - hash_join.filter().unwrap().clone(), - sorted_columns, - hash_join.join_type(), - hash_join.null_equals_null(), - )?); - Ok(new_join) -} - -/// We assume that the sort information is on a column, not a binary expr. -/// If a + b is sorted, we can not find intervals for 'a" and 'b". -/// -/// Here, we map the main column information to filter intermediate schema. -/// child_orders are derived from child schemas. However, filter schema is different. We enforce -/// that we get the column order for each corresponding [ColumnIndex]. If not, we return None and -/// assume this expression is not capable of pruning. -fn build_filter_input_order( - filter: &JoinFilter, - left_main_schema: SchemaRef, - right_main_schema: SchemaRef, - child_orders: Vec>, -) -> Result>> { - let left_child = child_orders[0]; - let right_child = child_orders[1]; - let mut left_hashmap: HashMap = HashMap::new(); - let mut right_hashmap: HashMap = HashMap::new(); - if left_child.is_none() || right_child.is_none() { - return Ok(None); - } - let left_sort_information = left_child.unwrap(); - let right_sort_information = right_child.unwrap(); - for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { - if index.side.eq(&JoinSide::Left) { - let main_field = left_main_schema.field(index.index); - let main_col = - Column::new_with_schema(main_field.name(), left_main_schema.as_ref())?; - let filter_field = filter.schema().field(filter_schema_index); - let filter_col = Column::new(filter_field.name(), filter_schema_index); - left_hashmap.insert(main_col, filter_col); - } else { - let main_field = right_main_schema.field(index.index); - let main_col = - Column::new_with_schema(main_field.name(), right_main_schema.as_ref())?; - let filter_field = filter.schema().field(filter_schema_index); - let filter_col = Column::new(filter_field.name(), filter_schema_index); - right_hashmap.insert(main_col, filter_col); +fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result { + let are_children_ordered = hash_join.left().output_ordering().is_some() + && hash_join.right().output_ordering().is_some(); + let is_filter_supported = match hash_join.filter() { + None => false, + Some(filter) => { + let expr: Arc = filter.expression().clone(); + let mut is_expr_supported = true; + expr.accept(UnsupportedPhysicalExprVisitor { + supported: &mut is_expr_supported, + })?; + let is_fields_supported = filter + .schema() + .fields() + .iter() + .all(|f| is_datatype_supported(f.data_type())); + is_expr_supported && is_fields_supported } - } - let mut sorted_exps = vec![]; - for sort_expr in left_sort_information { - let expr = sort_expr.expr.clone(); - let new_expr = - expr.transform_up(&|p| convert_filter_columns(p, &left_hashmap))?; - sorted_exps.push(SortedFilterExpr::new( - JoinSide::Left, - sort_expr.expr.clone(), - new_expr.clone(), - sort_expr.options, - )); - } - for sort_expr in right_sort_information { - let expr = sort_expr.expr.clone(); - let new_expr = - expr.transform_up(&|p| convert_filter_columns(p, &right_hashmap))?; - sorted_exps.push(SortedFilterExpr::new( - JoinSide::Right, - sort_expr.expr.clone(), - new_expr.clone(), - sort_expr.options, - )); - } - Ok(Some(sorted_exps)) -} - -fn convert_filter_columns( - input: Arc, - column_change_information: &HashMap, -) -> Result>> { - if let Some(col) = input.as_any().downcast_ref::() { - let filter_col = column_change_information.get(col).unwrap().clone(); - Ok(Some(Arc::new(filter_col))) - } else { - Ok(Some(input)) - } + }; + Ok(are_children_ordered && is_filter_supported) } fn hash_join_convert_symmetric_subrule( @@ -271,35 +174,25 @@ fn hash_join_convert_symmetric_subrule( if let Some(hash_join) = plan.as_any().downcast_ref::() { let (left_unbounded, right_unbounded) = (children[0], children[1]); let new_plan = if left_unbounded && right_unbounded { - match hash_join.filter() { - Some(filter) => match is_suitable_for_symmetric_hash_join(filter) { - Ok(suitable) => { - if suitable { - let children = input.plan.children(); - let input_order = children - .iter() - .map(|c| c.output_ordering()) - .collect::>(); - match build_filter_input_order( - filter, - hash_join.left().schema(), - hash_join.right().schema(), - input_order, - ) { - Ok(Some(sort_columns)) => { - enforce_symmetric_hash_join(hash_join, sort_columns) - } - Ok(None) => Ok(plan), - Err(e) => return Some(Err(e)), - } - } else { - Ok(plan) - } - } - Err(e) => return Some(Err(e)), - }, - None => Ok(plan), - } + Ok(match is_suitable_for_symmetric_hash_join(hash_join) { + Ok(true) => Arc::new( + SymmetricHashJoinExec::try_new( + hash_join.left().clone(), + hash_join.right().clone(), + hash_join + .on() + .iter() + .map(|(l, r)| (l.clone(), r.clone())) + .collect(), + hash_join.filter().unwrap().clone(), + hash_join.join_type(), + hash_join.null_equals_null(), + ) + .unwrap(), + ), + Ok(false) => plan, + Err(e) => return Some(Err(e)), + }) } else { Ok(plan) }; diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 170732293223..e6891cd13df2 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -20,6 +20,7 @@ use std::{fmt, usize}; use ahash::RandomState; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::{ array::{ @@ -39,15 +40,26 @@ use arrow::{ use hashbrown::raw::RawTable; use smallvec::{smallvec, SmallVec}; +use crate::common::Result; +use arrow::compute::CastOptions; use datafusion_common::cast::{as_dictionary_array, as_string_array}; -use datafusion_common::DataFusionError; -use datafusion_physical_expr::expressions::Column; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; use datafusion_physical_expr::hash_utils::create_hashes; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::intervals::interval_aritmetics::Interval; +use datafusion_physical_expr::physical_expr_visitor::{ + PhysicalExprVisitable, PhysicalExpressionVisitor, Recursion, +}; +use datafusion_physical_expr::rewrite::TreeNodeRewritable; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use std::collections::HashMap; +use std::sync::Arc; use crate::physical_plan::joins::utils::{ apply_join_filter_to_indices, JoinFilter, JoinSide, }; +use crate::physical_plan::sorts::sort::SortOptions; /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, /// assuming that the [RecordBatch] corresponds to the `index`th @@ -58,12 +70,12 @@ pub fn update_hash( offset: usize, random_state: &RandomState, hashes_buffer: &mut Vec, -) -> datafusion_common::Result<()> { +) -> Result<()> { // evaluate the keys let keys_values = on .iter() .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) - .collect::>>()?; + .collect::>>()?; // calculate the hash values let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; @@ -100,7 +112,7 @@ pub fn build_join_indices( hashes_buffer: &mut Vec, offset: Option, build_side: JoinSide, -) -> datafusion_common::Result<(UInt64Array, UInt32Array)> { +) -> Result<(UInt64Array, UInt32Array)> { // Get the indices which is satisfies the equal join condition, like `left.a1 = right.a2` let (build_indices, probe_indices) = build_equal_condition_join_indices( build_hashmap, @@ -201,7 +213,7 @@ pub fn equal_rows( left_arrays: &[ArrayRef], right_arrays: &[ArrayRef], null_equals_null: bool, -) -> datafusion_common::Result { +) -> Result { let mut err = None; let res = left_arrays .iter() @@ -499,18 +511,18 @@ pub fn build_equal_condition_join_indices( null_equals_null: &bool, hashes_buffer: &mut Vec, offset: Option, -) -> datafusion_common::Result<(UInt64Array, UInt32Array)> { +) -> Result<(UInt64Array, UInt32Array)> { let keys_values = probe_on .iter() .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) - .collect::>>()?; + .collect::>>()?; let build_join_values = build_on .iter() .map(|c| { Ok(c.evaluate(build_input_buffer)? .into_array(build_input_buffer.num_rows())) }) - .collect::>>()?; + .collect::>>()?; hashes_buffer.clear(); hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; @@ -582,3 +594,474 @@ impl fmt::Debug for JoinHashMap { Ok(()) } } + +#[derive(Debug)] +pub struct CheckFilterExprContainsSortInformation<'a> { + /// Supported state + pub contains: &'a mut bool, + pub checked_expr: Arc, +} + +impl<'a> CheckFilterExprContainsSortInformation<'a> { + pub fn new(contains: &'a mut bool, checked_expr: Arc) -> Self { + Self { + contains, + checked_expr, + } + } +} + +impl PhysicalExpressionVisitor for CheckFilterExprContainsSortInformation<'_> { + fn pre_visit(self, expr: Arc) -> Result> { + *self.contains = *self.contains || self.checked_expr.eq(&expr); + if *self.contains { + Ok(Recursion::Stop(self)) + } else { + Ok(Recursion::Continue(self)) + } + } +} + +#[derive(Debug)] +pub struct PhysicalExprColumnCollector<'a> { + pub columns: &'a mut Vec, +} + +impl<'a> PhysicalExprColumnCollector<'a> { + pub fn new(columns: &'a mut Vec) -> Self { + Self { columns } + } +} + +impl PhysicalExpressionVisitor for PhysicalExprColumnCollector<'_> { + fn pre_visit(self, expr: Arc) -> Result> { + if let Some(column) = expr.as_any().downcast_ref::() { + self.columns.push(column.clone()) + } + Ok(Recursion::Continue(self)) + } +} + +/// Create main_col -> filter_col one to one mapping from filter column indices. +/// A column index looks like +/// ColumnIndex { +// index: 0, -> field index in main schema +// side: JoinSide::Left, -> child side +// }, +fn map_origin_col_to_filter_col( + filter: &JoinFilter, + main_schema: SchemaRef, + side: JoinSide, +) -> Result> { + let mut col_to_col_map: HashMap = HashMap::new(); + for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { + if index.side.eq(&side) { + // Get main field from column index + let main_field = main_schema.field(index.index); + // Create a column PhysicalExpr + let main_col = + Column::new_with_schema(main_field.name(), main_schema.as_ref())?; + // Since the filter.column_indices() order directly same with intermediate schema fields, we can + // get the column. + let filter_field = filter.schema().field(filter_schema_index); + let filter_col = Column::new(filter_field.name(), filter_schema_index); + // Insert mapping + col_to_col_map.insert(main_col, filter_col); + } + } + Ok(col_to_col_map) +} +/// If converted PhysicalExpr is inside the Filter, we use the sort information. The visitor +/// visits each node and checks if the filter expression includes the converted expr. +/// +/// We want to filter out the expression if the filter is a + b > c + 10 AND a + b < c + 100 +/// - a + d is sorted. D is not part of the filter. +/// - d is sorted. D is not part of the filter. +/// - a + b + c is sorted. All columns are represented in filter expression however there is no exact match. +/// These expression does not indicate prune. +/// +fn rewrite_sort_information_for_filter( + side: JoinSide, + filter: &JoinFilter, + col_to_col_map: &HashMap, + child_order: &[PhysicalSortExpr], +) -> Result> { + let mut sorted_exps = vec![]; + for sort_expr in child_order { + let expr = sort_expr.expr.clone(); + // Get main schema columns + let mut expr_columns = vec![]; + expr.clone() + .accept(PhysicalExprColumnCollector::new(&mut expr_columns))?; + // Calculation is possible with col_to_col_map since sort exprs belong to a child. + let all_columns_are_included = expr_columns + .iter() + .all(|col| col_to_col_map.contains_key(col)); + if all_columns_are_included { + // Since we are sure that one to one column mapping includes all column, we convert + // the sort expression into Filter expression. + let converted_filter_expr = + expr.transform_up(&|p| convert_filter_columns(p, col_to_col_map))?; + let mut contains = false; + // Search converted PhysicalExpr in filter expression + filter.expression().clone().accept( + CheckFilterExprContainsSortInformation::new( + &mut contains, + converted_filter_expr.clone(), + ), + )?; + // If the exact match is encountered, use this sorted expression in graph traversals. + if contains { + sorted_exps.push(SortedFilterExpr::new( + side, + sort_expr.expr.clone(), + converted_filter_expr.clone(), + sort_expr.options, + )); + } + } + } + Ok(sorted_exps) +} + +/// Here, we map the main column information to filter intermediate schema. +/// child_orders are derived from child schemas. However, filter schema is different. We enforce +/// that we get the column order for each corresponding [ColumnIndex]. If not, we return None and +/// assume this expression is not capable of pruning. +pub fn build_filter_input_order( + filter: &JoinFilter, + left_main_schema: SchemaRef, + right_main_schema: SchemaRef, + left_order: &[PhysicalSortExpr], + right_order: &[PhysicalSortExpr], +) -> Result> { + // Get left mapping + let left_hashmap: HashMap = + map_origin_col_to_filter_col(filter, left_main_schema, JoinSide::Left)?; + // Get right mapping + let right_hashmap: HashMap = + map_origin_col_to_filter_col(filter, right_main_schema, JoinSide::Right)?; + // Extract sorted expression that belongs to the filter expression + let left_sorted_filter_exprs = rewrite_sort_information_for_filter( + JoinSide::Left, + filter, + &left_hashmap, + left_order, + )?; + let right_sorted_filter_exprs = rewrite_sort_information_for_filter( + JoinSide::Right, + filter, + &right_hashmap, + right_order, + )?; + Ok([left_sorted_filter_exprs, right_sorted_filter_exprs].concat()) +} + +fn convert_filter_columns( + input: Arc, + column_change_information: &HashMap, +) -> Result>> { + if let Some(col) = input.as_any().downcast_ref::() { + let filter_col = column_change_information.get(col).unwrap().clone(); + Ok(Some(Arc::new(filter_col))) + } else { + Ok(Some(input)) + } +} + +#[derive(Debug, Clone)] +/// The SortedFilterExpr struct is used to represent a sorted filter expression in the +/// [SymmetricHashJoinExec] struct. It contains information about the join side, the origin +/// expression, the filter expression, and the sort option. +/// The struct has several methods to access and modify its fields. +pub struct SortedFilterExpr { + /// Column side + pub join_side: JoinSide, + /// Sorted expr from a particular join side (child) + pub origin_expr: Arc, + /// For interval calculations, one to one mapping of the columns according to filter expression, + /// and column indices. + pub filter_expr: Arc, + /// Sort option + pub sort_option: SortOptions, + /// Interval + pub interval: Interval, + /// NodeIndex in Graph + pub node_index: usize, +} + +impl SortedFilterExpr { + /// Constructor + pub fn new( + join_side: JoinSide, + origin_expr: Arc, + filter_expr: Arc, + sort_option: SortOptions, + ) -> Self { + Self { + join_side, + origin_expr, + filter_expr, + sort_option, + interval: Interval::default(), + node_index: 0, + } + } + /// Get origin expr information + pub fn origin_expr(&self) -> Arc { + self.origin_expr.clone() + } + /// Get filter expr information + pub fn filter_expr(&self) -> Arc { + self.filter_expr.clone() + } + /// Get sort information + pub fn sort_option(&self) -> SortOptions { + self.sort_option + } + /// Get interval information + pub fn interval(&self) -> &Interval { + &self.interval + } + /// Sets interval + pub fn set_interval(&mut self, interval: Interval) { + self.interval = interval; + } + /// Node index in ExprIntervalGraph + pub fn node_index(&self) -> usize { + self.node_index + } + /// Node index setter in ExprIntervalGraph + pub fn set_node_index(&mut self, node_index: usize) { + self.node_index = node_index; + } +} +/// Filter expr for a + b > c + 10 AND a + b < c + 100 +#[allow(dead_code)] +pub(crate) fn complicated_filter() -> Arc { + let left_expr = BinaryExpr { + left: Arc::new(CastExpr { + expr: Arc::new(BinaryExpr { + left: Arc::new(Column { + name: "0".to_string(), + index: 0, + }), + op: Operator::Plus, + right: Arc::new(Column { + name: "1".to_string(), + index: 1, + }), + }), + cast_type: DataType::Int64, + cast_options: CastOptions { safe: false }, + }), + op: Operator::Gt, + right: Arc::new(BinaryExpr { + left: Arc::new(CastExpr { + expr: Arc::new(Column { + name: "2".to_string(), + index: 2, + }), + cast_type: DataType::Int64, + cast_options: CastOptions { safe: false }, + }), + op: Operator::Plus, + right: Arc::new(Literal::new(ScalarValue::Int64(Some(10)))), + }), + }; + let right_expr = BinaryExpr { + left: Arc::new(CastExpr { + expr: Arc::new(BinaryExpr { + left: Arc::new(Column { + name: "0".to_string(), + index: 0, + }), + op: Operator::Plus, + right: Arc::new(Column { + name: "1".to_string(), + index: 1, + }), + }), + cast_type: DataType::Int64, + cast_options: CastOptions { safe: false }, + }), + op: Operator::Lt, + right: Arc::new(BinaryExpr { + left: Arc::new(CastExpr { + expr: Arc::new(Column { + name: "2".to_string(), + index: 2, + }), + cast_type: DataType::Int64, + cast_options: CastOptions { safe: false }, + }), + op: Operator::Plus, + right: Arc::new(Literal::new(ScalarValue::Int64(Some(100)))), + }), + }; + Arc::new(BinaryExpr::new( + Arc::new(left_expr), + Operator::And, + Arc::new(right_expr), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::{ + expressions::Column, + expressions::PhysicalSortExpr, + joins::utils::{ColumnIndex, JoinFilter, JoinSide}, + }; + use arrow::compute::{CastOptions, SortOptions}; + use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + #[test] + fn find_expr_inside_expr() -> Result<()> { + let filter_expr = complicated_filter(); + + let mut contains1 = false; + let expr_1 = Arc::new(Column::new("gnz", 0)); + filter_expr + .clone() + .accept(CheckFilterExprContainsSortInformation::new( + &mut contains1, + expr_1, + ))?; + assert!(!contains1); + + let mut contains2 = false; + let expr_2 = Arc::new(Column::new("1", 1)); + filter_expr + .clone() + .accept(CheckFilterExprContainsSortInformation::new( + &mut contains2, + expr_2, + ))?; + assert!(contains2); + + let mut contains3 = false; + let expr_3 = Arc::new(CastExpr { + expr: Arc::new(BinaryExpr { + left: Arc::new(Column { + name: "0".to_string(), + index: 0, + }), + op: Operator::Plus, + right: Arc::new(Column { + name: "1".to_string(), + index: 1, + }), + }), + cast_type: DataType::Int64, + cast_options: CastOptions { safe: false }, + }); + filter_expr + .clone() + .accept(CheckFilterExprContainsSortInformation::new( + &mut contains3, + expr_3, + ))?; + assert!(contains3); + + let mut contains4 = false; + let expr_4 = Arc::new(Column::new("1", 42)); + filter_expr + .clone() + .accept(CheckFilterExprContainsSortInformation::new( + &mut contains4, + expr_4, + ))?; + assert!(!contains4); + + Ok(()) + } + + #[test] + fn build_sorted_expr() -> Result<()> { + let left_schema = Arc::new(Schema::new(vec![ + Field::new("la1", DataType::Int32, false), + Field::new("lb1", DataType::Int32, false), + Field::new("lc1", DataType::Int32, false), + Field::new("lt1", DataType::Int32, false), + Field::new("la2", DataType::Int32, false), + Field::new("la1_des", DataType::Int32, false), + ])); + + let right_schema = Arc::new(Schema::new(vec![ + Field::new("ra1", DataType::Int32, false), + Field::new("rb1", DataType::Int32, false), + Field::new("rc1", DataType::Int32, false), + Field::new("rt1", DataType::Int32, false), + Field::new("ra2", DataType::Int32, false), + Field::new("ra1_des", DataType::Int32, false), + ])); + + let filter_col_0 = Arc::new(Column::new("0", 0)); + let filter_col_1 = Arc::new(Column::new("1", 1)); + let filter_col_2 = Arc::new(Column::new("2", 2)); + + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(filter_col_0.name(), DataType::Int32, true), + Field::new(filter_col_1.name(), DataType::Int32, true), + Field::new(filter_col_2.name(), DataType::Int32, true), + ]); + + let filter_expr = complicated_filter(); + + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + // Only two of them in Filter PhysicalExpr (la1, la2) + let left_sorted = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("lt1", 3)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("la2", 4)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("la1", 0)), + options: SortOptions::default(), + }, + ]; + // Only "ra1" in Filter PhysicalExpr + let right_sorted = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("ra1", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("rt1", 3)), + options: SortOptions::default(), + }, + ]; + + let sorted = build_filter_input_order( + &filter, + left_schema, + right_schema, + &left_sorted, + &right_sorted, + )?; + assert_eq!(sorted.len(), 3); + + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/core/src/physical_plan/joins/mod.rs index feeefc1605aa..79ebe3cce0bc 100644 --- a/datafusion/core/src/physical_plan/joins/mod.rs +++ b/datafusion/core/src/physical_plan/joins/mod.rs @@ -21,9 +21,9 @@ pub use cross_join::CrossJoinExec; pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; // Note: SortMergeJoin is not used in plans yet +pub use hash_join_utils::SortedFilterExpr; pub use sort_merge_join::SortMergeJoinExec; -pub use symmetric_hash_join::{SortedFilterExpr, SymmetricHashJoinExec}; - +pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; mod hash_join_utils; diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index f31e7b95b5e4..727a813aa0d2 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -29,7 +29,7 @@ use std::{any::Any, usize}; use ahash::RandomState; use arrow::array::PrimitiveArray; use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveBuilder}; -use arrow::compute::{concat_batches, SortOptions}; +use arrow::compute::concat_batches; use arrow::datatypes::ArrowNativeType; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; @@ -50,7 +50,8 @@ use crate::execution::context::TaskContext; use crate::logical_expr::JoinType; use crate::physical_plan::common::merge_batches; use crate::physical_plan::joins::hash_join_utils::{ - build_join_indices, update_hash, JoinHashMap, + build_filter_input_order, build_join_indices, update_hash, JoinHashMap, + SortedFilterExpr, }; use crate::physical_plan::joins::utils::build_batch_from_indices; use crate::physical_plan::{ @@ -65,70 +66,6 @@ use crate::physical_plan::{ PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -#[derive(Debug, Clone)] -/// The SortedFilterExpr struct is used to represent a sorted filter expression in the -/// [SymmetricHashJoinExec] struct. It contains information about the join side, the origin -/// expression, the filter expression, and the sort option. -/// The struct has several methods to access and modify its fields. -pub struct SortedFilterExpr { - // Column side - join_side: JoinSide, - // Sorted expr from a particular join side (child) - origin_expr: Arc, - // For interval calculations, one to one mapping of the columns according to filter expression, - // and column indices. - filter_expr: Arc, - // Sort option - sort_option: SortOptions, - // Interval - interval: Interval, - // NodeIndex in Graph - node_index: usize, -} - -impl SortedFilterExpr { - /// Constructor - pub fn new( - join_side: JoinSide, - origin_expr: Arc, - filter_expr: Arc, - sort_option: SortOptions, - ) -> Self { - Self { - join_side, - origin_expr, - filter_expr, - sort_option, - interval: Interval::default(), - node_index: 0, - } - } - /// Get origin expr information - pub fn origin_expr(&self) -> Arc { - self.origin_expr.clone() - } - /// Get filter expr information - pub fn filter_expr(&self) -> Arc { - self.filter_expr.clone() - } - /// Get sort information - pub fn sort_option(&self) -> SortOptions { - self.sort_option - } - /// Get interval information - pub fn interval(&self) -> &Interval { - &self.interval - } - /// Sets interval - pub fn set_interval(&mut self, interval: Interval) { - self.interval = interval; - } - /// Node index in ExprIntervalGraph - pub fn node_index(&self) -> usize { - self.node_index - } -} - /// The symmetric hash join is a special type of hash join designed for data streams and large /// datasets. /// For each input, create a hash table. @@ -150,6 +87,8 @@ pub struct SymmetricHashJoinExec { pub(crate) join_type: JoinType, /// Order information of filter columns filter_columns: Vec, + /// Expression graph for interval calculations + physical_expr_graph: ExprIntervalGraph, /// The schema once the join is applied schema: SchemaRef, /// Shares the `RandomState` for the hashing algorithm @@ -218,7 +157,6 @@ impl SymmetricHashJoinExec { right: Arc, on: JoinOn, filter: JoinFilter, - filter_columns: Vec, join_type: &JoinType, null_equals_null: &bool, ) -> Result { @@ -237,13 +175,41 @@ impl SymmetricHashJoinExec { let random_state = RandomState::with_seeds(0, 0, 0, 0); + // + let mut filter_columns = build_filter_input_order( + &filter, + left.schema(), + right.schema(), + left.output_ordering().unwrap(), + right.output_ordering().unwrap(), + )?; + + // Create expression graph for PhysicalExpr. + let physical_expr_graph = ExprIntervalGraph::try_new( + filter.expression().clone(), + &filter_columns + .iter() + .map(|sorted_expr| sorted_expr.filter_expr()) + .collect_vec(), + )?; + // We inject calculated node indexes into SortedFilterExpr. In graph calculations, + // we will be using node index to put calculated intervals into Columns or BinaryExprs . + filter_columns + .iter_mut() + .zip(physical_expr_graph.2.iter()) + .map(|(sorted_expr, (_, index))| { + sorted_expr.set_node_index(*index); + }) + .collect_vec(); + Ok(SymmetricHashJoinExec { left, right, on, filter, - filter_columns, join_type: *join_type, + filter_columns, + physical_expr_graph, schema: Arc::new(schema), random_state, metrics: ExecutionPlanMetricsSet::new(), @@ -373,7 +339,6 @@ impl ExecutionPlan for SymmetricHashJoinExec { children[1].clone(), self.on.clone(), self.filter.clone(), - self.filter_columns.clone(), &self.join_type, &self.null_equals_null, )?)) @@ -416,27 +381,6 @@ impl ExecutionPlan for SymmetricHashJoinExec { let left_stream = self.left.execute(partition, context.clone())?; let right_stream = self.right.execute(partition, context)?; - let physical_expr_graph = ExprIntervalGraph::try_new( - self.filter.expression().clone(), - &self - .filter_columns - .iter() - .map(|sorted_expr| sorted_expr.filter_expr()) - .collect_vec(), - )?; - // We inject calculated node indexes into SortedFilterExpr. In graph calculations, - // we will be using node index to put calculated intervals into Columns or BinaryExprs . - let node_injected_filter_exprs = self - .filter_columns - .iter() - .zip(physical_expr_graph.2.iter()) - .map(|(sorted_expr, (_, index))| { - let mut temp = sorted_expr.clone(); - temp.node_index = *index; - temp - }) - .collect_vec(); - Ok(Box::pin(SymmetricHashJoinStream { left_stream, right_stream, @@ -450,9 +394,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { right: right_side_joiner, column_indices: self.column_indices.clone(), join_metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), - physical_expr_graph, + physical_expr_graph: self.physical_expr_graph.clone(), null_equals_null: self.null_equals_null, - filter_columns: node_injected_filter_exprs, + filter_columns: self.filter_columns.clone(), final_result: false, data_side: JoinSide::Left, })) @@ -918,11 +862,6 @@ impl OneSideHashJoiner { // We use Vec<(usize, Interval)> instead of Hashmap since the expected // filter exprs relatively low and conversion between Vec and Hashmap // back and forth may be slower. - - // Get calculated interval from Vec into a Vec. - // TODO: For ozan, abi hesaplamadan Vec verebilirdik direkt, - // ancak cyclic dependency olusturuyor dosyalar arası. SortedFilterExpr struct'ı common'a - // alabiliriz ve bu hesaplama olmadan SortedFilterExpr dogrudan graf icinde kullanilabilir. let mut filter_intervals: Vec<(usize, Interval)> = filter_columns .iter() .map(|sorted_expr| (sorted_expr.node_index(), sorted_expr.interval().clone())) @@ -1170,19 +1109,20 @@ impl SymmetricHashJoinStream { mod tests { use std::fs::File; - use arrow::array::ArrayRef; + use arrow::array::{Array, ArrayRef}; use arrow::array::{Int32Array, TimestampNanosecondArray}; - use arrow::compute::CastOptions; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::util::pretty::pretty_format_batches; use rstest::*; use tempfile::TempDir; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; + use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::utils; use crate::physical_plan::collect; + use crate::physical_plan::joins::hash_join_utils::complicated_filter; use crate::physical_plan::joins::{HashJoinExec, PartitionMode}; use crate::physical_plan::{ common, memory::MemoryExec, repartition::RepartitionExec, @@ -1192,6 +1132,15 @@ mod tests { use super::*; + const TABLE_SIZE: i32 = 1_000; + + // TODO: Lazy statistics for same record batches. + // use lazy_static::lazy_static; + // + // lazy_static! { + // static ref SIDE_BATCH: Result<(RecordBatch, RecordBatch)> = build_sides_record_batches(TABLE_SIZE, (10, 11)); + // } + fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { // compare let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); @@ -1219,7 +1168,6 @@ mod tests { right: Arc, on: JoinOn, filter: JoinFilter, - sorted_filter_columns: Vec, join_type: &JoinType, null_equals_null: bool, context: Arc, @@ -1247,7 +1195,6 @@ mod tests { )?), on, filter, - sorted_filter_columns, join_type, &null_equals_null, )?; @@ -1334,28 +1281,21 @@ mod tests { Ok(result) } - fn build_table(columns: Vec<(&str, ArrayRef)>) -> Arc { + fn build_record_batch(columns: Vec<(&str, ArrayRef)>) -> Result { let schema = Schema::new( columns .iter() - .map(|(name, array)| Field::new(*name, array.data_type().clone(), false)) + .map(|(name, array)| { + let null = array.null_count() > 0; + Field::new(*name, array.data_type().clone(), null) + }) .collect(), ); let batch = RecordBatch::try_new( Arc::new(schema), columns.into_iter().map(|(_, array)| array).collect(), - ) - .unwrap(); - - let schema = batch.schema(); - Arc::new( - MemoryExec::try_new( - &[split_record_batches(&batch, 13).unwrap()], - schema, - None, - ) - .unwrap(), - ) + )?; + Ok(batch) } fn join_expr_tests_fixture( @@ -1432,13 +1372,15 @@ mod tests { _ => unreachable!(), } } - - fn create_memory_table( + fn build_sides_record_batches( table_size: i32, key_cardinality: (i32, i32), - ) -> (Arc, Arc) { + ) -> Result<(RecordBatch, RecordBatch)> { + let null_ratio: f64 = 0.4; let initial_range = 0..table_size; - let left = build_table(vec![ + let index = (table_size as f64 * null_ratio).round() as i32; + let rest_of = index..table_size; + let left = build_record_batch(vec![ ( "la1", Arc::new(Int32Array::from_iter( @@ -1481,8 +1423,36 @@ mod tests { initial_range.clone().rev().collect::>(), )), ), - ]); - let right = build_table(vec![ + ( + "l_asc_null_first", + Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.clone().map(Some)) + .collect::>>() + })), + ), + ( + "l_asc_null_last", + Arc::new(Int32Array::from_iter({ + rest_of + .clone() + .map(Some) + .chain(std::iter::repeat(None).take(index as usize)) + .collect::>>() + })), + ), + ( + "l_desc_null_first", + Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.clone().rev().map(Some)) + .collect::>>() + })), + ), + ])?; + let right = build_record_batch(vec![ ( "ra1", Arc::new(Int32Array::from_iter( @@ -1525,80 +1495,94 @@ mod tests { initial_range.rev().collect::>(), )), ), - ]); - (left, right) + ( + "r_asc_null_first", + Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.clone().map(Some)) + .collect::>>() + })), + ), + ( + "r_asc_null_last", + Arc::new(Int32Array::from_iter({ + rest_of + .clone() + .map(Some) + .chain(std::iter::repeat(None).take(index as usize)) + .collect::>>() + })), + ), + ( + "r_desc_null_first", + Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.rev().map(Some)) + .collect::>>() + })), + ), + ])?; + Ok((left, right)) } - fn complicated_fiter() -> Arc { - let left_expr = BinaryExpr { - left: Arc::new(CastExpr { - expr: Arc::new(BinaryExpr { - left: Arc::new(Column { - name: "0".to_string(), - index: 0, - }), - op: Operator::Plus, - right: Arc::new(Column { - name: "1".to_string(), - index: 1, - }), - }), - cast_type: DataType::Int64, - cast_options: CastOptions { safe: false }, - }), - op: Operator::Gt, - right: Arc::new(BinaryExpr { - left: Arc::new(CastExpr { - expr: Arc::new(Column { - name: "2".to_string(), - index: 2, - }), - cast_type: DataType::Int64, - cast_options: CastOptions { safe: false }, - }), - op: Operator::Plus, - right: Arc::new(Literal::new(ScalarValue::Int64(Some(10)))), - }), - }; - let right_expr = BinaryExpr { - left: Arc::new(CastExpr { - expr: Arc::new(BinaryExpr { - left: Arc::new(Column { - name: "0".to_string(), - index: 0, - }), - op: Operator::Plus, - right: Arc::new(Column { - name: "1".to_string(), - index: 1, - }), - }), - cast_type: DataType::Int64, - cast_options: CastOptions { safe: false }, - }), - op: Operator::Lt, - right: Arc::new(BinaryExpr { - left: Arc::new(CastExpr { - expr: Arc::new(Column { - name: "2".to_string(), - index: 2, - }), - cast_type: DataType::Int64, - cast_options: CastOptions { safe: false }, - }), - op: Operator::Plus, - right: Arc::new(Literal::new(ScalarValue::Int64(Some(100)))), - }), - }; - Arc::new(BinaryExpr::new( - Arc::new(left_expr), - Operator::And, - Arc::new(right_expr), + fn create_memory_table( + left_batch: RecordBatch, + right_batch: RecordBatch, + left_sorted: Vec, + right_sorted: Vec, + ) -> Result<(Arc, Arc)> { + Ok(( + Arc::new(MemoryExec::try_new_with_sort_information( + &[split_record_batches(&left_batch, 13).unwrap()], + left_batch.schema(), + None, + Some(left_sorted), + )?), + Arc::new(MemoryExec::try_new_with_sort_information( + &[split_record_batches(&right_batch, 13).unwrap()], + right_batch.schema(), + None, + Some(right_sorted), + )?), )) } + async fn experiment( + left: Arc, + right: Arc, + filter: JoinFilter, + join_type: JoinType, + on: JoinOn, + task_ctx: Arc, + ) -> Result<()> { + let first_batches = partitioned_sym_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + &join_type, + false, + task_ctx.clone(), + ) + .await?; + let second_batches = partitioned_hash_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + &join_type, + false, + task_ctx.clone(), + ) + .await?; + compare_batches(&first_batches, &second_batches); + Ok(()) + } + #[rstest] - #[tokio::test(flavor = "multi_thread", worker_threads = 3)] + #[tokio::test(flavor = "multi_thread")] async fn complex_join_all_one_ascending_numeric( #[values( JoinType::Inner, @@ -1623,7 +1607,23 @@ mod tests { let config = SessionConfig::new().with_repartition_joins(false); let session_ctx = SessionContext::with_config(config); let task_ctx = session_ctx.task_ctx(); - let (left, right) = create_memory_table(1000, cardinality); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_sorted = vec![PhysicalSortExpr { + expr: Arc::new(BinaryExpr { + left: Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), + op: Operator::Plus, + right: Arc::new(Column::new_with_schema("la2", &left_batch.schema())?), + }), + options: SortOptions::default(), + }]; + + let right_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("ra1", &right_batch.schema())?), + options: SortOptions::default(), + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -1634,36 +1634,6 @@ mod tests { let filter_col_1 = Arc::new(Column::new("1", 1)); let filter_col_2 = Arc::new(Column::new("2", 2)); - let main_sorted_binary = Arc::new(BinaryExpr { - left: Arc::new(Column::new_with_schema("la1", &left.schema())?), - op: Operator::Plus, - right: Arc::new(Column::new_with_schema("la2", &left.schema())?), - }); - - let filter_sorted_binary = Arc::new(BinaryExpr { - left: filter_col_0.clone(), - op: Operator::Plus, - right: filter_col_1.clone(), - }); - - let right_main_sorted_col = - Arc::new(Column::new_with_schema("ra1", &right.schema())?); - - let sorted_filter_columns = vec![ - SortedFilterExpr::new( - JoinSide::Left, - main_sorted_binary, - filter_sorted_binary, - SortOptions::default(), - ), - SortedFilterExpr::new( - JoinSide::Right, - right_main_sorted_col.clone(), - filter_col_2.clone(), - SortOptions::default(), - ), - ]; - let column_indices = vec![ ColumnIndex { index: 0, @@ -1684,7 +1654,7 @@ mod tests { Field::new(filter_col_2.name(), DataType::Int32, true), ]); - let filter_expr = complicated_fiter(); + let filter_expr = complicated_filter(); let filter = JoinFilter::new( filter_expr, @@ -1692,33 +1662,12 @@ mod tests { intermediate_schema.clone(), ); - let first_batches = partitioned_sym_join_with_filter( - left.clone(), - right.clone(), - on.clone(), - filter.clone(), - sorted_filter_columns, - &join_type, - false, - task_ctx.clone(), - ) - .await?; - let second_batches = partitioned_hash_join_with_filter( - left.clone(), - right.clone(), - on.clone(), - filter.clone(), - &join_type, - false, - task_ctx.clone(), - ) - .await?; - compare_batches(&first_batches, &second_batches); + experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) } #[rstest] - #[tokio::test(flavor = "multi_thread", worker_threads = 3)] + #[tokio::test(flavor = "multi_thread")] async fn join_all_one_ascending_numeric( #[values( JoinType::Inner, @@ -1743,7 +1692,18 @@ mod tests { let config = SessionConfig::new().with_repartition_joins(false); let session_ctx = SessionContext::with_config(config); let task_ctx = session_ctx.task_ctx(); - let (left, right) = create_memory_table(1000, cardinality); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), + options: SortOptions::default(), + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("ra1", &right_batch.schema())?), + options: SortOptions::default(), + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -1753,21 +1713,6 @@ mod tests { let left_col = Arc::new(Column::new("left", 0)); let right_col = Arc::new(Column::new("right", 1)); - let sorted_filter_columns = vec![ - SortedFilterExpr::new( - JoinSide::Left, - Arc::new(Column::new_with_schema("la1", &left.schema())?), - left_col.clone(), - SortOptions::default(), - ), - SortedFilterExpr::new( - JoinSide::Right, - Arc::new(Column::new_with_schema("ra1", &right.schema())?), - right_col.clone(), - SortOptions::default(), - ), - ]; - let column_indices = vec![ ColumnIndex { index: 0, @@ -1792,33 +1737,12 @@ mod tests { intermediate_schema.clone(), ); - let first_batches = partitioned_sym_join_with_filter( - left.clone(), - right.clone(), - on.clone(), - filter.clone(), - sorted_filter_columns, - &join_type, - false, - task_ctx.clone(), - ) - .await?; - let second_batches = partitioned_hash_join_with_filter( - left.clone(), - right.clone(), - on.clone(), - filter.clone(), - &join_type, - false, - task_ctx.clone(), - ) - .await?; - compare_batches(&first_batches, &second_batches); + experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) } #[rstest] - #[tokio::test(flavor = "multi_thread", worker_threads = 3)] + #[tokio::test(flavor = "multi_thread")] async fn join_all_one_descending_numeric_particular( #[values( JoinType::Inner, @@ -1843,7 +1767,24 @@ mod tests { let config = SessionConfig::new().with_repartition_joins(false); let session_ctx = SessionContext::with_config(config); let task_ctx = session_ctx.task_ctx(); - let (left, right) = create_memory_table(20, cardinality); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("la1_des", &left_batch.schema())?), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("ra1_des", &right_batch.schema())?), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -1853,27 +1794,6 @@ mod tests { let left_col = Arc::new(Column::new("left", 0)); let right_col = Arc::new(Column::new("right", 1)); - let sorted_filter_columns = vec![ - SortedFilterExpr::new( - JoinSide::Left, - Arc::new(Column::new_with_schema("la1_des", &left.schema())?), - left_col.clone(), - SortOptions { - descending: true, - nulls_first: true, - }, - ), - SortedFilterExpr::new( - JoinSide::Right, - Arc::new(Column::new_with_schema("ra1_des", &right.schema())?), - right_col.clone(), - SortOptions { - descending: true, - nulls_first: true, - }, - ), - ]; - let column_indices = vec![ ColumnIndex { index: 5, @@ -1898,28 +1818,7 @@ mod tests { intermediate_schema.clone(), ); - let first_batches = partitioned_sym_join_with_filter( - left.clone(), - right.clone(), - on.clone(), - filter.clone(), - sorted_filter_columns, - &join_type, - false, - task_ctx.clone(), - ) - .await?; - let second_batches = partitioned_hash_join_with_filter( - left.clone(), - right.clone(), - on.clone(), - filter.clone(), - &join_type, - false, - task_ctx.clone(), - ) - .await?; - compare_batches(&first_batches, &second_batches); + experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) } @@ -1949,11 +1848,216 @@ mod tests { let task_ctx = ctx.task_ctx(); let results = collect(physical_plan.clone(), task_ctx).await.unwrap(); let formatted = pretty_format_batches(&results).unwrap().to_string(); - println!("Query Output:\n\n{formatted}"); let found = formatted .lines() .any(|line| line.contains("SymmetricHashJoinExec")); assert!(found); Ok(()) } + + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_first() -> Result<()> { + let join_type = JoinType::Full; + let cardinality = (10, 11); + let case_expr = 1; + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema( + "l_asc_null_first", + &left_batch.schema(), + )?), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema( + "r_asc_null_first", + &right_batch.schema(), + )?), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + + let on = vec![( + Column::new_with_schema("lc1", &left.schema())?, + Column::new_with_schema("rc1", &right.schema())?, + )]; + + let left_col = Arc::new(Column::new("left", 0)); + let right_col = Arc::new(Column::new("right", 1)); + + let column_indices = vec![ + ColumnIndex { + index: 6, + side: JoinSide::Left, + }, + ColumnIndex { + index: 6, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(left_col.name(), DataType::Int32, true), + Field::new(right_col.name(), DataType::Int32, true), + ]); + + let filter_expr = + join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); + + let filter = JoinFilter::new( + filter_expr, + column_indices.clone(), + intermediate_schema.clone(), + ); + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_last() -> Result<()> { + let join_type = JoinType::Full; + let cardinality = (10, 11); + let case_expr = 1; + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema( + "l_asc_null_last", + &left_batch.schema(), + )?), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema( + "r_asc_null_last", + &right_batch.schema(), + )?), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + + let on = vec![( + Column::new_with_schema("lc1", &left.schema())?, + Column::new_with_schema("rc1", &right.schema())?, + )]; + + let left_col = Arc::new(Column::new("left", 0)); + let right_col = Arc::new(Column::new("right", 1)); + + let column_indices = vec![ + ColumnIndex { + index: 7, + side: JoinSide::Left, + }, + ColumnIndex { + index: 7, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(left_col.name(), DataType::Int32, true), + Field::new(right_col.name(), DataType::Int32, true), + ]); + + let filter_expr = + join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); + + let filter = JoinFilter::new( + filter_expr, + column_indices.clone(), + intermediate_schema.clone(), + ); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_first_descending() -> Result<()> { + let join_type = JoinType::Full; + let cardinality = (10, 11); + let case_expr = 1; + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema( + "l_desc_null_first", + &left_batch.schema(), + )?), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema( + "r_desc_null_first", + &right_batch.schema(), + )?), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + + let on = vec![( + Column::new_with_schema("lc1", &left.schema())?, + Column::new_with_schema("rc1", &right.schema())?, + )]; + + let left_col = Arc::new(Column::new("left", 0)); + let right_col = Arc::new(Column::new("right", 1)); + + let column_indices = vec![ + ColumnIndex { + index: 8, + side: JoinSide::Left, + }, + ColumnIndex { + index: 8, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(left_col.name(), DataType::Int32, true), + Field::new(right_col.name(), DataType::Int32, true), + ]); + + let filter_expr = + join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); + + let filter = JoinFilter::new( + filter_expr, + column_indices.clone(), + intermediate_schema.clone(), + ); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } } diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/core/src/physical_plan/memory.rs index a15b59552294..6141497182de 100644 --- a/datafusion/core/src/physical_plan/memory.rs +++ b/datafusion/core/src/physical_plan/memory.rs @@ -46,6 +46,8 @@ pub struct MemoryExec { projected_schema: SchemaRef, /// Optional projection projection: Option>, + // Optional sort information + sort_information: Option>, } impl fmt::Debug for MemoryExec { @@ -78,7 +80,10 @@ impl ExecutionPlan for MemoryExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None + match self.sort_information.as_ref() { + None => None, + Some(sort) => Some(sort.as_slice()), + } } fn with_new_children( @@ -145,6 +150,25 @@ impl MemoryExec { schema, projected_schema, projection, + sort_information: None, + }) + } + /// Create a new execution plan for reading in-memory record batches + /// The provided `schema` should not have the projection applied. Also, you can specify sort + /// information on PhysicalExprs. + pub fn try_new_with_sort_information( + partitions: &[Vec], + schema: SchemaRef, + projection: Option>, + sort_information: Option>, + ) -> Result { + let projected_schema = project_schema(&schema, projection.as_ref())?; + Ok(Self { + partitions: partitions.to_vec(), + schema, + projected_schema, + projection, + sort_information, }) } } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 7dd2d82f6371..6a3e99c6110d 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -85,6 +85,7 @@ impl PartialEq for ExprIntervalGraphNode { } } +#[derive(Clone)] pub struct ExprIntervalGraph( NodeIndex, StableGraph, diff --git a/datafusion/physical-expr/src/physical_expr_visitor.rs b/datafusion/physical-expr/src/physical_expr_visitor.rs index b67178a3e6c0..7bc0227e9ecb 100644 --- a/datafusion/physical-expr/src/physical_expr_visitor.rs +++ b/datafusion/physical-expr/src/physical_expr_visitor.rs @@ -16,7 +16,7 @@ // under the License. use crate::expressions::{ - BinaryExpr, CastExpr, Column, DateTimeIntervalExpr, GetIndexedFieldExpr, InListExpr, + BinaryExpr, CastExpr, DateTimeIntervalExpr, GetIndexedFieldExpr, InListExpr, }; use crate::PhysicalExpr; @@ -34,18 +34,18 @@ pub enum Recursion { } /// Encode the traversal of an expression tree. When passed to -/// `Expr::accept`, `PhysicalExpressionVisitor::visit` is invoked +/// `Arc::accept`, `PhysicalExpressionVisitor::visit` is invoked /// recursively on all nodes of an expression tree. See the comments -/// on `Expr::accept` for details on its use -pub trait PhysicalExpressionVisitor>: +/// on `Arc::accept` for details on its use +pub trait PhysicalExpressionVisitor>: Sized { - /// Invoked before any children of `expr` are visited. + /// Invoked before any children of `Arc` are visited. fn pre_visit(self, expr: E) -> Result> where Self: PhysicalExpressionVisitor; - /// Invoked after all children of `expr` are visited. Default + /// Invoked after all children of `Arc` are visited. Default /// implementation does nothing. fn post_visit(self, _expr: E) -> Result { Ok(self) @@ -53,13 +53,12 @@ pub trait PhysicalExpressionVisitor>: } /// trait for types that can be visited by [`ExpressionVisitor`] -pub trait ExprVisitable: Sized { +pub trait PhysicalExprVisitable: Sized { /// accept a visitor, calling `visit` on all children of this fn accept>(self, visitor: V) -> Result; } -// TODO: Widen the options. -impl ExprVisitable for Arc { +impl PhysicalExprVisitable for Arc { fn accept>(self, visitor: V) -> Result { let visitor = match visitor.pre_visit(self.clone())? { Recursion::Continue(visitor) => visitor, @@ -99,17 +98,3 @@ impl ExprVisitable for Arc { visitor.post_visit(self.clone()) } } - -#[derive(Default, Debug)] -pub struct PostOrderPhysicalColumnCollector { - pub columns: Vec, -} - -impl PhysicalExpressionVisitor for PostOrderPhysicalColumnCollector { - fn pre_visit(mut self, expr: Arc) -> Result> { - if let Some(column) = expr.as_any().downcast_ref::() { - self.columns.push(column.clone()) - } - Ok(Recursion::Continue(self)) - } -} From b9d50bd93843aae5bf7a5a1cf1d5db21583b57b3 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Tue, 24 Jan 2023 14:15:47 +0300 Subject: [PATCH 04/39] Commenting on --- .../joins/symmetric_hash_join.rs | 99 +++++++- .../physical-expr/src/intervals/cp_solver.rs | 225 +++++++++++++----- .../src/intervals/interval_aritmetics.rs | 36 ++- 3 files changed, 283 insertions(+), 77 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 727a813aa0d2..660264bbc50b 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -66,14 +66,95 @@ use crate::physical_plan::{ PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -/// The symmetric hash join is a special type of hash join designed for data streams and large -/// datasets. -/// For each input, create a hash table. -/// - For each new record, hash and insert into inputs hash table. +/// A symmetric hash join with range conditions is when both streams are hashed on the +/// join key and the resulting hash tables are used to join the streams. +/// The join is considered symmetric because the hash table is built on the join keys from both +/// streams, and the matching of rows is based on the values of the join keys in both streams. +/// This type of join is efficient in streaming context as it allows for fast lookups in the hash +/// table, rather than having to scan through one or both of the streams to find matching rows, also it +/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions), +/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming +/// data without any memory issues. +/// +/// For each input stream, create a hash table. +/// - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets. /// - Test if input is equal to a predefined set of other inputs. -/// - If so, output the records, record the visited rows. -/// - If the join type indicates that unmatched rows results must be produced (LEFT, FULL etc.), -/// produce result if a pruning happens or at the end of the data. +/// - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch]. +/// - Try to prune other side (probe) with new [RecordBatch]. +/// - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.), +/// output the [RecordBatch] when a pruning happens or at the end of the data. +/// +/// +/// ``` text +/// +-------------------------+ +/// | | +/// left stream ---------| Left OneSideHashJoiner |---+ +/// | | | +/// +-------------------------+ | +/// | +/// |--------- Joined output +/// | +/// +-------------------------+ | +/// | | | +/// right stream ---------| Right OneSideHashJoiner |---+ +/// | | +/// +-------------------------+ +/// +/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetics +/// on JoinFilter's sorted PhysicalExprs to calculate joinable range. +/// +/// +/// PROBE SIDE BUILD SIDE +/// BUFFER BUFFER +/// +-------------+ +------------+ +/// | | | | Unjoinable +/// | | | | Range +/// | | | | +/// | | |--------------------------------- +/// | | | | | +/// | | | | | +/// | | / | | +/// | | | | | +/// | | | | | +/// | | | | | +/// | | | | | +/// | | | | | Joinable +/// | |/ | | Range +/// | || | | +/// |+-----------+|| | | +/// || Record || | | +/// || Batch || | | +/// |+-----------+|| | | +/// +-------------+\ +------------+ +/// | +/// \ +/// |--------------------------------- +/// +/// This happens when range conditions are provided on sorted columns. E.g. +/// +/// SELECT * FROM left_table, right_table +/// ON +/// left_key = right_key AND +/// left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR +/// +/// or +/// SELECT * FROM left_table, right_table +/// ON +/// left_key = right_key AND +/// left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10 +/// +/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to +/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the +/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios) +/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning +/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" , +/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending) +/// than that can be dropped from the inner buffer. +/// +/// +/// +/// ``` +/// pub struct SymmetricHashJoinExec { /// left side stream pub(crate) left: Arc, @@ -150,8 +231,8 @@ impl SymmetricHashJoinMetrics { impl SymmetricHashJoinExec { /// Tries to create a new [SymmetricHashJoinExec]. /// # Error - /// This function errors when it is not possible to join the left and right sides on keys `on`. - /// TODO: Support for CollectLeft (or CorrectRight options.) + /// This function errors when it is not possible to join the left and right sides on keys `on`, + /// iterate and construct [SortedFilterExpr]s,and create [ExprIntervalGraph] pub fn try_new( left: Arc, right: Arc, diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 6a3e99c6110d..e3122772556f 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -35,56 +35,83 @@ use petgraph::stable_graph::StableGraph; use std::ops::{Index, IndexMut}; use std::sync::Arc; -#[derive(Clone, Debug)] -/// Graph nodes contains interval information. -pub struct ExprIntervalGraphNode { - expr: Arc, - interval: Interval, -} - -impl Display for ExprIntervalGraphNode { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.expr) - } -} - -impl ExprIntervalGraphNode { - // Construct ExprIntervalGraphNode - pub fn new(expr: Arc) -> Self { - ExprIntervalGraphNode { - expr, - interval: Interval::default(), - } - } - /// Specify interval - pub fn new_with_interval(expr: Arc, interval: Interval) -> Self { - ExprIntervalGraphNode { expr, interval } - } - /// Get interval - pub fn interval(&self) -> &Interval { - &self.interval - } - - /// Within this static method, one can customize how PhysicalExpr generator constructs its nodes - pub fn expr_node_builder(input: Arc) -> ExprIntervalGraphNode { - let binding = input.expr(); - let plan_any = binding.as_any(); - if let Some(literal) = plan_any.downcast_ref::() { - // Create interval - let interval = Interval::Singleton(literal.value().clone()); - ExprIntervalGraphNode::new_with_interval(input.expr().clone(), interval) - } else { - ExprIntervalGraphNode::new(input.expr().clone()) - } - } -} - -impl PartialEq for ExprIntervalGraphNode { - fn eq(&self, other: &ExprIntervalGraphNode) -> bool { - self.expr.eq(&other.expr) - } -} - +/// +/// Interval arithmetic provides a way to perform mathematical operations on intervals, +/// which represents a range of possible values rather than a single point value. +/// This allows for the propagation of ranges through mathematical operations, +/// and can be used to compute bounds for a complicated expression.The key idea is that by +/// breaking down a complicated expression into simpler terms, and then combining the bounds for +/// those simpler terms, one can obtain bounds for the overall expression. +/// +/// For example, consider a mathematical expression such as x^2 + y = 4. Since it would be +/// a binary tree in [PhysicalExpr] notation, this type of an hierarchical +/// computation is well-suited for a graph based implementation. In such an implementation, +/// an equation system f(x) = 0 is represented by a directed acyclic expression +/// graph (DAEG). +/// +/// To use interval arithmetic to compute bounds for this expression, one would first determine +/// intervals that represent the possible values of x and y. +/// Let's say that the interval for x is [1, 2] and the interval for y is [-3, 1]. Below, you can +/// see how the computation happens. +/// +/// This way of using interval arithmetic to compute bounds for a complicated expression by +/// combining the bounds for the simpler terms within the original expression allows us to +/// reason about the range of possible values of the expression. This information later can be +/// used in range pruning of the unused parts of the [RecordBatch]es. +/// +/// References +/// 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval +/// Arithmetic Based Approach. Stanford University, 2015. +/// 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. +/// 3 - F. Messine, "Deterministic global optimization using interval constraint +/// propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, +/// pp. 277{293, 2004. +/// +/// ``` text +/// Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression +/// graph using inverse semantics. +/// +/// [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] +/// +-----+ +-----+ +-----+ +-----+ +/// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +/// | | | | | | | | | | | | | | | | +/// | +-----+ | | +-----+ | | +-----+ | | +-----+ | +/// | | | | | | | | +/// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +/// | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* +/// |[.] | | | |[.] | | | |[.] | | | |[.] | | | +/// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +/// | | | [-3, 1] | +/// | | | | +/// +---+ +---+ +---+ +---+ +/// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] +/// +---+ +---+ +---+ +---+ +/// +/// (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 +/// +/// [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] +/// +-----+ +-----+ +-----+ +-----+ +/// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +/// | | | | | | | | | | | | | | | | +/// | +-----+ | | +-----+ | | +-----+ | | +-----+ | +/// | | | | | | | | +/// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +/// | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | +/// |[.] | | | |[.] | | | |[.] | | | |[.] | | | +/// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +/// | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] +/// | | | | +/// +---+ +---+ +---+ +---+ +/// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** +/// +---+ +---+ +---+ +---+ +/// +/// (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 +/// +/// * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] +/// ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] +/// *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] +/// ``` +/// #[derive(Clone)] pub struct ExprIntervalGraph( NodeIndex, @@ -108,10 +135,42 @@ fn calculate_node_interval( } impl ExprIntervalGraph { - /// We initiate [ExprIntervalGraph] with [PhysicalExpr] and the ones with we provide internals. - /// Let's say we have "a@0 + b@1 > 3" expression. - /// We can provide initial interval information for "a@0" and "b@1". Also, we can provide - /// information + /// Constructs + /// + /// # Arguments + /// * `expr` - Arc. The complex expression that we compute bounds on by + /// traversing the graph. + /// * `provided_expr` - Arc that is used for child nodes. If one of them is + /// a Arc with multiple child, the child nodes of this node + /// will be pruned, since we do not need them while traversing. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::Operator; + /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + /// use datafusion_physical_expr::intervals::ExprIntervalGraph; + /// // syn > gnz + 1 AND syn < gnz + 10 + /// let syn = Arc::new(Column::new("syn", 0)); + /// let left_and_2 = Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("gnz", 0)), + /// Operator::Plus, + /// Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + /// )); + /// let right_and_2 = Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("gnz", 0)), + /// Operator::Plus, + /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + /// )); + /// let left_expr = Arc::new(BinaryExpr::new(syn, Operator::Gt, left_and_2)); + /// let right_expr = Arc::new(BinaryExpr::new(syn, Operator::Lt, right_and_2)); + /// let expr = Arc::new(BinaryExpr::new(left_expr, Operator::And, right_expr)); + /// let provided_exprs = vec![left_and_2.clone()]; + /// // You can provide exprs for child prunning. + /// let graph = ExprIntervalGraph::try_new(expr, &provided_exprs); + /// ``` pub fn try_new( expr: Arc, provided_expr: &[Arc], @@ -147,7 +206,11 @@ impl ExprIntervalGraph { } Ok(Self(root_node, graph, expr_node_indices)) } - + /// Calculates new intervals and mutate the vector without changing the order. + /// + /// # Arguments + /// + /// * `expr_stats` - PhysicalExpr intervals. They are the child nodes for the interval calculation. pub fn calculate_new_intervals( &mut self, expr_stats: &mut [(usize, Interval)], @@ -156,7 +219,7 @@ impl ExprIntervalGraph { self.pre_order_interval_propagation(expr_stats)?; Ok(()) } - + /// Computing bounds for an expression using interval arithmetic. fn post_order_interval_calculation( &mut self, expr_stats: &[(usize, Interval)], @@ -290,6 +353,56 @@ impl ExprIntervalGraph { } } +#[derive(Clone, Debug)] +/// Graph nodes contains interval information. +pub struct ExprIntervalGraphNode { + expr: Arc, + interval: Interval, +} + +impl Display for ExprIntervalGraphNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.expr) + } +} + +impl ExprIntervalGraphNode { + // Construct ExprIntervalGraphNode + pub fn new(expr: Arc) -> Self { + ExprIntervalGraphNode { + expr, + interval: Interval::default(), + } + } + /// Specify interval + pub fn new_with_interval(expr: Arc, interval: Interval) -> Self { + ExprIntervalGraphNode { expr, interval } + } + /// Get interval + pub fn interval(&self) -> &Interval { + &self.interval + } + + /// Within this static method, one can customize how PhysicalExpr generator constructs its nodes + pub fn expr_node_builder(input: Arc) -> ExprIntervalGraphNode { + let binding = input.expr(); + let plan_any = binding.as_any(); + if let Some(literal) = plan_any.downcast_ref::() { + // Create interval + let interval = Interval::Singleton(literal.value().clone()); + ExprIntervalGraphNode::new_with_interval(input.expr().clone(), interval) + } else { + ExprIntervalGraphNode::new(input.expr().clone()) + } + } +} + +impl PartialEq for ExprIntervalGraphNode { + fn eq(&self, other: &ExprIntervalGraphNode) -> bool { + self.expr.eq(&other.expr) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs index 9c22b5d5587a..8924a5aea8f4 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs @@ -34,6 +34,20 @@ use crate::aggregate::min_max::{max, min}; use arrow::compute::{kernels, CastOptions}; use std::ops::{Add, Sub}; +/// +/// At the moment, it only supports addition and subtraction, +/// but we are constantly expanding its features to provide more capabilities in the future. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Interval { + Singleton(ScalarValue), + Range(Range), +} + +impl Default for Interval { + fn default() -> Self { + Interval::Range(Range::default()) + } +} #[derive(Debug, PartialEq, Clone, Eq, Hash)] pub struct Range { pub lower: ScalarValue, @@ -100,18 +114,6 @@ fn cast_scalar_value( ScalarValue::try_from_array(&cast_array, 0) } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum Interval { - Singleton(ScalarValue), - Range(Range), -} - -impl Default for Interval { - fn default() -> Self { - Interval::Range(Range::default()) - } -} - impl Interval { pub(crate) fn cast_to( &self, @@ -466,6 +468,12 @@ impl Interval { }) } + /// The interval addition operation takes two intervals, say [a1, b1] and [a2, b2], + /// and returns a new interval that represents the range of possible values obtained by adding + /// any value from the first interval to any value from the second interval. + /// The resulting interval is defined as [a1 + a2, b1 + b2]. For example, + /// if we have the intervals [1, 2] and [3, 4], the interval addition + /// operation would return the interval [4, 6]. pub fn add>(&self, other: T) -> Result { let rhs = other.borrow(); let result = match (self, rhs) { @@ -517,6 +525,10 @@ impl Interval { Ok(result) } + /// The interval subtraction operation is similar to the addition operation, + /// but it subtracts one interval from another. The resulting interval is defined as [a1 - b2, b1 - a2]. + /// For example, if we have the intervals [1, 2] and [3, 4], the interval subtraction operation + /// would return the interval [-3, -1]. pub fn sub>(&self, other: T) -> Result { let rhs = other.borrow(); let result = match (self, rhs) { From 0abf0d89bb818ed61a25a0afb959ab6b78d4de71 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Tue, 24 Jan 2023 15:01:10 +0300 Subject: [PATCH 05/39] Minor changes after merge --- .../core/src/physical_plan/joins/symmetric_hash_join.rs | 2 +- datafusion/physical-expr/src/intervals/cp_solver.rs | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 660264bbc50b..07ca1f8637a0 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -39,7 +39,7 @@ use hashbrown::raw::RawTable; use hashbrown::HashSet; use itertools::Itertools; -use datafusion_common::bisect::bisect; +use datafusion_common::utils::bisect; use datafusion_common::ScalarValue; use datafusion_physical_expr::intervals::interval_aritmetics::{Interval, Range}; use datafusion_physical_expr::intervals::ExprIntervalGraph; diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index e3122772556f..024b7aebaced 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -135,7 +135,7 @@ fn calculate_node_interval( } impl ExprIntervalGraph { - /// Constructs + /// Constructs ExprIntervalGraph /// /// # Arguments /// * `expr` - Arc. The complex expression that we compute bounds on by @@ -152,6 +152,7 @@ impl ExprIntervalGraph { /// use datafusion_expr::Operator; /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; /// use datafusion_physical_expr::intervals::ExprIntervalGraph; + /// use datafusion_physical_expr::PhysicalExpr; /// // syn > gnz + 1 AND syn < gnz + 10 /// let syn = Arc::new(Column::new("syn", 0)); /// let left_and_2 = Arc::new(BinaryExpr::new( @@ -164,12 +165,12 @@ impl ExprIntervalGraph { /// Operator::Plus, /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), /// )); - /// let left_expr = Arc::new(BinaryExpr::new(syn, Operator::Gt, left_and_2)); + /// let left_expr = Arc::new(BinaryExpr::new(syn.clone(), Operator::Gt, left_and_2.clone())); /// let right_expr = Arc::new(BinaryExpr::new(syn, Operator::Lt, right_and_2)); /// let expr = Arc::new(BinaryExpr::new(left_expr, Operator::And, right_expr)); - /// let provided_exprs = vec![left_and_2.clone()]; + /// let provided_exprs: Vec> = vec![left_and_2]; /// // You can provide exprs for child prunning. - /// let graph = ExprIntervalGraph::try_new(expr, &provided_exprs); + /// let graph = ExprIntervalGraph::try_new(expr, provided_exprs.as_slice()); /// ``` pub fn try_new( expr: Arc, From 6a15fe7234b9d25097f65a83bdeb57c65981e4d4 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Sun, 29 Jan 2023 21:47:47 -0600 Subject: [PATCH 06/39] Simplify interval arithmetic library code --- .../joins/symmetric_hash_join.rs | 14 +- .../physical-expr/src/intervals/cp_solver.rs | 46 +- .../src/intervals/interval_aritmetics.rs | 604 ++++-------------- 3 files changed, 170 insertions(+), 494 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 07ca1f8637a0..af41d0b376f7 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -41,7 +41,7 @@ use itertools::Itertools; use datafusion_common::utils::bisect; use datafusion_common::ScalarValue; -use datafusion_physical_expr::intervals::interval_aritmetics::{Interval, Range}; +use datafusion_physical_expr::intervals::interval_aritmetics::Interval; use datafusion_physical_expr::intervals::ExprIntervalGraph; use crate::arrow::array::BooleanBufferBuilder; @@ -610,15 +610,15 @@ fn column_stats_two_side( let value = ScalarValue::try_from_array(&array, 0)?; let infinite = ScalarValue::try_from(value.get_datatype())?; *interval = if sort_option.descending { - Interval::Range(Range { + Interval { lower: infinite, upper: value, - }) + } } else { - Interval::Range(Range { + Interval { lower: value, upper: infinite, - }) + } }; } Ok(()) @@ -645,9 +645,9 @@ fn determine_prune_length( .unwrap() .into_array(buffer.num_rows()); let target = if sort_option.descending { - interval.upper_value() + interval.upper.clone() } else { - interval.lower_value() + interval.lower.clone() }; Some(bisect::(&[batch_arr], &[target], &[*sort_option])) } else { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 024b7aebaced..9fa91b1fd555 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -125,13 +125,9 @@ fn calculate_node_interval( current_child_interval: &Interval, negated_op: Operator, other_child_interval: &Interval, -) -> Result { - Ok(if current_child_interval.is_singleton() { - current_child_interval.clone() - } else { - apply_operator(parent, &negated_op, other_child_interval)? - .intersect(current_child_interval)? - }) +) -> Result> { + let interv = apply_operator(parent, &negated_op, other_child_interval)?; + interv.intersect(current_child_interval) } impl ExprIntervalGraph { @@ -324,13 +320,17 @@ impl ExprIntervalGraph { second_child_interval, negated_op, first_child_interval, - )?; + )? + // TODO: Fix this unwrap + .unwrap(); let new_left_operator = calculate_node_interval( &node_interval, first_child_interval, negated_op, &new_right_operator, - )?; + )? + // TODO: Fix this unwrap + .unwrap(); (new_left_operator, new_right_operator) }; let mutable_first_child = self.1.index_mut(first_child_node_index); @@ -388,12 +388,17 @@ impl ExprIntervalGraphNode { pub fn expr_node_builder(input: Arc) -> ExprIntervalGraphNode { let binding = input.expr(); let plan_any = binding.as_any(); + let clone_expr = input.expr().clone(); if let Some(literal) = plan_any.downcast_ref::() { // Create interval - let interval = Interval::Singleton(literal.value().clone()); - ExprIntervalGraphNode::new_with_interval(input.expr().clone(), interval) + let value = literal.value(); + let interval = Interval { + lower: value.clone(), + upper: value.clone(), + }; + ExprIntervalGraphNode::new_with_interval(clone_expr, interval) } else { - ExprIntervalGraphNode::new(input.expr().clone()) + ExprIntervalGraphNode::new(clone_expr) } } } @@ -411,7 +416,6 @@ mod tests { use itertools::Itertools; use crate::expressions::Column; - use crate::intervals::interval_aritmetics::Range; use datafusion_common::ScalarValue; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -428,33 +432,33 @@ mod tests { let col_stats = vec![ ( exprs.0.clone(), - Interval::Range(Range { + Interval { lower: ScalarValue::Int32(left_interval.0), upper: ScalarValue::Int32(left_interval.1), - }), + }, ), ( exprs.1.clone(), - Interval::Range(Range { + Interval { lower: ScalarValue::Int32(right_interval.0), upper: ScalarValue::Int32(right_interval.1), - }), + }, ), ]; let expected = vec![ ( exprs.0.clone(), - Interval::Range(Range { + Interval { lower: ScalarValue::Int32(left_waited.0), upper: ScalarValue::Int32(left_waited.1), - }), + }, ), ( exprs.1.clone(), - Interval::Range(Range { + Interval { lower: ScalarValue::Int32(right_waited.0), upper: ScalarValue::Int32(right_waited.1), - }), + }, ), ]; let mut graph = ExprIntervalGraph::try_new( diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs index 8924a5aea8f4..a95cd7e9a730 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs @@ -15,90 +15,41 @@ // specific language governing permissions and limitations // under the License. -//! Interval arithmetics library +//! Interval arithmetic library //! use std::borrow::Borrow; +use std::fmt; +use std::fmt::{Display, Formatter}; +use arrow::compute::{kernels, CastOptions}; use arrow::datatypes::DataType; - use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; - use datafusion_expr::Operator; -use std::fmt; -use std::fmt::{Display, Formatter}; - use crate::aggregate::min_max::{max, min}; -use arrow::compute::{kernels, CastOptions}; -use std::ops::{Add, Sub}; - -/// -/// At the moment, it only supports addition and subtraction, -/// but we are constantly expanding its features to provide more capabilities in the future. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum Interval { - Singleton(ScalarValue), - Range(Range), -} - -impl Default for Interval { - fn default() -> Self { - Interval::Range(Range::default()) - } -} +/// This type represents an interval, which is used to calculate reliable +/// bounds for expressions. Currently, we only support addition and +/// subtraction, but more capabilities will be added in the future. #[derive(Debug, PartialEq, Clone, Eq, Hash)] -pub struct Range { +pub struct Interval { pub lower: ScalarValue, pub upper: ScalarValue, } -impl Range { - fn new(lower: ScalarValue, upper: ScalarValue) -> Range { - Range { lower, upper } - } -} - -impl Default for Range { +impl Default for Interval { fn default() -> Self { - Range { + Interval { lower: ScalarValue::Null, upper: ScalarValue::Null, } } } -impl Add for Range { - type Output = Range; - - fn add(self, other: Range) -> Range { - Range::new( - self.lower.add(&other.lower).unwrap(), - self.upper.add(&other.upper).unwrap(), - ) - } -} - -impl Sub for Range { - type Output = Range; - - fn sub(self, other: Range) -> Range { - Range::new( - self.lower.add(&other.lower).unwrap(), - self.upper.sub(&other.lower).unwrap(), - ) - } -} - impl Display for Interval { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Interval::Singleton(value) => write!(f, "Singleton [{}]", value), - Interval::Range(Range { lower, upper }) => { - write!(f, "Range [{}, {}]", lower, upper) - } - } + write!(f, "Interval [{}, {}]", self.lower, self.upper) } } @@ -120,151 +71,55 @@ impl Interval { data_type: &DataType, cast_options: &CastOptions, ) -> Result { - Ok(match self { - Interval::Range(Range { lower, upper }) => Interval::Range(Range { - lower: cast_scalar_value(lower, data_type, cast_options)?, - upper: cast_scalar_value(upper, data_type, cast_options)?, - }), - Interval::Singleton(value) => { - Interval::Singleton(cast_scalar_value(value, data_type, cast_options)?) - } + Ok(Interval { + lower: cast_scalar_value(&self.lower, data_type, cast_options)?, + upper: cast_scalar_value(&self.upper, data_type, cast_options)?, }) } + pub(crate) fn is_boolean(&self) -> bool { matches!( self, - &Interval::Range(Range { + Interval { lower: ScalarValue::Boolean(_), upper: ScalarValue::Boolean(_), - }) + } ) } - pub fn lower_value(&self) -> ScalarValue { - match self { - Interval::Range(Range { lower, .. }) => lower.clone(), - Interval::Singleton(val) => val.clone(), - } - } - - pub fn upper_value(&self) -> ScalarValue { - match self { - Interval::Range(Range { upper, .. }) => upper.clone(), - Interval::Singleton(val) => val.clone(), - } - } - pub(crate) fn get_datatype(&self) -> DataType { - match self { - Interval::Range(Range { lower, .. }) => lower.get_datatype(), - Interval::Singleton(val) => val.get_datatype(), - } - } - pub(crate) fn is_singleton(&self) -> bool { - match self { - Interval::Range(_) => false, - Interval::Singleton(_) => true, - } - } - pub(crate) fn is_strict_false(&self) -> bool { - matches!( - self, - &Interval::Range(Range { - lower: ScalarValue::Boolean(Some(false)), - upper: ScalarValue::Boolean(Some(false)), - }) - ) + self.lower.get_datatype() } + pub(crate) fn gt(&self, other: &Interval) -> Interval { - let bools = match (self, other) { - (Interval::Singleton(val), Interval::Singleton(val2)) => { - let equal = val > val2; - (equal, equal) - } - (Interval::Singleton(val), Interval::Range(Range { lower, upper })) => { - if val < lower { - (false, false) - } else if val > upper { - (true, true) - } else { - (false, true) - } - } - (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { - if val > upper { - (false, false) - } else if lower > val { - (true, true) - } else { - (false, true) - } - } - ( - Interval::Range(Range { lower, upper }), - Interval::Range(Range { - lower: lower2, - upper: upper2, - }), - ) => { - if upper < lower2 { - (false, false) - } else if lower > upper2 { - (true, true) - } else { - (false, true) - } - } + let flags = if self.upper < other.lower { + (false, false) + } else if self.lower > other.upper { + (true, true) + } else { + (false, true) }; - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(bools.0)), - upper: ScalarValue::Boolean(Some(bools.1)), - }) + Interval { + lower: ScalarValue::Boolean(Some(flags.0)), + upper: ScalarValue::Boolean(Some(flags.1)), + } } pub(crate) fn and(&self, other: &Interval) -> Result { - let bools = match (self, other) { - ( - Interval::Singleton(ScalarValue::Boolean(Some(val))), - Interval::Singleton(ScalarValue::Boolean(Some(val2))), - ) => { - let equal = *val && *val2; - (equal, equal) - } - ( - Interval::Singleton(ScalarValue::Boolean(Some(val))), - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(lower)), - upper: ScalarValue::Boolean(Some(upper)), - }), - ) - | ( - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(lower)), - upper: ScalarValue::Boolean(Some(upper)), - }), - Interval::Singleton(ScalarValue::Boolean(Some(val))), - ) => { - if (*val && *lower) && *upper { - (true, true) - } else if !(*val && *lower) && *upper { - (false, false) - } else { - (false, true) - } - } + let flags = match (self, other) { ( - Interval::Range(Range { + Interval { lower: ScalarValue::Boolean(Some(lower)), upper: ScalarValue::Boolean(Some(upper)), - }), - Interval::Range(Range { + }, + Interval { lower: ScalarValue::Boolean(Some(lower2)), upper: ScalarValue::Boolean(Some(upper2)), - }), + }, ) => { - if (*lower && *lower2) && (*upper && *upper2) { + if *lower && *lower2 { (true, true) - } else if (*lower || *lower2) || (*upper || *upper2) { + } else if *upper || *upper2 { (false, true) } else { (false, false) @@ -272,164 +127,71 @@ impl Interval { } _ => { return Err(DataFusionError::Internal( - "Booleans cannot be None.".to_string(), + "Incompatible types for logical conjunction".to_string(), )) } }; - Ok(Interval::Range(Range { - lower: ScalarValue::Boolean(Some(bools.0)), - upper: ScalarValue::Boolean(Some(bools.1)), - })) + Ok(Interval { + lower: ScalarValue::Boolean(Some(flags.0)), + upper: ScalarValue::Boolean(Some(flags.1)), + }) } pub(crate) fn equal(&self, other: &Interval) -> Interval { - let bool = match (self, other) { - (Interval::Singleton(val), Interval::Singleton(val2)) => { - let equal = val == val2; - (equal, equal) - } - (Interval::Singleton(val), Interval::Range(Range { lower, upper })) => { - let equal = lower >= val && val <= upper; - (equal, equal) - } - (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { - let equal = lower >= val && val <= upper; - (equal, equal) - } - ( - Interval::Range(Range { lower, upper }), - Interval::Range(Range { - lower: lower2, - upper: upper2, - }), - ) => { - if lower == lower2 && upper == upper2 { - (true, true) - } else if (lower < lower2 && upper < upper2) - || (lower > lower2 && upper > upper2) - { - (false, false) - } else { - (false, true) - } - } + let flags = if (self.lower == self.upper) + && (other.lower == other.upper) + && (self.lower == other.lower) + { + (true, true) + } else if (self.lower > other.upper) || (self.upper < other.lower) { + (false, false) + } else { + (false, true) }; - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(bool.0)), - upper: ScalarValue::Boolean(Some(bool.1)), - }) + Interval { + lower: ScalarValue::Boolean(Some(flags.0)), + upper: ScalarValue::Boolean(Some(flags.1)), + } } pub(crate) fn lt(&self, other: &Interval) -> Interval { - let bool = match (self, other) { - (Interval::Singleton(val), Interval::Singleton(val2)) => { - let result = val < val2; - (result, result) - } - (Interval::Singleton(ref val), Interval::Range(Range { lower, upper })) => { - if val > upper { - (false, false) - } else if val < lower { - (true, true) - } else { - (true, false) - } - } - ( - Interval::Range(Range { - ref lower, - ref upper, - }), - Interval::Singleton(val), - ) => { - if val < upper { - (false, false) - } else if lower > val { - (true, true) - } else { - (true, false) - } - } - ( - Interval::Range(Range { - ref lower, - ref upper, - }), - Interval::Range(Range { - lower: lower2, - upper: upper2, - }), - ) => { - if upper < lower2 { - (true, true) - } else if lower >= upper2 { - (false, false) - } else { - (true, false) - } - } - }; - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(bool.0)), - upper: ScalarValue::Boolean(Some(bool.1)), - }) + other.gt(self) } - pub(crate) fn intersect(&self, other: &Interval) -> Result { - let result = match (self, other) { - (Interval::Singleton(val1), Interval::Singleton(val2)) => { - if val1 == val2 { - Interval::Singleton(val1.clone()) - } else { - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(false)), - upper: ScalarValue::Boolean(Some(false)), - }) - } - } - (Interval::Singleton(val), Interval::Range(range)) => { - if val >= &range.lower && val <= &range.upper { - Interval::Singleton(val.clone()) - } else { - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(false)), - upper: ScalarValue::Boolean(Some(false)), - }) - } - } - (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { - if (lower.is_null() && upper.is_null()) || (val >= lower && val <= upper) + pub(crate) fn intersect(&self, other: &Interval) -> Result> { + Ok(match (self, other) { + ( + Interval { + lower: ScalarValue::Boolean(Some(low)), + upper: ScalarValue::Boolean(Some(high)), + }, + Interval { + lower: ScalarValue::Boolean(Some(other_low)), + upper: ScalarValue::Boolean(Some(other_high)), + }, + ) => { + if low > other_high || high < other_low { + // This None value signals an empty interval. + None + } else if (low == high) && (other_low == other_high) && (low == other_low) { - Interval::Singleton(val.clone()) + Some(self.clone()) } else { - Interval::Range(Range { + Some(Interval { lower: ScalarValue::Boolean(Some(false)), - upper: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(true)), }) } } ( - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(val)), - upper: ScalarValue::Boolean(Some(val2)), - }), - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(val3)), - upper: ScalarValue::Boolean(Some(val4)), - }), - ) => Interval::Range(Range { - lower: ScalarValue::Boolean(Some(*val || *val3)), - upper: ScalarValue::Boolean(Some(*val2 || *val4)), - }), - ( - Interval::Range(Range { + Interval { lower: lower1, upper: upper1, - }), - Interval::Range(Range { + }, + Interval { lower: lower2, upper: upper2, - }), + }, ) => { let lower = if lower1.is_null() { lower2.clone() @@ -446,25 +208,19 @@ impl Interval { min(upper1, upper2)? }; if !upper.is_null() && lower > upper { - Interval::Range(Range { - lower: ScalarValue::Boolean(Some(false)), - upper: ScalarValue::Boolean(Some(false)), - }) + // This None value signals an empty interval. + None } else { - Interval::Range(Range { lower, upper }) + Some(Interval { lower, upper }) } } - }; - Ok(result) + }) } pub(crate) fn arithmetic_negate(&self) -> Result { - Ok(match self { - Interval::Singleton(value) => Interval::Singleton(value.arithmetic_negate()?), - Interval::Range(Range { lower, upper }) => Interval::Range(Range { - lower: upper.arithmetic_negate()?, - upper: lower.arithmetic_negate()?, - }), + Ok(Interval { + lower: self.upper.arithmetic_negate()?, + upper: self.lower.arithmetic_negate()?, }) } @@ -476,53 +232,17 @@ impl Interval { /// operation would return the interval [4, 6]. pub fn add>(&self, other: T) -> Result { let rhs = other.borrow(); - let result = match (self, rhs) { - (Interval::Singleton(val), Interval::Singleton(val2)) => { - Interval::Singleton(val.add(val2).unwrap()) - } - (Interval::Singleton(val), Interval::Range(Range { lower, upper })) - | (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { - let new_lower = if lower.is_null() { - lower.clone() - } else { - lower.add(val).unwrap() - }; - let new_upper = if upper.is_null() { - upper.clone() - } else { - upper.add(val).unwrap() - }; - Interval::Range(Range { - lower: new_lower, - upper: new_upper, - }) - } - - ( - Interval::Range(Range { lower, upper }), - Interval::Range(Range { - lower: lower2, - upper: upper2, - }), - ) => { - let new_lower = if lower.is_null() || lower2.is_null() { - ScalarValue::try_from(lower.get_datatype()).unwrap() - } else { - lower.add(lower2).unwrap() - }; - - let new_upper = if upper.is_null() || upper2.is_null() { - ScalarValue::try_from(upper.get_datatype()).unwrap() - } else { - upper.add(upper2).unwrap() - }; - Interval::Range(Range { - lower: new_lower, - upper: new_upper, - }) - } - }; - Ok(result) + let lower = if self.lower.is_null() || rhs.lower.is_null() { + ScalarValue::try_from(self.lower.get_datatype()) + } else { + self.lower.add(&rhs.lower) + }?; + let upper = if self.upper.is_null() || rhs.upper.is_null() { + ScalarValue::try_from(self.upper.get_datatype()) + } else { + self.upper.add(&rhs.upper) + }?; + Ok(Interval { lower, upper }) } /// The interval subtraction operation is similar to the addition operation, @@ -531,68 +251,17 @@ impl Interval { /// would return the interval [-3, -1]. pub fn sub>(&self, other: T) -> Result { let rhs = other.borrow(); - let result = match (self, rhs) { - (Interval::Singleton(val), Interval::Singleton(val2)) => { - Interval::Singleton(val.sub(val2).unwrap()) - } - (Interval::Singleton(val), Interval::Range(Range { lower, upper })) => { - let new_lower = if lower.is_null() { - lower.clone() - } else { - lower.add(val)?.arithmetic_negate()? - }; - let new_upper = if upper.is_null() { - upper.clone() - } else { - upper.add(val)?.arithmetic_negate()? - }; - Interval::Range(Range { - lower: new_lower, - upper: new_upper, - }) - } - (Interval::Range(Range { lower, upper }), Interval::Singleton(val)) => { - let new_lower = if lower.is_null() { - lower.clone() - } else { - lower.sub(val)? - }; - let new_upper = if upper.is_null() { - upper.clone() - } else { - upper.sub(val)? - }; - Interval::Range(Range { - lower: new_lower, - upper: new_upper, - }) - } - ( - Interval::Range(Range { lower, upper }), - Interval::Range(Range { - lower: lower2, - upper: upper2, - }), - ) => { - let new_lower = if lower.is_null() || upper2.is_null() { - ScalarValue::try_from(&lower.get_datatype())? - } else { - lower.sub(upper2)? - }; - - let new_upper = if upper.is_null() || lower2.is_null() { - ScalarValue::try_from(&upper.get_datatype())? - } else { - upper.sub(lower2)? - }; - - Interval::Range(Range { - lower: new_lower, - upper: new_upper, - }) - } - }; - Ok(result) + let lower = if self.lower.is_null() || rhs.upper.is_null() { + ScalarValue::try_from(self.lower.get_datatype()) + } else { + self.lower.sub(&rhs.upper) + }?; + let upper = if self.upper.is_null() || rhs.lower.is_null() { + ScalarValue::try_from(self.upper.get_datatype()) + } else { + self.upper.sub(&rhs.lower) + }?; + Ok(Interval { lower, upper }) } } @@ -604,37 +273,35 @@ pub fn apply_operator(lhs: &Interval, op: &Operator, rhs: &Interval) -> Result lhs.and(rhs), Operator::Plus => lhs.add(rhs), Operator::Minus => lhs.sub(rhs), - _ => Ok(Interval::Singleton(ScalarValue::Null)), + _ => Ok(Interval { + lower: ScalarValue::Null, + upper: ScalarValue::Null, + }), } } pub fn target_parent_interval(datatype: &DataType, op: &Operator) -> Result { - let inf = ScalarValue::try_from(datatype)?; + let unbounded = ScalarValue::try_from(datatype)?; let zero = ScalarValue::try_from_string("0".to_string(), datatype)?; - let parent_interval = match *op { + Ok(match *op { + // TODO: Tidy these comments, write a text that explains what this function does. // [x1, y1] > [x2, y2] ise // [x1, y1] + [-y2, -x2] = [0, inf] // [x1_new, x2_new] = ([0, inf] - [-y2, -x2]) intersect [x1, y1] // [-y2_new, -x2_new] = ([0, inf] - [x1_new, x2_new]) intersect [-y2, -x2] - Operator::Gt => { - // [0, inf] - Interval::Range(Range { - lower: zero, - upper: inf, - }) - } + Operator::Gt => Interval { + lower: zero, + upper: unbounded, + }, + // TODO: Tidy these comments. // [x1, y1] < [x2, y2] ise // [x1, y1] + [-y2, -x2] = [-inf, 0] - Operator::Lt => { - // [-inf, 0] - Interval::Range(Range { - lower: inf, - upper: zero, - }) - } + Operator::Lt => Interval { + lower: unbounded, + upper: zero, + }, _ => unreachable!(), - }; - Ok(parent_interval) + }) } pub fn propagate_logical_operators( @@ -644,27 +311,32 @@ pub fn propagate_logical_operators( ) -> Result<(Interval, Interval)> { if left_interval.is_boolean() || right_interval.is_boolean() { return Ok(( - Interval::Range(Range { + Interval { lower: ScalarValue::Boolean(Some(true)), upper: ScalarValue::Boolean(Some(true)), - }), - Interval::Range(Range { + }, + Interval { lower: ScalarValue::Boolean(Some(true)), upper: ScalarValue::Boolean(Some(true)), - }), + }, )); } - - let intersection = left_interval.clone().intersect(right_interval)?; - if intersection.is_strict_false() { + let intersection = left_interval.intersect(right_interval)?; + // TODO: Is this right with the new "intersect" semantics? + if intersection.is_none() { return Ok((left_interval.clone(), right_interval.clone())); } let parent_interval = target_parent_interval(&left_interval.get_datatype(), op)?; let negate_right = right_interval.arithmetic_negate()?; + // TODO: Fix the following unwraps. let new_left = parent_interval .sub(&negate_right)? - .intersect(left_interval)?; - let new_right = parent_interval.sub(&new_left)?.intersect(&negate_right)?; + .intersect(left_interval)? + .unwrap(); + let new_right = parent_interval + .sub(&new_left)? + .intersect(&negate_right)? + .unwrap(); let range = (new_left, new_right.arithmetic_negate()?); Ok(range) } From d2b6a66584683eb88689665cd163183296e50db6 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 30 Jan 2023 16:41:39 +0300 Subject: [PATCH 07/39] Make the interval arithmetics library more robust --- .../physical-expr/src/intervals/cp_solver.rs | 197 ++++++-- .../src/intervals/interval_aritmetics.rs | 444 ++++++++++++------ 2 files changed, 476 insertions(+), 165 deletions(-) diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 9fa91b1fd555..425bba6b541b 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -18,19 +18,19 @@ //! It solves constraint propagation problem for the custom PhysicalExpr //! use crate::expressions::{BinaryExpr, CastExpr, Literal}; -use crate::intervals::interval_aritmetics::{ - apply_operator, negate_ops, propagate_logical_operators, Interval, -}; +use crate::intervals::interval_aritmetics::{negate_ops, Interval}; use crate::utils::{build_physical_expr_graph, ExprTreeNode}; use crate::PhysicalExpr; use std::fmt::{Display, Formatter}; -use datafusion_common::Result; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; use petgraph::visit::{Bfs, DfsPostOrder}; use petgraph::Outgoing; +use arrow::compute::CastOptions; +use arrow_schema::DataType; use petgraph::stable_graph::StableGraph; use std::ops::{Index, IndexMut}; use std::sync::Arc; @@ -123,11 +123,119 @@ pub struct ExprIntervalGraph( fn calculate_node_interval( parent: &Interval, current_child_interval: &Interval, - negated_op: Operator, + negated_op: &Operator, other_child_interval: &Interval, -) -> Result> { - let interv = apply_operator(parent, &negated_op, other_child_interval)?; - interv.intersect(current_child_interval) +) -> Result { + let interv = apply_operator(parent, negated_op, other_child_interval)?; + match interv.intersect(current_child_interval) { + Ok(Some(val)) => Ok(val), + Ok(None) => Ok(current_child_interval.clone()), + Err(e) => Err(e), + } +} + +pub fn propagate_arithmetic_operators( + first_child_interval: &Interval, + op: &Operator, + second_child_interval: &Interval, + parent_node_interval: &Interval, +) -> Result<(Interval, Interval)> { + let negated_op = negate_ops(*op); + let new_right_operator = calculate_node_interval( + parent_node_interval, + second_child_interval, + &negated_op, + first_child_interval, + )?; + let new_left_operator = calculate_node_interval( + parent_node_interval, + first_child_interval, + &negated_op, + &new_right_operator, + )?; + Ok((new_left_operator, new_right_operator)) +} + +/// Casting data type with arrow kernel +pub fn cast_scalar_value( + value: &ScalarValue, + data_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let scalar_array = value.to_array(); + let cast_array = + arrow::compute::cast_with_options(&scalar_array, data_type, cast_options)?; + ScalarValue::try_from_array(&cast_array, 0) +} + +pub fn apply_operator(lhs: &Interval, op: &Operator, rhs: &Interval) -> Result { + match *op { + Operator::Eq => Ok(lhs.equal(rhs)), + Operator::Gt => Ok(lhs.gt(rhs)), + Operator::Lt => Ok(lhs.lt(rhs)), + Operator::And => lhs.and(rhs), + Operator::Plus => lhs.add(rhs), + Operator::Minus => lhs.sub(rhs), + _ => Ok(Interval { + lower: ScalarValue::Null, + upper: ScalarValue::Null, + }), + } +} +/// We use this function provide a target parent interval for comparison operators. +/// If expression > 0, that means expression must has interval of [0, ∞] +/// If expression < 0, that means expression must has interval of [-∞, 0] +/// +/// Currently, we only support GT and LT since we did not implemented open and closed intervals. +pub fn comparison_operator_target( + datatype: &DataType, + op: &Operator, +) -> Result { + let unbounded = ScalarValue::try_from(datatype)?; + let zero = ScalarValue::try_from_string("0".to_string(), datatype)?; + Ok(match *op { + Operator::Gt => Interval { + lower: zero, + upper: unbounded, + }, + Operator::Lt => Interval { + lower: unbounded, + upper: zero, + }, + _ => unreachable!(), + }) +} +/// We propagate the intervals for comparison operators. Main idea is that if +/// [x1, y1] > [x2, y2], it can be rewritten as [x1, y1] + [-y2, -x2] = [0, ∞]. This is quite +/// powerful conversion, since we now can shrink the intervals. For less than operator, it is also +/// available that if [x1, y1] < [x2, y2], than it can be rewritten as [x1, y1] + [-y2, -x2] = [-∞, 0]. +/// +/// Let's say we have GT operator to demonstrate. We can further use this information to first +/// calculate the right side by using the not-yet-shrunk +/// left side by [x1_new, x2_new] = ([0, ∞] - [-y2, -x2]) ∩ [x1, y1]. +/// After calculated the new right side, than we can use the shrunk right side to shrink left side as +/// [-y2_new, -x2_new] = ([0, ∞] - [x1_new, x2_new]) ∩ [-y2, -x2]. +/// +pub fn propagate_comparison_operators( + left_interval: &Interval, + op: &Operator, + right_interval: &Interval, +) -> Result<(Interval, Interval)> { + let parent_interval = comparison_operator_target(&left_interval.get_datatype(), op)?; + let negate_right = right_interval.arithmetic_negate()?; + let new_left = match parent_interval + .sub(&negate_right)? + .intersect(left_interval)? + { + Some(i) => i, + None => left_interval.clone(), + }; + let new_right = match parent_interval.sub(&new_left)?.intersect(&negate_right)? { + Some(i) => i, + None => right_interval.clone(), + }; + let range = (new_left, new_right.arithmetic_negate()?); + Ok(range) } impl ExprIntervalGraph { @@ -284,7 +392,6 @@ impl ExprIntervalGraph { // Get calculated interval. BinaryExpr will propagate the interval according to // this. let node_interval = input.interval().clone(); - if let Some((_, interval)) = expr_stats.iter_mut().find(|(e, _)| *e == node.index()) { @@ -307,31 +414,24 @@ impl ExprIntervalGraph { self.1.index(first_child_node_index).interval(); let (shrink_left_interval, shrink_right_interval) = - if node_interval.is_boolean() { - propagate_logical_operators( + // There is no propagation process for logical operators. + if binary.op().is_logic_operator(){ + continue + } else if binary.op().is_comparison_operator() { + // If comparison is strictly false, there is nothing to do for shrink. + if let Interval{ lower: ScalarValue::Boolean(Some(false)), upper: ScalarValue::Boolean(Some(false))} = node_interval { continue } + // Propagate the comparison operator. + propagate_comparison_operators( first_child_interval, binary.op(), second_child_interval, )? } else { - let negated_op = negate_ops(*binary.op()); - let new_right_operator = calculate_node_interval( - &node_interval, - second_child_interval, - negated_op, - first_child_interval, - )? - // TODO: Fix this unwrap - .unwrap(); - let new_left_operator = calculate_node_interval( - &node_interval, - first_child_interval, - negated_op, - &new_right_operator, - )? - // TODO: Fix this unwrap - .unwrap(); - (new_left_operator, new_right_operator) + // Propagate the arithmetic operator. + propagate_arithmetic_operators(first_child_interval, + binary.op(), + second_child_interval, + &node_interval)? }; let mutable_first_child = self.1.index_mut(first_child_node_index); mutable_first_child.interval = shrink_left_interval.clone(); @@ -541,6 +641,45 @@ mod tests { )?; Ok(()) } + + #[test] + fn particular() -> Result<()> { + let seed = 0; + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 + let expr = filter_numeric_expr_generation( + left_col.clone(), + right_col.clone(), + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + 1, + 5, + 3, + 10, + ); + // l > r + 6 AND r > l - 7 + let l_gt_r = 6; + let r_gt_l = -7; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 6 AND l < r + 7 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + + Ok(()) + } + #[rstest] #[test] fn case_1( diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs index a95cd7e9a730..f271b04087f1 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs @@ -21,13 +21,14 @@ use std::borrow::Borrow; use std::fmt; use std::fmt::{Display, Formatter}; -use arrow::compute::{kernels, CastOptions}; +use arrow::compute::CastOptions; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Operator; use crate::aggregate::min_max::{max, min}; +use crate::intervals::cp_solver; /// This type represents an interval, which is used to calculate reliable /// bounds for expressions. Currently, we only support addition and @@ -53,18 +54,6 @@ impl Display for Interval { } } -/// Casting data type with arrow kernel -fn cast_scalar_value( - value: &ScalarValue, - data_type: &DataType, - cast_options: &CastOptions, -) -> Result { - let scalar_array = value.to_array(); - let cast_array = - kernels::cast::cast_with_options(&scalar_array, data_type, cast_options)?; - ScalarValue::try_from_array(&cast_array, 0) -} - impl Interval { pub(crate) fn cast_to( &self, @@ -72,29 +61,38 @@ impl Interval { cast_options: &CastOptions, ) -> Result { Ok(Interval { - lower: cast_scalar_value(&self.lower, data_type, cast_options)?, - upper: cast_scalar_value(&self.upper, data_type, cast_options)?, + lower: cp_solver::cast_scalar_value(&self.lower, data_type, cast_options)?, + upper: cp_solver::cast_scalar_value(&self.upper, data_type, cast_options)?, }) } - pub(crate) fn is_boolean(&self) -> bool { - matches!( - self, - Interval { - lower: ScalarValue::Boolean(_), - upper: ScalarValue::Boolean(_), - } - ) - } - pub(crate) fn get_datatype(&self) -> DataType { self.lower.get_datatype() } + /// Null is treated as infinite. + /// [a , b] > [c, d] + /// where a, b, c or d can be infinite or a real value. We use None if a or c is minus infinity or + /// b or d is infinity. [10,20], [10, ∞], [-∞, 100] or [-∞, ∞] are all valid intervals. + /// While we are calculating correct comparison, we take the PartialOrd implementation of ScalarValue, + /// which is + /// None < a true + /// a < None false + /// None > a false + /// a > None true + /// None < None false + /// None > None false pub(crate) fn gt(&self, other: &Interval) -> Interval { - let flags = if self.upper < other.lower { + let flags = if !self.upper.is_null() + && !other.lower.is_null() + && self.upper < other.lower + { (false, false) - } else if self.lower > other.upper { + // If lhs has lower negative infinity or rhs has upper infinity, this can not be completely true. + } else if !other.upper.is_null() + && !self.lower.is_null() + && (self.lower > other.upper) + { (true, true) } else { (false, true) @@ -104,7 +102,15 @@ impl Interval { upper: ScalarValue::Boolean(Some(flags.1)), } } - + /// We define vague result as (false, true). If we "AND" a vague value with another, + /// result is + /// vague AND false -> false + /// vague AND true -> vague, + /// vague AND vague -> vague + /// false AND true -> true + /// + /// where "false" represented with (false, false), "true" represented with (true, true) and + /// vague represented with (false, true). pub(crate) fn and(&self, other: &Interval) -> Result { let flags = match (self, other) { ( @@ -119,7 +125,7 @@ impl Interval { ) => { if *lower && *lower2 { (true, true) - } else if *upper || *upper2 { + } else if *upper && *upper2 { (false, true) } else { (false, false) @@ -153,59 +159,54 @@ impl Interval { upper: ScalarValue::Boolean(Some(flags.1)), } } - + /// Null is treated as infinite. + /// [a , b] < [c, d] + /// where a, b, c or d can be infinite or a real value. We use None if a or c is minus infinity or + /// b or d is infinity. [10,20], [10, ∞], [-∞, 100] or [-∞, ∞] are all valid intervals. + /// Implemented as complementary of greater than. pub(crate) fn lt(&self, other: &Interval) -> Interval { other.gt(self) } + /// Null is treated as infinite. + /// [a , b] < [c, d] + /// where a, b, c or d can be infinite or a real value. We use None if a or c is minus infinity or + /// b or d is infinity. [10,20], [10, ∞], [-∞, 100] or [-∞, ∞] are all valid intervals. pub(crate) fn intersect(&self, other: &Interval) -> Result> { Ok(match (self, other) { ( Interval { - lower: ScalarValue::Boolean(Some(low)), - upper: ScalarValue::Boolean(Some(high)), + lower: ScalarValue::Boolean(_), + upper: ScalarValue::Boolean(_), }, Interval { - lower: ScalarValue::Boolean(Some(other_low)), - upper: ScalarValue::Boolean(Some(other_high)), + lower: ScalarValue::Boolean(_), + upper: ScalarValue::Boolean(_), }, - ) => { - if low > other_high || high < other_low { - // This None value signals an empty interval. - None - } else if (low == high) && (other_low == other_high) && (low == other_low) - { - Some(self.clone()) - } else { - Some(Interval { - lower: ScalarValue::Boolean(Some(false)), - upper: ScalarValue::Boolean(Some(true)), - }) - } - } + ) => None, ( Interval { - lower: lower1, - upper: upper1, + lower: low, + upper: high, }, Interval { - lower: lower2, - upper: upper2, + lower: other_low, + upper: other_high, }, ) => { - let lower = if lower1.is_null() { - lower2.clone() - } else if lower2.is_null() { - lower1.clone() + let lower = if low.is_null() { + other_low.clone() + } else if other_low.is_null() { + low.clone() } else { - max(lower1, lower2)? + max(low, other_low)? }; - let upper = if upper1.is_null() { - upper2.clone() - } else if upper2.is_null() { - upper1.clone() + let upper = if high.is_null() { + other_high.clone() + } else if other_high.is_null() { + high.clone() } else { - min(upper1, upper2)? + min(high, other_high)? }; if !upper.is_null() && lower > upper { // This None value signals an empty interval. @@ -265,86 +266,257 @@ impl Interval { } } -pub fn apply_operator(lhs: &Interval, op: &Operator, rhs: &Interval) -> Result { - match *op { - Operator::Eq => Ok(lhs.equal(rhs)), - Operator::Gt => Ok(lhs.gt(rhs)), - Operator::Lt => Ok(lhs.lt(rhs)), - Operator::And => lhs.and(rhs), - Operator::Plus => lhs.add(rhs), - Operator::Minus => lhs.sub(rhs), - _ => Ok(Interval { - lower: ScalarValue::Null, - upper: ScalarValue::Null, - }), +pub fn negate_ops(op: Operator) -> Operator { + match op { + Operator::Plus => Operator::Minus, + Operator::Minus => Operator::Plus, + _ => unreachable!(), } } +#[cfg(test)] +mod tests { + use crate::intervals::interval_aritmetics::Interval; + use datafusion_common::Result; + use datafusion_common::ScalarValue; -pub fn target_parent_interval(datatype: &DataType, op: &Operator) -> Result { - let unbounded = ScalarValue::try_from(datatype)?; - let zero = ScalarValue::try_from_string("0".to_string(), datatype)?; - Ok(match *op { - // TODO: Tidy these comments, write a text that explains what this function does. - // [x1, y1] > [x2, y2] ise - // [x1, y1] + [-y2, -x2] = [0, inf] - // [x1_new, x2_new] = ([0, inf] - [-y2, -x2]) intersect [x1, y1] - // [-y2_new, -x2_new] = ([0, inf] - [x1_new, x2_new]) intersect [-y2, -x2] - Operator::Gt => Interval { - lower: zero, - upper: unbounded, - }, - // TODO: Tidy these comments. - // [x1, y1] < [x2, y2] ise - // [x1, y1] + [-y2, -x2] = [-inf, 0] - Operator::Lt => Interval { - lower: unbounded, - upper: zero, - }, - _ => unreachable!(), - }) -} + #[test] + fn intersect_test() -> Result<()> { + let possible_cases = vec![ + (Some(1000), None, None, None, Some(1000), None), + (None, Some(1000), None, None, None, Some(1000)), + (None, None, Some(1000), None, Some(1000), None), + (None, None, None, Some(1000), None, Some(1000)), + (Some(1000), None, Some(1000), None, Some(1000), None), + ( + None, + Some(1000), + Some(999), + Some(1002), + Some(999), + Some(1000), + ), + (None, Some(1000), Some(1000), None, Some(1000), Some(1000)), // singleton + (None, None, None, None, None, None), + ]; + + for case in possible_cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .intersect(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + })? + .unwrap(), + Interval { + lower: ScalarValue::Int64(case.4), + upper: ScalarValue::Int64(case.5) + } + ) + } -pub fn propagate_logical_operators( - left_interval: &Interval, - op: &Operator, - right_interval: &Interval, -) -> Result<(Interval, Interval)> { - if left_interval.is_boolean() || right_interval.is_boolean() { - return Ok(( - Interval { - lower: ScalarValue::Boolean(Some(true)), - upper: ScalarValue::Boolean(Some(true)), - }, - Interval { - lower: ScalarValue::Boolean(Some(true)), - upper: ScalarValue::Boolean(Some(true)), - }, - )); + let not_possible_cases = vec![ + (None, Some(1000), Some(1001), None), + (Some(1001), None, None, Some(1000)), + (None, Some(1000), Some(1001), Some(1002)), + (Some(1001), Some(1002), None, Some(1000)), + ]; + //(None, Some(1000), Some(1001), None, --), + // (None, Some(1000), Some(1000), None, false, true), + // (None, Some(1000), Some(1001), Some(1002), false, false), + + for case in not_possible_cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .intersect(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + })?, + None + ) + } + + Ok(()) } - let intersection = left_interval.intersect(right_interval)?; - // TODO: Is this right with the new "intersect" semantics? - if intersection.is_none() { - return Ok((left_interval.clone(), right_interval.clone())); + + #[test] + fn gt_test() { + let cases = vec![ + (Some(1000), None, None, None, false, true), + (None, Some(1000), None, None, false, true), + (None, None, Some(1000), None, false, true), + (None, None, None, Some(1000), false, true), + (None, Some(1000), Some(1000), None, false, true), + (None, Some(1000), Some(1001), None, false, false), + (Some(1000), None, Some(1000), None, false, true), + (None, Some(1000), Some(1001), Some(1002), false, false), + (None, Some(1000), Some(999), Some(1002), false, true), + (None, None, None, None, false, true), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .gt(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + }), + Interval { + lower: ScalarValue::Boolean(Some(case.4)), + upper: ScalarValue::Boolean(Some(case.5)) + } + ) + } } - let parent_interval = target_parent_interval(&left_interval.get_datatype(), op)?; - let negate_right = right_interval.arithmetic_negate()?; - // TODO: Fix the following unwraps. - let new_left = parent_interval - .sub(&negate_right)? - .intersect(left_interval)? - .unwrap(); - let new_right = parent_interval - .sub(&new_left)? - .intersect(&negate_right)? - .unwrap(); - let range = (new_left, new_right.arithmetic_negate()?); - Ok(range) -} -pub fn negate_ops(op: Operator) -> Operator { - match op { - Operator::Plus => Operator::Minus, - Operator::Minus => Operator::Plus, - _ => unreachable!(), + #[test] + fn lt_test() { + let cases = vec![ + (Some(1000), None, None, None, false, true), + (None, Some(1000), None, None, false, true), + (None, None, Some(1000), None, false, true), + (None, None, None, Some(1000), false, true), + (None, Some(1000), Some(1000), None, false, true), + (None, Some(1000), Some(1001), None, true, true), + (Some(1000), None, Some(1000), None, false, true), + (None, Some(1000), Some(1001), Some(1002), true, true), + (None, Some(1000), Some(999), Some(1002), false, true), + (None, None, None, None, false, true), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .lt(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + }), + Interval { + lower: ScalarValue::Boolean(Some(case.4)), + upper: ScalarValue::Boolean(Some(case.5)) + } + ) + } + } + + #[test] + fn and_test() -> Result<()> { + let cases = vec![ + (false, true, false, false, false, false), + (false, false, false, true, false, false), + (false, true, false, true, false, true), + (false, true, true, true, false, true), + (false, false, false, false, false, false), + (true, true, true, true, true, true), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Boolean(Some(case.0)), + upper: ScalarValue::Boolean(Some(case.1)) + } + .and(&Interval { + lower: ScalarValue::Boolean(Some(case.2)), + upper: ScalarValue::Boolean(Some(case.3)) + })?, + Interval { + lower: ScalarValue::Boolean(Some(case.4)), + upper: ScalarValue::Boolean(Some(case.5)) + } + ) + } + Ok(()) + } + + #[test] + fn add_test() -> Result<()> { + let cases = vec![ + (Some(1000), None, None, None, None, None), + (None, Some(1000), None, None, None, None), + (None, None, Some(1000), None, None, None), + (None, None, None, Some(1000), None, None), + (Some(1000), None, Some(1000), None, Some(2000), None), + (None, Some(1000), Some(999), Some(1002), None, Some(2002)), + (None, Some(1000), Some(1000), None, None, None), + ( + Some(2001), + Some(1), + Some(1005), + Some(-999), + Some(3006), + Some(-998), + ), + (None, None, None, None, None, None), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .add(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + })?, + Interval { + lower: ScalarValue::Int64(case.4), + upper: ScalarValue::Int64(case.5) + } + ) + } + Ok(()) + } + + #[test] + fn sub_test() -> Result<()> { + let cases = vec![ + (Some(1000), None, None, None, None, None), + (None, Some(1000), None, None, None, None), + (None, None, Some(1000), None, None, None), + (None, None, None, Some(1000), None, None), + (Some(1000), None, Some(1000), None, None, None), + (None, Some(1000), Some(999), Some(1002), None, Some(1)), + (None, Some(1000), Some(1000), None, None, Some(0)), + ( + Some(2001), + Some(1000), + Some(1005), + Some(999), + Some(1002), + Some(-5), + ), + (None, None, None, None, None, None), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .sub(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + })?, + Interval { + lower: ScalarValue::Int64(case.4), + upper: ScalarValue::Int64(case.5) + } + ) + } + Ok(()) } } From 1da529fbfffd336b230dd42e6a7b99db84bb952c Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 30 Jan 2023 16:54:06 +0300 Subject: [PATCH 08/39] After merge corrections --- datafusion-cli/Cargo.lock | 44 +++++++++++++-------------- datafusion/core/src/test_util.rs | 3 +- datafusion/core/tests/fifo.rs | 2 +- datafusion/physical-expr/src/utils.rs | 5 ++- 4 files changed, 29 insertions(+), 25 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index ac9370666964..0cd18cc3009e 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -10,9 +10,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf6ccdb167abbf410dcb915cabd428929d7f6a04980b54a11f26a39f1c7f7107" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", "const-random", @@ -290,9 +290,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.62" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "689894c2db1ea643a50834b999abf1c110887402542955ff5451dab8f861f9ed" +checksum = "eff18d764974428cf3a9328e23fc5c986f5fbed46e6cd4cdf42544df5d297ec1" dependencies = [ "proc-macro2", "quote", @@ -440,9 +440,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a20104e2335ce8a659d6dd92a51a767a0c062599c73b343fd152cb401e828c3d" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" dependencies = [ "jobserver", ] @@ -629,9 +629,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.87" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b61a7545f753a88bcbe0a70de1fcc0221e10bfc752f576754fa91e663db1622e" +checksum = "322296e2f2e5af4270b54df9e85a02ff037e271af20ba3e7fe1575515dc840b8" dependencies = [ "cc", "cxxbridge-flags", @@ -641,9 +641,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.87" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f464457d494b5ed6905c63b0c4704842aba319084a0a3561cdc1359536b53200" +checksum = "017a1385b05d631e7875b1f151c9f012d37b53491e2a87f65bff5c262b2111d8" dependencies = [ "cc", "codespan-reporting", @@ -656,15 +656,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.87" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43c7119ce3a3701ed81aca8410b9acf6fc399d2629d057b87e2efa4e63a3aaea" +checksum = "c26bbb078acf09bc1ecda02d4223f03bdd28bd4874edcb0379138efc499ce971" [[package]] name = "cxxbridge-macro" -version = "1.0.87" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65e07508b90551e610910fa648a1878991d367064997a596135b86df30daf07e" +checksum = "357f40d1f06a24b60ae1fe122542c1fb05d28d32acb2aed064e84bc2ad1e252e" dependencies = [ "proc-macro2", "quote", @@ -896,9 +896,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" [[package]] name = "encoding_rs" @@ -970,9 +970,9 @@ dependencies = [ [[package]] name = "fd-lock" -version = "3.0.8" +version = "3.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb21c69b9fea5e15dbc1049e4b77145dd0ba1c84019c488102de0dc4ea4b0a27" +checksum = "28c0190ff0bd3b28bfdd4d0cf9f92faa12880fb0b8ae2054723dd6c76a4efd42" dependencies = [ "cfg-if", "rustix", @@ -2449,9 +2449,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.24.2" +version = "1.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a12a59981d9e3c38d216785b0c37399f6e415e8d0712047620f189371b0bb" +checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" dependencies = [ "autocfg", "bytes", @@ -2575,9 +2575,9 @@ checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "unicode-bidi" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0046be40136ef78dc325e0edefccf84ccddacd0afcc1ca54103fa3c61bbdab1d" +checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" [[package]] name = "unicode-ident" diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index 0228d6038669..4ef6c7d52dec 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -24,6 +24,7 @@ use std::{env, error::Error, path::PathBuf, sync::Arc}; use crate::datasource::datasource::TableProviderFactory; use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::execution::context::SessionState; +use crate::execution::options::ReadOptions; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; @@ -414,7 +415,7 @@ pub async fn test_create_unbounded_sorted_file( .mark_infinite(true); // Get listing options let options_sort = fifo_options - .to_listing_options(ctx.copied_config().target_partitions()) + .to_listing_options(&ctx.copied_config()) .with_file_sort_order(Some(file_sort_order)); // Register table ctx.register_listing_table( diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 1d182835dc09..92d56a5262a0 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -280,7 +280,7 @@ mod unix_test { while *waiting_thread.lock().unwrap() && TEST_BATCH_SIZE < cnt { thread::sleep(Duration::from_millis(200)); } - let line = format!("{},{}\n", a1, a2).to_owned(); + let line = format!("{a1},{a2}\n").to_owned(); write_to_fifo(&file, &line, execution_start, broken_pipe_timeout) .unwrap(); } diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 07f435bcf86c..a91804340f1b 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -21,9 +21,9 @@ use crate::expressions::UnKnownColumn; use crate::expressions::{BinaryExpr, Literal}; use crate::rewrite::TreeNodeRewritable; use crate::PhysicalSortExpr; +use crate::{EquivalenceProperties, PhysicalExpr}; use arrow::datatypes::SchemaRef; use datafusion_common::{Result, ScalarValue}; -use crate::{EquivalenceProperties, PhysicalExpr}; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; @@ -237,6 +237,9 @@ impl ExprTreeNode { }) }) .collect() + } + } +} /// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. pub fn ordering_satisfy EquivalenceProperties>( provided: Option<&[PhysicalSortExpr]>, From 6921aafd7a6b55c5d1eff449e7b1ec2e5a566b3a Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Mon, 30 Jan 2023 22:23:55 -0600 Subject: [PATCH 09/39] Simplifications to constraint propagation code --- .../joins/symmetric_hash_join.rs | 2 +- .../physical-expr/src/intervals/cp_solver.rs | 558 ++++++++---------- .../src/intervals/interval_aritmetics.rs | 242 ++++---- 3 files changed, 368 insertions(+), 434 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index af41d0b376f7..84bad6d5042b 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -948,7 +948,7 @@ impl OneSideHashJoiner { .map(|sorted_expr| (sorted_expr.node_index(), sorted_expr.interval().clone())) .collect_vec(); // Use this vector to seed the child PhysicalExpr interval. - physical_expr_graph.calculate_new_intervals(&mut filter_intervals)?; + physical_expr_graph.update_intervals(&mut filter_intervals)?; // Mutate the Vec for for (sorted_expr, (_, interval)) in filter_columns.iter_mut().zip(filter_intervals.into_iter()) diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 425bba6b541b..c831fa1e4403 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -15,103 +15,105 @@ // specific language governing permissions and limitations // under the License. -//! It solves constraint propagation problem for the custom PhysicalExpr -//! -use crate::expressions::{BinaryExpr, CastExpr, Literal}; -use crate::intervals::interval_aritmetics::{negate_ops, Interval}; -use crate::utils::{build_physical_expr_graph, ExprTreeNode}; -use crate::PhysicalExpr; +//! Constraint propagator/solver for custom PhysicalExpr graphs. + use std::fmt::{Display, Formatter}; +use std::ops::{Index, IndexMut}; +use std::sync::Arc; +use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; +use petgraph::stable_graph::StableGraph; use petgraph::visit::{Bfs, DfsPostOrder}; use petgraph::Outgoing; -use arrow::compute::CastOptions; -use arrow_schema::DataType; -use petgraph::stable_graph::StableGraph; -use std::ops::{Index, IndexMut}; -use std::sync::Arc; +use crate::expressions::{BinaryExpr, CastExpr, Literal}; +use crate::intervals::interval_aritmetics::{apply_operator, Interval}; +use crate::utils::{build_physical_expr_graph, ExprTreeNode}; +use crate::PhysicalExpr; + +// Interval arithmetic provides a way to perform mathematical operations on +// intervals, which represent a range of possible values rather than a single +// point value. This allows for the propagation of ranges through mathematical +// operations, and can be used to compute bounds for a complicated expression. +// The key idea is that by breaking down a complicated expression into simpler +// terms, and then combining the bounds for those simpler terms, one can +// obtain bounds for the overall expression. +// +// For example, consider a mathematical expression such as x^2 + y = 4. Since +// it would be a binary tree in [PhysicalExpr] notation, this type of an +// hierarchical computation is well-suited for a graph based implementation. +// In such an implementation, an equation system f(x) = 0 is represented by a +// directed acyclic expression graph (DAEG). +// +// In order to use interval arithmetic to compute bounds for this expression, +// one would first determine intervals that represent the possible values of x +// and y. Let's say that the interval for x is [1, 2] and the interval for y +// is [-3, 1]. In the chart below, you can see how the computation takes place. +// +// This way of using interval arithmetic to compute bounds for a complex +// expression by combining the bounds for the constituent terms within the +// original expression allows us to reason about the range of possible values +// of the expression. This information later can be used in range pruning of +// the provably unnecessary parts of `RecordBatch`es. +// +// References +// 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval +// Arithmetic Based Approach, Chapter 4. Stanford University, 2015. +// 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. +// 3 - F. Messine, "Deterministic global optimization using interval constraint +// propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, +// pp. 277{293, 2004. +// +// ``` text +// Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression +// graph using inverse semantics. +// +// [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] +// +-----+ +-----+ +-----+ +-----+ +// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +// | | | | | | | | | | | | | | | | +// | +-----+ | | +-----+ | | +-----+ | | +-----+ | +// | | | | | | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* +// |[.] | | | |[.] | | | |[.] | | | |[.] | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | | | [-3, 1] | +// | | | | +// +---+ +---+ +---+ +---+ +// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] +// +---+ +---+ +---+ +---+ +// +// (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 +// +// [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] +// +-----+ +-----+ +-----+ +-----+ +// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +// | | | | | | | | | | | | | | | | +// | +-----+ | | +-----+ | | +-----+ | | +-----+ | +// | | | | | | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | +// |[.] | | | |[.] | | | |[.] | | | |[.] | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] +// | | | | +// +---+ +---+ +---+ +---+ +// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** +// +---+ +---+ +---+ +---+ +// +// (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 +// +// * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] +// ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] +// *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] +// ``` -/// -/// Interval arithmetic provides a way to perform mathematical operations on intervals, -/// which represents a range of possible values rather than a single point value. -/// This allows for the propagation of ranges through mathematical operations, -/// and can be used to compute bounds for a complicated expression.The key idea is that by -/// breaking down a complicated expression into simpler terms, and then combining the bounds for -/// those simpler terms, one can obtain bounds for the overall expression. -/// -/// For example, consider a mathematical expression such as x^2 + y = 4. Since it would be -/// a binary tree in [PhysicalExpr] notation, this type of an hierarchical -/// computation is well-suited for a graph based implementation. In such an implementation, -/// an equation system f(x) = 0 is represented by a directed acyclic expression -/// graph (DAEG). -/// -/// To use interval arithmetic to compute bounds for this expression, one would first determine -/// intervals that represent the possible values of x and y. -/// Let's say that the interval for x is [1, 2] and the interval for y is [-3, 1]. Below, you can -/// see how the computation happens. -/// -/// This way of using interval arithmetic to compute bounds for a complicated expression by -/// combining the bounds for the simpler terms within the original expression allows us to -/// reason about the range of possible values of the expression. This information later can be -/// used in range pruning of the unused parts of the [RecordBatch]es. -/// -/// References -/// 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval -/// Arithmetic Based Approach. Stanford University, 2015. -/// 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. -/// 3 - F. Messine, "Deterministic global optimization using interval constraint -/// propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, -/// pp. 277{293, 2004. -/// -/// ``` text -/// Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression -/// graph using inverse semantics. -/// -/// [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] -/// +-----+ +-----+ +-----+ +-----+ -/// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ -/// | | | | | | | | | | | | | | | | -/// | +-----+ | | +-----+ | | +-----+ | | +-----+ | -/// | | | | | | | | -/// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -/// | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* -/// |[.] | | | |[.] | | | |[.] | | | |[.] | | | -/// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -/// | | | [-3, 1] | -/// | | | | -/// +---+ +---+ +---+ +---+ -/// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] -/// +---+ +---+ +---+ +---+ -/// -/// (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 -/// -/// [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] -/// +-----+ +-----+ +-----+ +-----+ -/// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ -/// | | | | | | | | | | | | | | | | -/// | +-----+ | | +-----+ | | +-----+ | | +-----+ | -/// | | | | | | | | -/// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -/// | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | -/// |[.] | | | |[.] | | | |[.] | | | |[.] | | | -/// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -/// | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] -/// | | | | -/// +---+ +---+ +---+ +---+ -/// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** -/// +---+ +---+ +---+ +---+ -/// -/// (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 -/// -/// * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] -/// ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] -/// *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] -/// ``` -/// +/// This object implements a directed acyclic expression graph (DAEG) that +/// is used to compute ranges for expressions through interval arithmetic. #[derive(Clone)] pub struct ExprIntervalGraph( NodeIndex, @@ -119,79 +121,122 @@ pub struct ExprIntervalGraph( pub Vec<(Arc, usize)>, ); -/// This function calculates the current node interval -fn calculate_node_interval( +/// This is a node in the DAEG; it encapsulates a reference to the actual +/// [PhysicalExpr] as well as an interval containing expression bounds. +#[derive(Clone, Debug)] +pub struct ExprIntervalGraphNode { + expr: Arc, + interval: Interval, +} + +impl Display for ExprIntervalGraphNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.expr) + } +} + +impl ExprIntervalGraphNode { + /// Constructs a new DAEG node with an [-∞, ∞] range. + pub fn new(expr: Arc) -> Self { + ExprIntervalGraphNode { + expr, + interval: Interval::default(), + } + } + + /// Constructs a new DAEG node with the given range. + pub fn new_with_interval(expr: Arc, interval: Interval) -> Self { + ExprIntervalGraphNode { expr, interval } + } + + /// Get the interval object representing the range of the expression. + pub fn interval(&self) -> &Interval { + &self.interval + } + + /// This function creates a DAEG node from Datafusion's [ExprTreeNode] + /// object. Literals are created with definite, singleton intervals while + /// any other expression starts with an indefinite interval ([-∞, ∞]). + pub fn make_node(node: Arc) -> ExprIntervalGraphNode { + let expr = node.expr(); + if let Some(literal) = expr.as_any().downcast_ref::() { + let value = literal.value(); + let interval = Interval { + lower: value.clone(), + upper: value.clone(), + }; + ExprIntervalGraphNode::new_with_interval(expr, interval) + } else { + ExprIntervalGraphNode::new(expr) + } + } +} + +impl PartialEq for ExprIntervalGraphNode { + fn eq(&self, other: &ExprIntervalGraphNode) -> bool { + self.expr.eq(&other.expr) + } +} + +/// This function refines the interval `left_child` by propagating intervals +/// `parent` and `right_child` through the application of `inverse_op`, which +/// is the inverse operation of the operator in the corresponding DAEG node. +fn propagate_left( + inverse_op: &Operator, parent: &Interval, - current_child_interval: &Interval, - negated_op: &Operator, - other_child_interval: &Interval, + left_child: &Interval, + right_child: &Interval, ) -> Result { - let interv = apply_operator(parent, negated_op, other_child_interval)?; - match interv.intersect(current_child_interval) { + let interv = apply_operator(inverse_op, parent, right_child)?; + match interv.intersect(left_child) { Ok(Some(val)) => Ok(val), - Ok(None) => Ok(current_child_interval.clone()), + // TODO: We should not have this condition realize in practice. If this + // happens, it would mean the constraint you are propagating will + // never get satisfied (there is no satisfying value). When solving + // equations, that means we should stop and report there is no solution. + // For joins, I think it means we will never have a match. Therefore, + // I think we should just propagate this out in the return value all + // the way up and handle it properly so that we don't lose generality. + // Just returning the left interval without change is not wrong (i.e. + // it is conservatively true), but loses very useful information. + Ok(None) => Ok(left_child.clone()), Err(e) => Err(e), } } -pub fn propagate_arithmetic_operators( - first_child_interval: &Interval, - op: &Operator, - second_child_interval: &Interval, - parent_node_interval: &Interval, -) -> Result<(Interval, Interval)> { - let negated_op = negate_ops(*op); - let new_right_operator = calculate_node_interval( - parent_node_interval, - second_child_interval, - &negated_op, - first_child_interval, - )?; - let new_left_operator = calculate_node_interval( - parent_node_interval, - first_child_interval, - &negated_op, - &new_right_operator, - )?; - Ok((new_left_operator, new_right_operator)) +// This function returns the inverse operator of the given operator. +fn get_inverse_op(op: Operator) -> Operator { + match op { + Operator::Plus => Operator::Minus, + Operator::Minus => Operator::Plus, + _ => unreachable!(), + } } -/// Casting data type with arrow kernel -pub fn cast_scalar_value( - value: &ScalarValue, - data_type: &DataType, - cast_options: &CastOptions, -) -> Result { - let scalar_array = value.to_array(); - let cast_array = - arrow::compute::cast_with_options(&scalar_array, data_type, cast_options)?; - ScalarValue::try_from_array(&cast_array, 0) +/// This function refines intervals `left_child` and `right_child` by +/// applying constraint propagation through `parent` via `inverse_op`, which +/// is the inverse operation of the operator in the corresponding DAEG node. +pub fn propagate( + op: &Operator, + parent: &Interval, + left_child: &Interval, + right_child: &Interval, +) -> Result<(Interval, Interval)> { + let inverse_op = get_inverse_op(*op); + let new_right = propagate_left(&inverse_op, parent, right_child, left_child)?; + let new_left = propagate_left(&inverse_op, parent, left_child, &new_right)?; + Ok((new_left, new_right)) } -pub fn apply_operator(lhs: &Interval, op: &Operator, rhs: &Interval) -> Result { - match *op { - Operator::Eq => Ok(lhs.equal(rhs)), - Operator::Gt => Ok(lhs.gt(rhs)), - Operator::Lt => Ok(lhs.lt(rhs)), - Operator::And => lhs.and(rhs), - Operator::Plus => lhs.add(rhs), - Operator::Minus => lhs.sub(rhs), - _ => Ok(Interval { - lower: ScalarValue::Null, - upper: ScalarValue::Null, - }), - } -} -/// We use this function provide a target parent interval for comparison operators. -/// If expression > 0, that means expression must has interval of [0, ∞] -/// If expression < 0, that means expression must has interval of [-∞, 0] -/// -/// Currently, we only support GT and LT since we did not implemented open and closed intervals. -pub fn comparison_operator_target( - datatype: &DataType, - op: &Operator, -) -> Result { +/// This function provides a target parent interval for comparison operators. +/// If we have expression > 0, expression must have the range [0, ∞]. +/// If we have expression < 0, expression must have the range [-∞, 0]. +/// Currently, we only support strict inequalities since open/closed intervals +/// are not implemented yet. +fn comparison_operator_target(datatype: &DataType, op: &Operator) -> Result { let unbounded = ScalarValue::try_from(datatype)?; + // TODO: Going through a string zero to a numeric zero is very inefficient. + // Let's find an efficient way to initialize a zero value with given data type. let zero = ScalarValue::try_from_string("0".to_string(), datatype)?; Ok(match *op { Operator::Gt => Interval { @@ -205,130 +250,78 @@ pub fn comparison_operator_target( _ => unreachable!(), }) } -/// We propagate the intervals for comparison operators. Main idea is that if -/// [x1, y1] > [x2, y2], it can be rewritten as [x1, y1] + [-y2, -x2] = [0, ∞]. This is quite -/// powerful conversion, since we now can shrink the intervals. For less than operator, it is also -/// available that if [x1, y1] < [x2, y2], than it can be rewritten as [x1, y1] + [-y2, -x2] = [-∞, 0]. -/// -/// Let's say we have GT operator to demonstrate. We can further use this information to first -/// calculate the right side by using the not-yet-shrunk -/// left side by [x1_new, x2_new] = ([0, ∞] - [-y2, -x2]) ∩ [x1, y1]. -/// After calculated the new right side, than we can use the shrunk right side to shrink left side as -/// [-y2_new, -x2_new] = ([0, ∞] - [x1_new, x2_new]) ∩ [-y2, -x2]. -/// + +/// This function propagates constraints arising from comparison operators. +/// The main idea is that we can analyze an inequality like x > y through the +/// equivalent inequality x - y > 0. Assuming that x and y has ranges [xL, xU] +/// and [yL, yU], we simply apply constraint propagation across [xL, xU], +/// [yL, yH] and [0, ∞]. Specifically, we would first do +/// - [xL, xU] <- ([yL, yU] + [0, ∞]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([xL, xU] - [0, ∞]) ∩ [yL, yU]. pub fn propagate_comparison_operators( - left_interval: &Interval, op: &Operator, - right_interval: &Interval, + left_child: &Interval, + right_child: &Interval, ) -> Result<(Interval, Interval)> { - let parent_interval = comparison_operator_target(&left_interval.get_datatype(), op)?; - let negate_right = right_interval.arithmetic_negate()?; - let new_left = match parent_interval - .sub(&negate_right)? - .intersect(left_interval)? - { - Some(i) => i, - None => left_interval.clone(), - }; - let new_right = match parent_interval.sub(&new_left)?.intersect(&negate_right)? { - Some(i) => i, - None => right_interval.clone(), - }; - let range = (new_left, new_right.arithmetic_negate()?); - Ok(range) + let parent = comparison_operator_target(&left_child.get_datatype(), op)?; + // TODO: Same issue with propagate_left, we should report empty intervals + // and handle this outside. + let new_left = right_child + .add(&parent)? + .intersect(left_child)? + .unwrap_or_else(|| left_child.clone()); + let new_right = left_child + .sub(&parent)? + .intersect(right_child)? + .unwrap_or_else(|| right_child.clone()); + Ok((new_left, new_right)) } impl ExprIntervalGraph { - /// Constructs ExprIntervalGraph - /// - /// # Arguments - /// * `expr` - Arc. The complex expression that we compute bounds on by - /// traversing the graph. - /// * `provided_expr` - Arc that is used for child nodes. If one of them is - /// a Arc with multiple child, the child nodes of this node - /// will be pruned, since we do not need them while traversing. - /// - /// # Examples - /// - /// ``` - /// use std::sync::Arc; - /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; - /// use datafusion_physical_expr::intervals::ExprIntervalGraph; - /// use datafusion_physical_expr::PhysicalExpr; - /// // syn > gnz + 1 AND syn < gnz + 10 - /// let syn = Arc::new(Column::new("syn", 0)); - /// let left_and_2 = Arc::new(BinaryExpr::new( - /// Arc::new(Column::new("gnz", 0)), - /// Operator::Plus, - /// Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), - /// )); - /// let right_and_2 = Arc::new(BinaryExpr::new( - /// Arc::new(Column::new("gnz", 0)), - /// Operator::Plus, - /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), - /// )); - /// let left_expr = Arc::new(BinaryExpr::new(syn.clone(), Operator::Gt, left_and_2.clone())); - /// let right_expr = Arc::new(BinaryExpr::new(syn, Operator::Lt, right_and_2)); - /// let expr = Arc::new(BinaryExpr::new(left_expr, Operator::And, right_expr)); - /// let provided_exprs: Vec> = vec![left_and_2]; - /// // You can provide exprs for child prunning. - /// let graph = ExprIntervalGraph::try_new(expr, provided_exprs.as_slice()); - /// ``` + /// Constructs a new ExprIntervalGraph. pub fn try_new( expr: Arc, provided_expr: &[Arc], ) -> Result { + // Build the full graph: let (root_node, mut graph) = - build_physical_expr_graph(expr, &ExprIntervalGraphNode::expr_node_builder)?; + build_physical_expr_graph(expr, &ExprIntervalGraphNode::make_node)?; let mut bfs = Bfs::new(&graph, root_node); - let mut will_be_removed = vec![]; + // TODO: I don't understand what this comment means. Write a clearer + // comment and you should probably find a better name for `provided_expr` + // too. // We preserve the order with Vec in SymmetricHashJoin. let mut expr_node_indices: Vec<(Arc, usize)> = provided_expr .iter() .map(|e| (e.clone(), usize::MAX)) .collect(); + let mut removals = vec![]; while let Some(node) = bfs.next(&graph) { - // Get plan + // Get the plan corresponding to this node: let input = graph.index(node); let expr = input.expr.clone(); - // If we find a expr, remove the childs since we could not propagate the intervals there. + // If the current expression is among `provided_exprs`, slate its + // children for removal. + // TODO: Complete the following comment with a clear explanation: + // We remove the children because ............................... if let Some(value) = provided_expr.iter().position(|e| expr.eq(e)) { - // Remove all child nodes here, since we have information about them. If there is no child, - // iterator will not have anything. - let expr_node_index = expr_node_indices.get_mut(value).unwrap(); - *expr_node_index = (expr_node_index.0.clone(), node.index()); + expr_node_indices[value].1 = node.index(); let mut edges = graph.neighbors_directed(node, Outgoing).detach(); while let Some(n_index) = edges.next_node(&graph) { - // We will not delete while traversing - will_be_removed.push(n_index); + // Slate the child for removal, do not remove immediately. + removals.push(n_index); } } } - for node in will_be_removed { + for node in removals { graph.remove_node(node); } Ok(Self(root_node, graph, expr_node_indices)) } - /// Calculates new intervals and mutate the vector without changing the order. - /// - /// # Arguments - /// - /// * `expr_stats` - PhysicalExpr intervals. They are the child nodes for the interval calculation. - pub fn calculate_new_intervals( - &mut self, - expr_stats: &mut [(usize, Interval)], - ) -> Result<()> { - self.post_order_interval_calculation(expr_stats)?; - self.pre_order_interval_propagation(expr_stats)?; - Ok(()) - } - /// Computing bounds for an expression using interval arithmetic. - fn post_order_interval_calculation( - &mut self, - expr_stats: &[(usize, Interval)], - ) -> Result<()> { + + /// Computes bounds for an expression using interval arithmetic via a + /// bottom-up traversal. + fn evaluate_bounds(&mut self, expr_stats: &[(usize, Interval)]) -> Result<()> { let mut dfs = DfsPostOrder::new(&self.1, self.0); while let Some(node) = dfs.next(&self.1) { // Outgoing (Children) edges @@ -360,7 +353,7 @@ impl ExprIntervalGraph { let input = self.1.index_mut(node); // Calculate and replace the interval input.interval = - apply_operator(&left_interval, binary.op(), &right_interval)?; + apply_operator(binary.op(), &left_interval, &right_interval)?; } else if let Some(CastExpr { cast_type, cast_options, @@ -380,7 +373,9 @@ impl ExprIntervalGraph { Ok(()) } - pub fn pre_order_interval_propagation( + /// Updates/shrinks bounds for leaf expressions using interval arithmetic + /// via a top-down traversal. + pub fn propagate_constraints( &mut self, expr_stats: &mut [(usize, Interval)], ) -> Result<()> { @@ -422,16 +417,16 @@ impl ExprIntervalGraph { if let Interval{ lower: ScalarValue::Boolean(Some(false)), upper: ScalarValue::Boolean(Some(false))} = node_interval { continue } // Propagate the comparison operator. propagate_comparison_operators( - first_child_interval, binary.op(), + first_child_interval, second_child_interval, )? } else { // Propagate the arithmetic operator. - propagate_arithmetic_operators(first_child_interval, - binary.op(), - second_child_interval, - &node_interval)? + propagate(binary.op(), + &node_interval, + first_child_interval, + second_child_interval)? }; let mutable_first_child = self.1.index_mut(first_child_node_index); mutable_first_child.interval = shrink_left_interval.clone(); @@ -452,60 +447,15 @@ impl ExprIntervalGraph { } Ok(()) } -} - -#[derive(Clone, Debug)] -/// Graph nodes contains interval information. -pub struct ExprIntervalGraphNode { - expr: Arc, - interval: Interval, -} - -impl Display for ExprIntervalGraphNode { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.expr) - } -} - -impl ExprIntervalGraphNode { - // Construct ExprIntervalGraphNode - pub fn new(expr: Arc) -> Self { - ExprIntervalGraphNode { - expr, - interval: Interval::default(), - } - } - /// Specify interval - pub fn new_with_interval(expr: Arc, interval: Interval) -> Self { - ExprIntervalGraphNode { expr, interval } - } - /// Get interval - pub fn interval(&self) -> &Interval { - &self.interval - } - /// Within this static method, one can customize how PhysicalExpr generator constructs its nodes - pub fn expr_node_builder(input: Arc) -> ExprIntervalGraphNode { - let binding = input.expr(); - let plan_any = binding.as_any(); - let clone_expr = input.expr().clone(); - if let Some(literal) = plan_any.downcast_ref::() { - // Create interval - let value = literal.value(); - let interval = Interval { - lower: value.clone(), - upper: value.clone(), - }; - ExprIntervalGraphNode::new_with_interval(clone_expr, interval) - } else { - ExprIntervalGraphNode::new(clone_expr) - } - } -} - -impl PartialEq for ExprIntervalGraphNode { - fn eq(&self, other: &ExprIntervalGraphNode) -> bool { - self.expr.eq(&other.expr) + /// Updates intervals for all expressions in the DAEG by successive + /// bottom-up and top-down traversals. + pub fn update_intervals( + &mut self, + expr_stats: &mut [(usize, Interval)], + ) -> Result<()> { + self.evaluate_bounds(expr_stats) + .and(self.propagate_constraints(expr_stats)) } } @@ -576,7 +526,7 @@ mod tests { .map(|((_, interval), (_, index))| (*index, interval.clone())) .collect_vec(); - graph.calculate_new_intervals(&mut col_stat_nodes[..])?; + graph.update_intervals(&mut col_stat_nodes[..])?; col_stat_nodes .iter() .zip(expected_nodes.iter()) diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs index f271b04087f1..261659fabc3e 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetics.rs @@ -16,23 +16,23 @@ // under the License. //! Interval arithmetic library -//! + use std::borrow::Borrow; use std::fmt; use std::fmt::{Display, Formatter}; -use arrow::compute::CastOptions; +use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::DataType; -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Operator; use crate::aggregate::min_max::{max, min}; -use crate::intervals::cp_solver; /// This type represents an interval, which is used to calculate reliable /// bounds for expressions. Currently, we only support addition and /// subtraction, but more capabilities will be added in the future. +/// Upper/lower bounds having NULL values indicate an unbounded side. For +/// example; [10, 20], [10, ∞], [-∞, 100] and [-∞, ∞] are all valid intervals. #[derive(Debug, PartialEq, Clone, Eq, Hash)] pub struct Interval { pub lower: ScalarValue, @@ -61,8 +61,8 @@ impl Interval { cast_options: &CastOptions, ) -> Result { Ok(Interval { - lower: cp_solver::cast_scalar_value(&self.lower, data_type, cast_options)?, - upper: cp_solver::cast_scalar_value(&self.upper, data_type, cast_options)?, + lower: cast_scalar_value(&self.lower, data_type, cast_options)?, + upper: cast_scalar_value(&self.upper, data_type, cast_options)?, }) } @@ -70,27 +70,17 @@ impl Interval { self.lower.get_datatype() } - /// Null is treated as infinite. - /// [a , b] > [c, d] - /// where a, b, c or d can be infinite or a real value. We use None if a or c is minus infinity or - /// b or d is infinity. [10,20], [10, ∞], [-∞, 100] or [-∞, ∞] are all valid intervals. - /// While we are calculating correct comparison, we take the PartialOrd implementation of ScalarValue, - /// which is - /// None < a true - /// a < None false - /// None > a false - /// a > None true - /// None < None false - /// None > None false + /// Decide if this interval is certainly greater than, possibly greater than, + /// or can't be greater than `other` by returning [true, true], + /// [false, true] or [false, false] respectively. pub(crate) fn gt(&self, other: &Interval) -> Interval { let flags = if !self.upper.is_null() && !other.lower.is_null() - && self.upper < other.lower + && (self.upper <= other.lower) { (false, false) - // If lhs has lower negative infinity or rhs has upper infinity, this can not be completely true. - } else if !other.upper.is_null() - && !self.lower.is_null() + } else if !self.lower.is_null() + && !other.upper.is_null() && (self.lower > other.upper) { (true, true) @@ -102,15 +92,43 @@ impl Interval { upper: ScalarValue::Boolean(Some(flags.1)), } } - /// We define vague result as (false, true). If we "AND" a vague value with another, - /// result is - /// vague AND false -> false - /// vague AND true -> vague, - /// vague AND vague -> vague - /// false AND true -> true - /// - /// where "false" represented with (false, false), "true" represented with (true, true) and - /// vague represented with (false, true). + + /// Decide if this interval is certainly less than, possibly less than, + /// or can't be less than `other` by returning [true, true], + /// [false, true] or [false, false] respectively. + pub(crate) fn lt(&self, other: &Interval) -> Interval { + other.gt(self) + } + + /// Decide if this interval is certainly equal to, possibly equal to, + /// or can't be equal to `other` by returning [true, true], + /// [false, true] or [false, false] respectively. + pub(crate) fn equal(&self, other: &Interval) -> Interval { + let flags = if !self.lower.is_null() + && (self.lower == self.upper) + && (other.lower == other.upper) + && (self.lower == other.lower) + { + (true, true) + } else if (!self.lower.is_null() + && !other.upper.is_null() + && (self.lower > other.upper)) + || (!self.upper.is_null() + && !other.lower.is_null() + && (self.upper < other.lower)) + { + (false, false) + } else { + (false, true) + }; + Interval { + lower: ScalarValue::Boolean(Some(flags.0)), + upper: ScalarValue::Boolean(Some(flags.1)), + } + } + + /// Compute the logical conjunction of this (boolean) interval with the + /// given boolean interval. pub(crate) fn and(&self, other: &Interval) -> Result { let flags = match (self, other) { ( @@ -119,13 +137,13 @@ impl Interval { upper: ScalarValue::Boolean(Some(upper)), }, Interval { - lower: ScalarValue::Boolean(Some(lower2)), - upper: ScalarValue::Boolean(Some(upper2)), + lower: ScalarValue::Boolean(Some(other_lower)), + upper: ScalarValue::Boolean(Some(other_upper)), }, ) => { - if *lower && *lower2 { + if *lower && *other_lower { (true, true) - } else if *upper && *upper2 { + } else if *upper && *other_upper { (false, true) } else { (false, false) @@ -143,81 +161,33 @@ impl Interval { }) } - pub(crate) fn equal(&self, other: &Interval) -> Interval { - let flags = if (self.lower == self.upper) - && (other.lower == other.upper) - && (self.lower == other.lower) - { - (true, true) - } else if (self.lower > other.upper) || (self.upper < other.lower) { - (false, false) + /// Compute the intersection of the interval with the given interval. + /// If the intersection is empty, return None. + pub(crate) fn intersect(&self, other: &Interval) -> Result> { + let lower = if self.lower.is_null() { + other.lower.clone() + } else if other.lower.is_null() { + self.lower.clone() } else { - (false, true) + max(&self.lower, &other.lower)? }; - Interval { - lower: ScalarValue::Boolean(Some(flags.0)), - upper: ScalarValue::Boolean(Some(flags.1)), - } - } - /// Null is treated as infinite. - /// [a , b] < [c, d] - /// where a, b, c or d can be infinite or a real value. We use None if a or c is minus infinity or - /// b or d is infinity. [10,20], [10, ∞], [-∞, 100] or [-∞, ∞] are all valid intervals. - /// Implemented as complementary of greater than. - pub(crate) fn lt(&self, other: &Interval) -> Interval { - other.gt(self) - } - - /// Null is treated as infinite. - /// [a , b] < [c, d] - /// where a, b, c or d can be infinite or a real value. We use None if a or c is minus infinity or - /// b or d is infinity. [10,20], [10, ∞], [-∞, 100] or [-∞, ∞] are all valid intervals. - pub(crate) fn intersect(&self, other: &Interval) -> Result> { - Ok(match (self, other) { - ( - Interval { - lower: ScalarValue::Boolean(_), - upper: ScalarValue::Boolean(_), - }, - Interval { - lower: ScalarValue::Boolean(_), - upper: ScalarValue::Boolean(_), - }, - ) => None, - ( - Interval { - lower: low, - upper: high, - }, - Interval { - lower: other_low, - upper: other_high, - }, - ) => { - let lower = if low.is_null() { - other_low.clone() - } else if other_low.is_null() { - low.clone() - } else { - max(low, other_low)? - }; - let upper = if high.is_null() { - other_high.clone() - } else if other_high.is_null() { - high.clone() - } else { - min(high, other_high)? - }; - if !upper.is_null() && lower > upper { - // This None value signals an empty interval. - None - } else { - Some(Interval { lower, upper }) - } - } + let upper = if self.upper.is_null() { + other.upper.clone() + } else if other.upper.is_null() { + self.upper.clone() + } else { + min(&self.upper, &other.upper)? + }; + Ok(if !lower.is_null() && !upper.is_null() && lower > upper { + // This None value signals an empty interval. + None + } else { + Some(Interval { lower, upper }) }) } + // Compute the negation of the interval. + #[allow(dead_code)] pub(crate) fn arithmetic_negate(&self) -> Result { Ok(Interval { lower: self.upper.arithmetic_negate()?, @@ -225,12 +195,10 @@ impl Interval { }) } - /// The interval addition operation takes two intervals, say [a1, b1] and [a2, b2], - /// and returns a new interval that represents the range of possible values obtained by adding - /// any value from the first interval to any value from the second interval. - /// The resulting interval is defined as [a1 + a2, b1 + b2]. For example, - /// if we have the intervals [1, 2] and [3, 4], the interval addition - /// operation would return the interval [4, 6]. + /// Add the given interval (`other`) to this interval. Say we have + /// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2]. + /// Note that this represents all possible values the sum can take if + /// one can choose single values arbitrarily from each of the operands. pub fn add>(&self, other: T) -> Result { let rhs = other.borrow(); let lower = if self.lower.is_null() || rhs.lower.is_null() { @@ -246,10 +214,10 @@ impl Interval { Ok(Interval { lower, upper }) } - /// The interval subtraction operation is similar to the addition operation, - /// but it subtracts one interval from another. The resulting interval is defined as [a1 - b2, b1 - a2]. - /// For example, if we have the intervals [1, 2] and [3, 4], the interval subtraction operation - /// would return the interval [-3, -1]. + /// Subtract the given interval (`other`) from this interval. Say we have + /// intervals [a1, b1] and [a2, b2], then their sum is [a1 - b2, b1 - a2]. + /// Note that this represents all possible values the difference can take + /// if one can choose single values arbitrarily from each of the operands. pub fn sub>(&self, other: T) -> Result { let rhs = other.borrow(); let lower = if self.lower.is_null() || rhs.upper.is_null() { @@ -266,18 +234,35 @@ impl Interval { } } -pub fn negate_ops(op: Operator) -> Operator { - match op { - Operator::Plus => Operator::Minus, - Operator::Minus => Operator::Plus, - _ => unreachable!(), +pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { + match *op { + Operator::Eq => Ok(lhs.equal(rhs)), + Operator::Gt => Ok(lhs.gt(rhs)), + Operator::Lt => Ok(lhs.lt(rhs)), + Operator::And => lhs.and(rhs), + Operator::Plus => lhs.add(rhs), + Operator::Minus => lhs.sub(rhs), + _ => Ok(Interval { + lower: ScalarValue::Null, + upper: ScalarValue::Null, + }), } } + +/// Cast scalar value to the given data type using an arrow kernel. +fn cast_scalar_value( + value: &ScalarValue, + data_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let cast_array = cast_with_options(&value.to_array(), data_type, cast_options)?; + ScalarValue::try_from_array(&cast_array, 0) +} + #[cfg(test)] mod tests { use crate::intervals::interval_aritmetics::Interval; - use datafusion_common::Result; - use datafusion_common::ScalarValue; + use datafusion_common::{Result, ScalarValue}; #[test] fn intersect_test() -> Result<()> { @@ -317,17 +302,14 @@ mod tests { ) } - let not_possible_cases = vec![ + let empty_cases = vec![ (None, Some(1000), Some(1001), None), (Some(1001), None, None, Some(1000)), (None, Some(1000), Some(1001), Some(1002)), (Some(1001), Some(1002), None, Some(1000)), ]; - //(None, Some(1000), Some(1001), None, --), - // (None, Some(1000), Some(1000), None, false, true), - // (None, Some(1000), Some(1001), Some(1002), false, false), - for case in not_possible_cases { + for case in empty_cases { assert_eq!( Interval { lower: ScalarValue::Int64(case.0), @@ -351,11 +333,13 @@ mod tests { (None, Some(1000), None, None, false, true), (None, None, Some(1000), None, false, true), (None, None, None, Some(1000), false, true), - (None, Some(1000), Some(1000), None, false, true), + (None, Some(1000), Some(1000), None, false, false), (None, Some(1000), Some(1001), None, false, false), (Some(1000), None, Some(1000), None, false, true), (None, Some(1000), Some(1001), Some(1002), false, false), (None, Some(1000), Some(999), Some(1002), false, true), + (Some(1002), None, Some(999), Some(1002), false, true), + (Some(1003), None, Some(999), Some(1002), true, true), (None, None, None, None, false, true), ]; From 582573ebace74ac85f0c059a22cb75c5bbfec88d Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Tue, 31 Jan 2023 15:50:46 +0300 Subject: [PATCH 10/39] Revamp some API's and enhance comments - Utilize estimate_bounds without propagation for better API. - Remove coupling between node_index & PhysicalExpr pairing and graph. - Better commenting on symmetric hash join while using graph --- .../joins/symmetric_hash_join.rs | 161 +++++++++++-- .../physical-expr/src/intervals/cp_solver.rs | 218 ++++++++++++------ 2 files changed, 282 insertions(+), 97 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 84bad6d5042b..1f3c5b0ae737 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -266,18 +266,23 @@ impl SymmetricHashJoinExec { )?; // Create expression graph for PhysicalExpr. - let physical_expr_graph = ExprIntervalGraph::try_new( - filter.expression().clone(), - &filter_columns - .iter() - .map(|sorted_expr| sorted_expr.filter_expr()) - .collect_vec(), - )?; + let mut physical_expr_graph = + ExprIntervalGraph::try_new(filter.expression().clone())?; + + // Since the graph is stable, we reuse NodeIndex of our filter expressions. + let child_node_indexes = physical_expr_graph + .pair_node_indices_with_interval_providers( + &filter_columns + .iter() + .map(|sorted_expr| sorted_expr.filter_expr()) + .collect_vec(), + )?; + // We inject calculated node indexes into SortedFilterExpr. In graph calculations, // we will be using node index to put calculated intervals into Columns or BinaryExprs . filter_columns .iter_mut() - .zip(physical_expr_graph.2.iter()) + .zip(child_node_indexes.iter()) .map(|(sorted_expr, (_, index))| { sorted_expr.set_node_index(*index); }) @@ -458,7 +463,6 @@ impl ExecutionPlan for SymmetricHashJoinExec { let left_side_joiner = OneSideHashJoiner::new(JoinSide::Left, self.left.clone()); let right_side_joiner = OneSideHashJoiner::new(JoinSide::Right, self.right.clone()); - // TODO: Discuss unbounded mpsc and bounded one. let left_stream = self.left.execute(partition, context.clone())?; let right_stream = self.right.execute(partition, context)?; @@ -581,13 +585,41 @@ fn prune_visited_rows( Ok(()) } -/// We can prune build side when a probe batch comes. -fn column_stats_two_side( +/// We calculate PhysicalExpr boundaries. We get first ScalarValues from build side and +/// last ScalarValues from probe side, then update the SortedFilterExpr intervals. +/// +/// Build Probe +/// +-------+ +-------+ +/// |+-----+| | | | +/// ||a |b || | | | +/// |+--|--+| | | | +/// | | | | | | +/// | | | | | | +/// | | | | | | +/// | | | | | | +/// | | | |+-----+| +/// | | | ||c |d || +/// | | | |+--|--+| +/// +-------+ +-------+ +/// Expr1: [a, inf] +/// Expr2: [b, inf] +/// Expr3: [c, inf] +/// Expr4: [d, inf] +/// +/// Each probe batch trigger recalculation. Remember that the calculation logic differs +/// according to the build side. When the build side is changed by a new data arrival from other child, +/// [SortedFilterExpr] intervals change according to the side. You must use these calculated intervals +/// within same side, per [RecordBatch] arrival. +/// +fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, filter_columns: &mut [SortedFilterExpr], build_side: JoinSide, ) -> Result<()> { + if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { + return Ok(()); + } for sorted_expr in filter_columns.iter_mut() { let SortedFilterExpr { join_side, @@ -840,7 +872,7 @@ impl OneSideHashJoiner { random_state: &RandomState, null_equals_null: &bool, ) -> ArrowResult> { - if self.input_buffer.num_rows() == 0 { + if self.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(Some(RecordBatch::new_empty(schema))); } let (build_side, probe_side) = build_join_indices( @@ -933,13 +965,6 @@ impl OneSideHashJoiner { if self.input_buffer.num_rows() == 0 { return Ok(None); } - column_stats_two_side( - &self.input_buffer, - probe_batch, - filter_columns, - self.build_side, - )?; - // We use Vec<(usize, Interval)> instead of Hashmap since the expected // filter exprs relatively low and conversion between Vec and Hashmap // back and forth may be slower. @@ -949,15 +974,25 @@ impl OneSideHashJoiner { .collect_vec(); // Use this vector to seed the child PhysicalExpr interval. physical_expr_graph.update_intervals(&mut filter_intervals)?; - // Mutate the Vec for + let mut change_track = Vec::with_capacity(filter_intervals.len()); + // Mutate the Vec for intervals if it is not equal. + // Keep track of the interval changes. for (sorted_expr, (_, interval)) in filter_columns.iter_mut().zip(filter_intervals.into_iter()) { - sorted_expr.set_interval(interval.clone()) + if interval.eq(sorted_expr.interval()) { + change_track.push(false); + } else { + sorted_expr.set_interval(interval.clone()); + change_track.push(true); + } } - - let prune_length = - determine_prune_length(&self.input_buffer, filter_columns, self.build_side)?; + // If any interval change, try to determine prune index. + let prune_length = if change_track.iter().any(|x| *x) { + determine_prune_length(&self.input_buffer, filter_columns, self.build_side)? + } else { + 0 + }; if prune_length > 0 { let result = self.build_side_determined_results( schema, @@ -1057,6 +1092,13 @@ impl SymmetricHashJoinStream { )))) } } + // Each probe batch trigger recalculation. + calculate_filter_expr_intervals( + &self.right.input_buffer, + &probe_batch, + &mut self.filter_columns, + JoinSide::Right, + )?; // Using right as build side. let equal_result = self.right.record_batch_from_other_side( self.schema.clone(), @@ -1130,6 +1172,12 @@ impl SymmetricHashJoinStream { )))) } } + calculate_filter_expr_intervals( + &self.left.input_buffer, + &probe_batch, + &mut self.filter_columns, + JoinSide::Left, + )?; let equal_result = self.left.record_batch_from_other_side( self.schema.clone(), self.join_type, @@ -1144,6 +1192,7 @@ impl SymmetricHashJoinStream { &self.null_equals_null, ); self.right.offset += probe_batch.num_rows(); + let anti_result = self.left.prune_build_side( self.schema.clone(), &probe_batch, @@ -2141,4 +2190,68 @@ mod tests { experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) } + + #[tokio::test(flavor = "multi_thread")] + async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> { + let cardinality = (3, 4); + let join_type = JoinType::Full; + + // a + b > c + 10 AND a + b < c + 100 + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), + options: SortOptions::default(), + }]; + + let right_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("ra1", &right_batch.schema())?), + options: SortOptions::default(), + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + + let on = vec![( + Column::new_with_schema("lc1", &left.schema())?, + Column::new_with_schema("rc1", &right.schema())?, + )]; + + let filter_col_0 = Arc::new(Column::new("0", 0)); + let filter_col_1 = Arc::new(Column::new("1", 1)); + let filter_col_2 = Arc::new(Column::new("2", 2)); + + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(filter_col_0.name(), DataType::Int32, true), + Field::new(filter_col_1.name(), DataType::Int32, true), + Field::new(filter_col_2.name(), DataType::Int32, true), + ]); + + let filter_expr = complicated_filter(); + + let filter = JoinFilter::new( + filter_expr, + column_indices.clone(), + intermediate_schema.clone(), + ); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index c831fa1e4403..c02b76514c16 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -115,11 +115,7 @@ use crate::PhysicalExpr; /// This object implements a directed acyclic expression graph (DAEG) that /// is used to compute ranges for expressions through interval arithmetic. #[derive(Clone)] -pub struct ExprIntervalGraph( - NodeIndex, - StableGraph, - pub Vec<(Arc, usize)>, -); +pub struct ExprIntervalGraph(NodeIndex, StableGraph); /// This is a node in the DAEG; it encapsulates a reference to the actual /// [PhysicalExpr] as well as an interval containing expression bounds. @@ -278,50 +274,134 @@ pub fn propagate_comparison_operators( } impl ExprIntervalGraph { - /// Constructs a new ExprIntervalGraph. - pub fn try_new( - expr: Arc, - provided_expr: &[Arc], - ) -> Result { + pub fn try_new(expr: Arc) -> Result { // Build the full graph: - let (root_node, mut graph) = + let (root_node, graph) = build_physical_expr_graph(expr, &ExprIntervalGraphNode::make_node)?; - let mut bfs = Bfs::new(&graph, root_node); - // TODO: I don't understand what this comment means. Write a clearer - // comment and you should probably find a better name for `provided_expr` - // too. - // We preserve the order with Vec in SymmetricHashJoin. - let mut expr_node_indices: Vec<(Arc, usize)> = provided_expr - .iter() - .map(|e| (e.clone(), usize::MAX)) - .collect(); + Ok(Self(root_node, graph)) + } + + /// + /// ``` text + /// + /// + /// If we want to calculate intervals starting from and propagate up to BinaryExpr('a', +, 'b') + /// instead of col('a') and col('b'), we prune under BinaryExpr('a', +, 'b') since we will + /// never visit there. + /// We do not change the Arc inside Arc, just remove the nodes + /// corresponding that Arc. + /// + /// + /// +-----+ +-----+ + /// | GT | | GT | + /// +--------| |-------+ +--------| |-------+ + /// | +-----+ | | +-----+ | + /// | | | | + /// +-----+ | +-----+ | + /// |Cast | | |Cast | | + /// | | | --\ | | | + /// +-----+ | ---------- +-----+ | + /// | | --/ | | + /// | | | | + /// +-----+ +-----+ +-----+ +-----+ + /// +--|Plus |--+ +--|Plus |--+ |Plus | +--|Plus |--+ + /// | | | | | | | | | | | | | | + /// Prune from here | +-----+ | | +-----+ | +-----+ | +-----+ | + /// ------------------------------------ | | | | + /// | | | | | | + /// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ + /// | a | | b | | c | | 2 | | c | | 2 | + /// | | | | | | | | | | | | + /// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ + /// + /// This operation mutates the underline graph, so use it after constructor. + /// ``` + /// + /// Since graph is stable, the NodeIndex will stay same even if we delete a node. We exploit + /// this stability by matching Arc and NodeIndex for membership tests. + pub fn pair_node_indices_with_interval_providers( + &mut self, + exprs: &[Arc], + ) -> Result, usize)>> { + let mut bfs = Bfs::new(&self.1, self.0); + // We collect the node indexes (usize) of the PhysicalExprs, with the same order + // with `exprs`. To preserve order, we initiate each expr's node index with usize::MAX, + // then find the corresponding node indexes by traversing the graph. let mut removals = vec![]; - while let Some(node) = bfs.next(&graph) { + let mut expr_node_indices: Vec<(Arc, usize)> = + exprs.iter().map(|e| (e.clone(), usize::MAX)).collect(); + while let Some(node) = bfs.next(&self.1) { // Get the plan corresponding to this node: - let input = graph.index(node); + let input = self.1.index(node); let expr = input.expr.clone(); - // If the current expression is among `provided_exprs`, slate its + // If the current expression is among `exprs`, slate its // children for removal. - // TODO: Complete the following comment with a clear explanation: - // We remove the children because ............................... - if let Some(value) = provided_expr.iter().position(|e| expr.eq(e)) { + if let Some(value) = exprs.iter().position(|e| expr.eq(e)) { + // Update NodeIndex of the PhysicalExpr expr_node_indices[value].1 = node.index(); - let mut edges = graph.neighbors_directed(node, Outgoing).detach(); - while let Some(n_index) = edges.next_node(&graph) { + let mut edges = self.1.neighbors_directed(node, Outgoing).detach(); + while let Some(n_index) = edges.next_node(&self.1) { // Slate the child for removal, do not remove immediately. removals.push(n_index); } } } for node in removals { - graph.remove_node(node); + self.1.remove_node(node); } - Ok(Self(root_node, graph, expr_node_indices)) + Ok(expr_node_indices) } /// Computes bounds for an expression using interval arithmetic via a /// bottom-up traversal. - fn evaluate_bounds(&mut self, expr_stats: &[(usize, Interval)]) -> Result<()> { + /// It returns root node's interval. + /// + /// # Arguments + /// * `expr_stats` - &[(usize, Interval)]. Provide NodeIndex, Interval tuples for bound evaluation. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::Operator; + /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + /// use datafusion_physical_expr::intervals::ExprIntervalGraph; + /// use datafusion_physical_expr::intervals::interval_aritmetics::Interval; + /// use datafusion_physical_expr::PhysicalExpr; + /// let expr = Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("gnz", 0)), + /// Operator::Plus, + /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + /// )); + /// let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + /// // Do it once, while constructing. + /// let node_indices = graph + /// .pair_node_indices_with_interval_providers(&[Arc::new(Column::new("gnz", 0))]) + /// .unwrap(); + /// let left_index = node_indices.get(0).unwrap().1; + /// // Provide intervals + /// let intervals = vec![( + /// left_index, + /// Interval { + /// lower: ScalarValue::Int32(Some(10)), + /// upper: ScalarValue::Int32(Some(20)), + /// }, + /// )]; + /// // Evaluate bounds + /// assert_eq!( + /// graph.evaluate_bounds(&intervals).unwrap(), + /// Interval { + /// lower: ScalarValue::Int32(Some(20)), + /// upper: ScalarValue::Int32(Some(30)) + /// } + /// ) + /// + /// ``` + pub fn evaluate_bounds( + &mut self, + expr_stats: &[(usize, Interval)], + ) -> Result { let mut dfs = DfsPostOrder::new(&self.1, self.0); while let Some(node) = dfs.next(&self.1) { // Outgoing (Children) edges @@ -337,7 +417,6 @@ impl ExprIntervalGraph { input.interval = interval.clone(); continue; } - let expr_any = expr.as_any(); if let Some(binary) = expr_any.downcast_ref::() { // Access left child immutable @@ -370,12 +449,13 @@ impl ExprIntervalGraph { input.interval = new_interval; } } - Ok(()) + let root_interval = self.1.index(self.0).interval.clone(); + Ok(root_interval) } /// Updates/shrinks bounds for leaf expressions using interval arithmetic /// via a top-down traversal. - pub fn propagate_constraints( + fn propagate_constraints( &mut self, expr_stats: &mut [(usize, Interval)], ) -> Result<()> { @@ -455,7 +535,13 @@ impl ExprIntervalGraph { expr_stats: &mut [(usize, Interval)], ) -> Result<()> { self.evaluate_bounds(expr_stats) - .and(self.propagate_constraints(expr_stats)) + .and_then(|interval| match interval { + Interval { + upper: ScalarValue::Boolean(Some(true)), + .. + } => self.propagate_constraints(expr_stats), + _ => Ok(()), + }) } } @@ -473,7 +559,7 @@ mod tests { fn experiment( expr: Arc, - exprs: (Arc, Arc), + exprs_with_interval: (Arc, Arc), left_interval: (Option, Option), right_interval: (Option, Option), left_waited: (Option, Option), @@ -481,14 +567,14 @@ mod tests { ) -> Result<()> { let col_stats = vec![ ( - exprs.0.clone(), + exprs_with_interval.0.clone(), Interval { lower: ScalarValue::Int32(left_interval.0), upper: ScalarValue::Int32(left_interval.1), }, ), ( - exprs.1.clone(), + exprs_with_interval.1.clone(), Interval { lower: ScalarValue::Int32(right_interval.0), upper: ScalarValue::Int32(right_interval.1), @@ -497,32 +583,33 @@ mod tests { ]; let expected = vec![ ( - exprs.0.clone(), + exprs_with_interval.0.clone(), Interval { lower: ScalarValue::Int32(left_waited.0), upper: ScalarValue::Int32(left_waited.1), }, ), ( - exprs.1.clone(), + exprs_with_interval.1.clone(), Interval { lower: ScalarValue::Int32(right_waited.0), upper: ScalarValue::Int32(right_waited.1), }, ), ]; - let mut graph = ExprIntervalGraph::try_new( - expr, + let mut graph = ExprIntervalGraph::try_new(expr)?; + let expr_indexes = graph.pair_node_indices_with_interval_providers( &col_stats.iter().map(|(e, _)| e.clone()).collect_vec(), )?; + let mut col_stat_nodes = col_stats .iter() - .zip(graph.2.iter()) + .zip(expr_indexes.iter()) .map(|((_, interval), (_, index))| (*index, interval.clone())) .collect_vec(); let expected_nodes = expected .iter() - .zip(graph.2.iter()) + .zip(expr_indexes.iter()) .map(|((_, interval), (_, index))| (*index, interval.clone())) .collect_vec(); @@ -593,40 +680,25 @@ mod tests { } #[test] - fn particular() -> Result<()> { - let seed = 0; + fn testing_not_possible() -> Result<()> { let left_col = Arc::new(Column::new("left_watermark", 0)); let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 - let expr = filter_numeric_expr_generation( + // left_watermark > right_watermark + 5 + + let left_and_1 = Arc::new(BinaryExpr::new( left_col.clone(), - right_col.clone(), - Operator::Minus, - Operator::Plus, - Operator::Plus, Operator::Plus, - 1, - 5, - 3, - 10, - ); - // l > r + 6 AND r > l - 7 - let l_gt_r = 6; - let r_gt_l = -7; - generate_case::( - expr.clone(), - left_col.clone(), - right_col.clone(), - seed, - l_gt_r, - r_gt_l, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone())); + experiment( + expr, + (left_col, right_col), + (Some(10), Some(20)), + (Some(100), None), + (Some(10), Some(20)), + (Some(100), None), )?; - // Descending tests - // r < l - 6 AND l < r + 7 - let r_lt_l = -l_gt_r; - let l_lt_r = -r_gt_l; - generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; - Ok(()) } From 7638902de70599dc15e5937e3535f3656e344d6f Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Thu, 2 Feb 2023 17:14:46 +0300 Subject: [PATCH 11/39] Resolve a propagation bug and make the propagation returns an opt. status --- .../physical-expr/src/intervals/cp_solver.rs | 193 +++++++++++------- 1 file changed, 118 insertions(+), 75 deletions(-) diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index c02b76514c16..bdb03eae322f 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -117,6 +117,14 @@ use crate::PhysicalExpr; #[derive(Clone)] pub struct ExprIntervalGraph(NodeIndex, StableGraph); +/// Result corresponding the 'update_interval' call. +#[derive(PartialEq, Debug)] +pub enum OptimizationResult { + CannotPropagate, + UnfeasibleSolution, + Success, +} + /// This is a node in the DAEG; it encapsulates a reference to the actual /// [PhysicalExpr] as well as an interval containing expression bounds. #[derive(Clone, Debug)] @@ -174,32 +182,6 @@ impl PartialEq for ExprIntervalGraphNode { } } -/// This function refines the interval `left_child` by propagating intervals -/// `parent` and `right_child` through the application of `inverse_op`, which -/// is the inverse operation of the operator in the corresponding DAEG node. -fn propagate_left( - inverse_op: &Operator, - parent: &Interval, - left_child: &Interval, - right_child: &Interval, -) -> Result { - let interv = apply_operator(inverse_op, parent, right_child)?; - match interv.intersect(left_child) { - Ok(Some(val)) => Ok(val), - // TODO: We should not have this condition realize in practice. If this - // happens, it would mean the constraint you are propagating will - // never get satisfied (there is no satisfying value). When solving - // equations, that means we should stop and report there is no solution. - // For joins, I think it means we will never have a match. Therefore, - // I think we should just propagate this out in the return value all - // the way up and handle it properly so that we don't lose generality. - // Just returning the left interval without change is not wrong (i.e. - // it is conservatively true), but loses very useful information. - Ok(None) => Ok(left_child.clone()), - Err(e) => Err(e), - } -} - // This function returns the inverse operator of the given operator. fn get_inverse_op(op: Operator) -> Operator { match op { @@ -210,18 +192,60 @@ fn get_inverse_op(op: Operator) -> Operator { } /// This function refines intervals `left_child` and `right_child` by -/// applying constraint propagation through `parent` via `inverse_op`, which -/// is the inverse operation of the operator in the corresponding DAEG node. +/// applying constraint propagation through `parent` via operation. +/// +/// The main idea is that we can shrink the domains of variables x and y by using +/// parent interval p. +/// +/// Assuming that x,y and p has ranges [xL, xU], [yL, yU], and [pL, pU], we simply apply +/// constraint propagation across [xL, xU], [yL, yH] and [pL, pU]. +/// - For minus operation, specifically, we would first do +/// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. +/// - For plus operation, specifically, we would first do +/// - [xL, xU] <- ([pL, pU] - [yL, yU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([pL, pU] - [xL, xU]) ∩ [yL, yU]. pub fn propagate( op: &Operator, parent: &Interval, left_child: &Interval, right_child: &Interval, -) -> Result<(Interval, Interval)> { +) -> Result<(Option, Option)> { let inverse_op = get_inverse_op(*op); - let new_right = propagate_left(&inverse_op, parent, right_child, left_child)?; - let new_left = propagate_left(&inverse_op, parent, left_child, &new_right)?; - Ok((new_left, new_right)) + Ok( + match apply_operator(&inverse_op, parent, right_child)?.intersect(left_child)? { + // Left is feasible + Some(value) => { + // Adjust the propagation operator and parent-child order. + let (o, l, r) = match op { + Operator::Minus => (op, &value, parent), + Operator::Plus => (&inverse_op, parent, &value), + _ => unreachable!(), + }; + // Calculate right child with new left child. + let right = apply_operator(o, l, r)?.intersect(right_child)?; + // Return intervals for both children. + (Some(value), right) + } + // If left child is not feasible, return both childs None. + None => (None, None), + }, + ) +} + +fn get_zero_scalar(datatype: &DataType) -> ScalarValue { + assert!(datatype.is_primitive()); + match datatype { + DataType::Int8 => ScalarValue::Int8(Some(0)), + DataType::Int16 => ScalarValue::Int16(Some(0)), + DataType::Int32 => ScalarValue::Int32(Some(0)), + DataType::Int64 => ScalarValue::Int64(Some(0)), + DataType::UInt8 => ScalarValue::UInt8(Some(0)), + DataType::UInt16 => ScalarValue::UInt16(Some(0)), + DataType::UInt32 => ScalarValue::UInt32(Some(0)), + DataType::UInt64 => ScalarValue::UInt64(Some(0)), + _ => unreachable!(), + } } /// This function provides a target parent interval for comparison operators. @@ -231,9 +255,7 @@ pub fn propagate( /// are not implemented yet. fn comparison_operator_target(datatype: &DataType, op: &Operator) -> Result { let unbounded = ScalarValue::try_from(datatype)?; - // TODO: Going through a string zero to a numeric zero is very inefficient. - // Let's find an efficient way to initialize a zero value with given data type. - let zero = ScalarValue::try_from_string("0".to_string(), datatype)?; + let zero = get_zero_scalar(datatype); Ok(match *op { Operator::Gt => Interval { lower: zero, @@ -254,23 +276,15 @@ fn comparison_operator_target(datatype: &DataType, op: &Operator) -> Result [x2, y2] can be rewritten as [xL, xU] - [yL, yU] = [0, ∞] +/// and [xL, xU] < [yL, yU] can be rewritten as [xL, xU] - [yL, yU] = [-∞, 0]. pub fn propagate_comparison_operators( op: &Operator, left_child: &Interval, right_child: &Interval, -) -> Result<(Interval, Interval)> { +) -> Result<(Option, Option)> { let parent = comparison_operator_target(&left_child.get_datatype(), op)?; - // TODO: Same issue with propagate_left, we should report empty intervals - // and handle this outside. - let new_left = right_child - .add(&parent)? - .intersect(left_child)? - .unwrap_or_else(|| left_child.clone()); - let new_right = left_child - .sub(&parent)? - .intersect(right_child)? - .unwrap_or_else(|| right_child.clone()); - Ok((new_left, new_right)) + propagate(&Operator::Minus, &parent, left_child, right_child) } impl ExprIntervalGraph { @@ -288,6 +302,18 @@ impl ExprIntervalGraph { /// If we want to calculate intervals starting from and propagate up to BinaryExpr('a', +, 'b') /// instead of col('a') and col('b'), we prune under BinaryExpr('a', +, 'b') since we will /// never visit there. + /// + /// One of the reasons behind pruning the expression interval graph is that if we provide + /// + /// PhysicalSortExpr { + /// expr: BinaryExpr('a', +, 'b'), + /// sort_option: .. + /// } + /// + /// in output_order of one of the SymmetricHashJoin's childs. In this case, we must calculate + /// the interval for the BinaryExpr('a', +, 'b') instead of the columns inside the BinaryExpr. Thus, + /// the child PhysicalExprs of the BinaryExpr may be pruned for the performance. + /// /// We do not change the Arc inside Arc, just remove the nodes /// corresponding that Arc. /// @@ -455,10 +481,11 @@ impl ExprIntervalGraph { /// Updates/shrinks bounds for leaf expressions using interval arithmetic /// via a top-down traversal. + /// If return false, it means we have a unfeasible solution. fn propagate_constraints( &mut self, expr_stats: &mut [(usize, Interval)], - ) -> Result<()> { + ) -> Result { let mut bfs = Bfs::new(&self.1, self.0); while let Some(node) = bfs.next(&self.1) { // Get plan @@ -488,30 +515,42 @@ impl ExprIntervalGraph { let first_child_interval = self.1.index(first_child_node_index).interval(); - let (shrink_left_interval, shrink_right_interval) = + if let (Some(shrink_left_interval), Some(shrink_right_interval)) = // There is no propagation process for logical operators. - if binary.op().is_logic_operator(){ - continue - } else if binary.op().is_comparison_operator() { - // If comparison is strictly false, there is nothing to do for shrink. - if let Interval{ lower: ScalarValue::Boolean(Some(false)), upper: ScalarValue::Boolean(Some(false))} = node_interval { continue } - // Propagate the comparison operator. - propagate_comparison_operators( - binary.op(), - first_child_interval, - second_child_interval, - )? - } else { - // Propagate the arithmetic operator. - propagate(binary.op(), - &node_interval, - first_child_interval, - second_child_interval)? - }; - let mutable_first_child = self.1.index_mut(first_child_node_index); - mutable_first_child.interval = shrink_left_interval.clone(); - let mutable_second_child = self.1.index_mut(second_child_node_index); - mutable_second_child.interval = shrink_right_interval.clone(); + if binary.op().is_logic_operator() { + continue; + } else if binary.op().is_comparison_operator() { + // If comparison is strictly false, there is nothing to do for shrink. + if let Interval { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + } = node_interval + { + continue; + } + // Propagate the comparison operator. + propagate_comparison_operators( + binary.op(), + first_child_interval, + second_child_interval, + )? + } else { + // Propagate the arithmetic operator. + propagate( + binary.op(), + &node_interval, + first_child_interval, + second_child_interval, + )? + } + { + let mutable_first_child = self.1.index_mut(first_child_node_index); + mutable_first_child.interval = shrink_left_interval.clone(); + let mutable_second_child = self.1.index_mut(second_child_node_index); + mutable_second_child.interval = shrink_right_interval.clone(); + } else { + return Ok(OptimizationResult::UnfeasibleSolution); + }; } else if let Some(cast) = expr_any.downcast_ref::() { // Calculate new interval let child_index = edges.next_node(&self.1).unwrap(); @@ -525,7 +564,7 @@ impl ExprIntervalGraph { mutable_child.interval = new_child_interval.clone(); } } - Ok(()) + Ok(OptimizationResult::Success) } /// Updates intervals for all expressions in the DAEG by successive @@ -533,14 +572,14 @@ impl ExprIntervalGraph { pub fn update_intervals( &mut self, expr_stats: &mut [(usize, Interval)], - ) -> Result<()> { + ) -> Result { self.evaluate_bounds(expr_stats) .and_then(|interval| match interval { Interval { upper: ScalarValue::Boolean(Some(true)), .. } => self.propagate_constraints(expr_stats), - _ => Ok(()), + _ => Ok(OptimizationResult::CannotPropagate), }) } } @@ -564,6 +603,7 @@ mod tests { right_interval: (Option, Option), left_waited: (Option, Option), right_waited: (Option, Option), + result: OptimizationResult, ) -> Result<()> { let col_stats = vec![ ( @@ -613,7 +653,8 @@ mod tests { .map(|((_, interval), (_, index))| (*index, interval.clone())) .collect_vec(); - graph.update_intervals(&mut col_stat_nodes[..])?; + let exp_result = graph.update_intervals(&mut col_stat_nodes[..])?; + assert_eq!(exp_result, result); col_stat_nodes .iter() .zip(expected_nodes.iter()) @@ -675,6 +716,7 @@ mod tests { right_interval, left_waited, right_waited, + OptimizationResult::Success, )?; Ok(()) } @@ -698,6 +740,7 @@ mod tests { (Some(100), None), (Some(10), Some(20)), (Some(100), None), + OptimizationResult::CannotPropagate, )?; Ok(()) } From 287f2f47f27b3d6090265ddf1d0663614a15eade Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Thu, 2 Feb 2023 22:53:00 -0600 Subject: [PATCH 12/39] Refactor and simplify CP code, improve comments --- datafusion/common/src/scalar.rs | 23 + .../joins/symmetric_hash_join.rs | 38 +- .../physical-expr/src/intervals/cp_solver.rs | 423 ++++++++---------- 3 files changed, 228 insertions(+), 256 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 33dd583865c2..f08f18889d88 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1019,6 +1019,29 @@ impl ScalarValue { Self::List(scalars, Box::new(Field::new("item", child_type, true))) } + // Create a zero value in the given type. + pub fn new_zero(datatype: &DataType) -> Result { + assert!(datatype.is_primitive()); + Ok(match datatype { + DataType::Boolean => ScalarValue::Boolean(Some(false)), + DataType::Int8 => ScalarValue::Int8(Some(0)), + DataType::Int16 => ScalarValue::Int16(Some(0)), + DataType::Int32 => ScalarValue::Int32(Some(0)), + DataType::Int64 => ScalarValue::Int64(Some(0)), + DataType::UInt8 => ScalarValue::UInt8(Some(0)), + DataType::UInt16 => ScalarValue::UInt16(Some(0)), + DataType::UInt32 => ScalarValue::UInt32(Some(0)), + DataType::UInt64 => ScalarValue::UInt64(Some(0)), + DataType::Float32 => ScalarValue::UInt64(Some(0)), + DataType::Float64 => ScalarValue::UInt64(Some(0)), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a zero scalar from data_type \"{datatype:?}\"" + ))); + } + }) + } + /// Getter for the `DataType` of the value pub fn get_datatype(&self) -> DataType { match self { diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 1f3c5b0ae737..bf2f0cdca789 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -37,7 +37,6 @@ use arrow::record_batch::RecordBatch; use futures::{Stream, StreamExt}; use hashbrown::raw::RawTable; use hashbrown::HashSet; -use itertools::Itertools; use datafusion_common::utils::bisect; use datafusion_common::ScalarValue; @@ -270,23 +269,20 @@ impl SymmetricHashJoinExec { ExprIntervalGraph::try_new(filter.expression().clone())?; // Since the graph is stable, we reuse NodeIndex of our filter expressions. - let child_node_indexes = physical_expr_graph - .pair_node_indices_with_interval_providers( - &filter_columns - .iter() - .map(|sorted_expr| sorted_expr.filter_expr()) - .collect_vec(), - )?; + let child_node_indexes = physical_expr_graph.gather_node_indices( + &filter_columns + .iter() + .map(|sorted_expr| sorted_expr.filter_expr()) + .collect::>(), + ); // We inject calculated node indexes into SortedFilterExpr. In graph calculations, // we will be using node index to put calculated intervals into Columns or BinaryExprs . - filter_columns - .iter_mut() - .zip(child_node_indexes.iter()) - .map(|(sorted_expr, (_, index))| { - sorted_expr.set_node_index(*index); - }) - .collect_vec(); + for (sorted_expr, (_, index)) in + filter_columns.iter_mut().zip(child_node_indexes.iter()) + { + sorted_expr.set_node_index(*index); + } Ok(SymmetricHashJoinExec { left, @@ -968,10 +964,10 @@ impl OneSideHashJoiner { // We use Vec<(usize, Interval)> instead of Hashmap since the expected // filter exprs relatively low and conversion between Vec and Hashmap // back and forth may be slower. - let mut filter_intervals: Vec<(usize, Interval)> = filter_columns + let mut filter_intervals = filter_columns .iter() .map(|sorted_expr| (sorted_expr.node_index(), sorted_expr.interval().clone())) - .collect_vec(); + .collect::>(); // Use this vector to seed the child PhysicalExpr interval. physical_expr_graph.update_intervals(&mut filter_intervals)?; let mut change_track = Vec::with_capacity(filter_intervals.len()); @@ -1304,15 +1300,15 @@ mod tests { ) -> Result> { let partition_count = 1; - let left_expr: Vec> = on + let left_expr = on .iter() .map(|(l, _)| Arc::new(l.clone()) as Arc) - .collect_vec(); + .collect::>(); - let right_expr: Vec> = on + let right_expr = on .iter() .map(|(_, r)| Arc::new(r.clone()) as Arc) - .collect_vec(); + .collect::>(); let join = SymmetricHashJoinExec::try_new( Arc::new(RepartitionExec::try_new( diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index bdb03eae322f..dfeb03a22406 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -18,7 +18,7 @@ //! Constraint propagator/solver for custom PhysicalExpr graphs. use std::fmt::{Display, Formatter}; -use std::ops::{Index, IndexMut}; +use std::ops::Index; use std::sync::Arc; use arrow_schema::DataType; @@ -115,13 +115,16 @@ use crate::PhysicalExpr; /// This object implements a directed acyclic expression graph (DAEG) that /// is used to compute ranges for expressions through interval arithmetic. #[derive(Clone)] -pub struct ExprIntervalGraph(NodeIndex, StableGraph); +pub struct ExprIntervalGraph { + graph: StableGraph, + root: NodeIndex, +} /// Result corresponding the 'update_interval' call. #[derive(PartialEq, Debug)] -pub enum OptimizationResult { +pub enum PropagationResult { CannotPropagate, - UnfeasibleSolution, + Infeasible, Success, } @@ -191,60 +194,41 @@ fn get_inverse_op(op: Operator) -> Operator { } } -/// This function refines intervals `left_child` and `right_child` by -/// applying constraint propagation through `parent` via operation. -/// -/// The main idea is that we can shrink the domains of variables x and y by using -/// parent interval p. +/// This function refines intervals `left_child` and `right_child` by applying +/// constraint propagation through `parent` via operation. The main idea is +/// that we can shrink ranges of variables x and y using parent interval p. /// -/// Assuming that x,y and p has ranges [xL, xU], [yL, yU], and [pL, pU], we simply apply -/// constraint propagation across [xL, xU], [yL, yH] and [pL, pU]. -/// - For minus operation, specifically, we would first do -/// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then -/// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. +/// Assuming that x,y and p has ranges [xL, xU], [yL, yU], and [pL, pU], we +/// apply the following operations: /// - For plus operation, specifically, we would first do /// - [xL, xU] <- ([pL, pU] - [yL, yU]) ∩ [xL, xU], and then /// - [yL, yU] <- ([pL, pU] - [xL, xU]) ∩ [yL, yU]. -pub fn propagate( +/// - For minus operation, specifically, we would first do +/// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. +pub fn propagate_arithmetic( op: &Operator, parent: &Interval, left_child: &Interval, right_child: &Interval, ) -> Result<(Option, Option)> { let inverse_op = get_inverse_op(*op); - Ok( - match apply_operator(&inverse_op, parent, right_child)?.intersect(left_child)? { - // Left is feasible - Some(value) => { - // Adjust the propagation operator and parent-child order. - let (o, l, r) = match op { - Operator::Minus => (op, &value, parent), - Operator::Plus => (&inverse_op, parent, &value), - _ => unreachable!(), - }; - // Calculate right child with new left child. - let right = apply_operator(o, l, r)?.intersect(right_child)?; - // Return intervals for both children. - (Some(value), right) - } - // If left child is not feasible, return both childs None. - None => (None, None), - }, - ) -} - -fn get_zero_scalar(datatype: &DataType) -> ScalarValue { - assert!(datatype.is_primitive()); - match datatype { - DataType::Int8 => ScalarValue::Int8(Some(0)), - DataType::Int16 => ScalarValue::Int16(Some(0)), - DataType::Int32 => ScalarValue::Int32(Some(0)), - DataType::Int64 => ScalarValue::Int64(Some(0)), - DataType::UInt8 => ScalarValue::UInt8(Some(0)), - DataType::UInt16 => ScalarValue::UInt16(Some(0)), - DataType::UInt32 => ScalarValue::UInt32(Some(0)), - DataType::UInt64 => ScalarValue::UInt64(Some(0)), - _ => unreachable!(), + // First, propagate to the left: + match apply_operator(&inverse_op, parent, right_child)?.intersect(left_child)? { + // Left is feasible: + Some(value) => { + // Propagate to the right using the new left. + let right = match op { + Operator::Minus => apply_operator(op, &value, parent), + Operator::Plus => apply_operator(&inverse_op, parent, &value), + _ => unreachable!(), + }? + .intersect(right_child)?; + // Return intervals for both children: + Ok((Some(value), right)) + } + // If the left child is infeasible, short-circuit. + None => Ok((None, None)), } } @@ -255,7 +239,7 @@ fn get_zero_scalar(datatype: &DataType) -> ScalarValue { /// are not implemented yet. fn comparison_operator_target(datatype: &DataType, op: &Operator) -> Result { let unbounded = ScalarValue::try_from(datatype)?; - let zero = get_zero_scalar(datatype); + let zero = ScalarValue::new_zero(datatype)?; Ok(match *op { Operator::Gt => Interval { lower: zero, @@ -276,114 +260,113 @@ fn comparison_operator_target(datatype: &DataType, op: &Operator) -> Result [x2, y2] can be rewritten as [xL, xU] - [yL, yU] = [0, ∞] -/// and [xL, xU] < [yL, yU] can be rewritten as [xL, xU] - [yL, yU] = [-∞, 0]. -pub fn propagate_comparison_operators( +pub fn propagate_comparison( op: &Operator, left_child: &Interval, right_child: &Interval, ) -> Result<(Option, Option)> { let parent = comparison_operator_target(&left_child.get_datatype(), op)?; - propagate(&Operator::Minus, &parent, left_child, right_child) + propagate_arithmetic(&Operator::Minus, &parent, left_child, right_child) } impl ExprIntervalGraph { pub fn try_new(expr: Arc) -> Result { // Build the full graph: - let (root_node, graph) = + let (root, graph) = build_physical_expr_graph(expr, &ExprIntervalGraphNode::make_node)?; - Ok(Self(root_node, graph)) + Ok(Self { graph, root }) } - /// - /// ``` text - /// - /// - /// If we want to calculate intervals starting from and propagate up to BinaryExpr('a', +, 'b') - /// instead of col('a') and col('b'), we prune under BinaryExpr('a', +, 'b') since we will - /// never visit there. - /// - /// One of the reasons behind pruning the expression interval graph is that if we provide - /// - /// PhysicalSortExpr { - /// expr: BinaryExpr('a', +, 'b'), - /// sort_option: .. - /// } - /// - /// in output_order of one of the SymmetricHashJoin's childs. In this case, we must calculate - /// the interval for the BinaryExpr('a', +, 'b') instead of the columns inside the BinaryExpr. Thus, - /// the child PhysicalExprs of the BinaryExpr may be pruned for the performance. - /// - /// We do not change the Arc inside Arc, just remove the nodes - /// corresponding that Arc. - /// - /// - /// +-----+ +-----+ - /// | GT | | GT | - /// +--------| |-------+ +--------| |-------+ - /// | +-----+ | | +-----+ | - /// | | | | - /// +-----+ | +-----+ | - /// |Cast | | |Cast | | - /// | | | --\ | | | - /// +-----+ | ---------- +-----+ | - /// | | --/ | | - /// | | | | - /// +-----+ +-----+ +-----+ +-----+ - /// +--|Plus |--+ +--|Plus |--+ |Plus | +--|Plus |--+ - /// | | | | | | | | | | | | | | - /// Prune from here | +-----+ | | +-----+ | +-----+ | +-----+ | - /// ------------------------------------ | | | | - /// | | | | | | - /// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ - /// | a | | b | | c | | 2 | | c | | 2 | - /// | | | | | | | | | | | | - /// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ - /// - /// This operation mutates the underline graph, so use it after constructor. - /// ``` - /// - /// Since graph is stable, the NodeIndex will stay same even if we delete a node. We exploit - /// this stability by matching Arc and NodeIndex for membership tests. - pub fn pair_node_indices_with_interval_providers( + // Sometimes, we do not want to calculate and/or propagate intervals all + // way down to leaf expressions. For example, assume that we have a + // `SymmetricHashJoin` which has a child with an output ordering like: + // + // PhysicalSortExpr { + // expr: BinaryExpr('a', +, 'b'), + // sort_option: .. + // } + // + // i.e. its output order comes from a clause like "ORDER BY a + b". In such + // a case, we must calculate the interval for the BinaryExpr('a', +, 'b') + // instead of the columns inside this BinaryExpr, because this interval + // decides whether we prune or not. Therefore, children `PhysicalExpr`s of + // this `BinaryExpr` may be pruned for performance. The figure below + // explains this example visually. + // + // Note that we just remove the nodes from the DAEG, do not make any change + // to the plan itself. + // + // ```text + // + // +-----+ +-----+ + // | GT | | GT | + // +--------| |-------+ +--------| |-------+ + // | +-----+ | | +-----+ | + // | | | | + // +-----+ | +-----+ | + // |Cast | | |Cast | | + // | | | --\ | | | + // +-----+ | ---------- +-----+ | + // | | --/ | | + // | | | | + // +-----+ +-----+ +-----+ +-----+ + // +--|Plus |--+ +--|Plus |--+ |Plus | +--|Plus |--+ + // | | | | | | | | | | | | | | + // Prune from here | +-----+ | | +-----+ | +-----+ | +-----+ | + // ------------------------------------ | | | | + // | | | | | | + // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ + // | a | | b | | c | | 2 | | c | | 2 | + // | | | | | | | | | | | | + // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ + // + // ``` + + /// This function associates stable node indices with [PhysicalExpr]s so + /// that we can match Arc and NodeIndex objects during + /// membership tests. + pub fn gather_node_indices( &mut self, exprs: &[Arc], - ) -> Result, usize)>> { - let mut bfs = Bfs::new(&self.1, self.0); - // We collect the node indexes (usize) of the PhysicalExprs, with the same order - // with `exprs`. To preserve order, we initiate each expr's node index with usize::MAX, - // then find the corresponding node indexes by traversing the graph. + ) -> Vec<(Arc, usize)> { + let graph = &self.graph; + let mut bfs = Bfs::new(graph, self.root); + // We collect the node indices (usize) of [PhysicalExpr]s in the order + // given by argument `exprs`. To preserve this order, we initialize each + // expression's node index with usize::MAX, and then find the corresponding + // node indices by traversing the graph. let mut removals = vec![]; - let mut expr_node_indices: Vec<(Arc, usize)> = - exprs.iter().map(|e| (e.clone(), usize::MAX)).collect(); - while let Some(node) = bfs.next(&self.1) { + let mut expr_node_indices = exprs + .iter() + .map(|e| (e.clone(), usize::MAX)) + .collect::>(); + while let Some(node) = bfs.next(graph) { // Get the plan corresponding to this node: - let input = self.1.index(node); + let input = graph.index(node); let expr = input.expr.clone(); - // If the current expression is among `exprs`, slate its - // children for removal. + // If the current expression is among `exprs`, slate its children + // for removal: if let Some(value) = exprs.iter().position(|e| expr.eq(e)) { - // Update NodeIndex of the PhysicalExpr + // Update the node index of the associated `PhysicalExpr`: expr_node_indices[value].1 = node.index(); - let mut edges = self.1.neighbors_directed(node, Outgoing).detach(); - while let Some(n_index) = edges.next_node(&self.1) { + let mut edges = graph.neighbors_directed(node, Outgoing).detach(); + while let Some(n_index) = edges.next_node(graph) { // Slate the child for removal, do not remove immediately. removals.push(n_index); } } } for node in removals { - self.1.remove_node(node); + self.graph.remove_node(node); } - Ok(expr_node_indices) + expr_node_indices } /// Computes bounds for an expression using interval arithmetic via a /// bottom-up traversal. - /// It returns root node's interval. /// /// # Arguments - /// * `expr_stats` - &[(usize, Interval)]. Provide NodeIndex, Interval tuples for bound evaluation. + /// * `expr_stats` - &[(usize, Interval)]. Provide NodeIndex, Interval tuples for leaf variables. /// /// # Examples /// @@ -403,10 +386,9 @@ impl ExprIntervalGraph { /// let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); /// // Do it once, while constructing. /// let node_indices = graph - /// .pair_node_indices_with_interval_providers(&[Arc::new(Column::new("gnz", 0))]) - /// .unwrap(); + /// .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]); /// let left_index = node_indices.get(0).unwrap().1; - /// // Provide intervals + /// // Provide intervals for leaf variables (here, there is only one). /// let intervals = vec![( /// left_index, /// Interval { @@ -414,7 +396,7 @@ impl ExprIntervalGraph { /// upper: ScalarValue::Int32(Some(20)), /// }, /// )]; - /// // Evaluate bounds + /// // Evaluate bounds for the composite expression: /// assert_eq!( /// graph.evaluate_bounds(&intervals).unwrap(), /// Interval { @@ -428,143 +410,115 @@ impl ExprIntervalGraph { &mut self, expr_stats: &[(usize, Interval)], ) -> Result { - let mut dfs = DfsPostOrder::new(&self.1, self.0); - while let Some(node) = dfs.next(&self.1) { - // Outgoing (Children) edges - let mut edges = self.1.neighbors_directed(node, Outgoing).detach(); - // Get PhysicalExpr - let expr = self.1[node].expr.clone(); - // Check if we have a interval information about given PhysicalExpr, if so, directly - // propagate it to the upper. - if let Some((_, interval)) = - expr_stats.iter().find(|(e, _)| *e == node.index()) - { - let input = self.1.index_mut(node); - input.interval = interval.clone(); + let mut dfs = DfsPostOrder::new(&self.graph, self.root); + while let Some(node) = dfs.next(&self.graph) { + // Get outgoing edges (i.e. children): + let mut edges = self.graph.neighbors_directed(node, Outgoing).detach(); + // If the current expression is a leaf expression, it should have + // externally-given intervals. Use this information if available: + let index = node.index(); + if let Some((_, interval)) = expr_stats.iter().find(|(e, _)| *e == index) { + self.graph[node].interval = interval.clone(); continue; } - let expr_any = expr.as_any(); + let expr_any = self.graph[node].expr.as_any(); if let Some(binary) = expr_any.downcast_ref::() { - // Access left child immutable - let second_child_node_index = edges.next_node(&self.1).unwrap(); - // Access left child immutable - let first_child_node_index = edges.next_node(&self.1).unwrap(); - // Do not use any reference from graph, MUST clone here. - let left_interval = - self.1.index(first_child_node_index).interval().clone(); - let right_interval = - self.1.index(second_child_node_index).interval().clone(); - // Since we release the reference, we can get mutable reference. - let input = self.1.index_mut(node); - // Calculate and replace the interval - input.interval = - apply_operator(binary.op(), &left_interval, &right_interval)?; + // Get children intervals: + let left_child_idx = edges.next_node(&self.graph).unwrap(); + let right_child_idx = edges.next_node(&self.graph).unwrap(); + let left_interval = self.graph.index(right_child_idx).interval(); + let right_interval = self.graph.index(left_child_idx).interval(); + // Calculate and replace current node's interval: + self.graph[node].interval = + apply_operator(binary.op(), left_interval, right_interval)?; } else if let Some(CastExpr { cast_type, cast_options, .. }) = expr_any.downcast_ref::() { - // Access the child immutable - let child_index = edges.next_node(&self.1).unwrap(); - let child = self.1.index(child_index); - // Cast the interval - let new_interval = child.interval.cast_to(cast_type, cast_options)?; - // Update the interval - let input = self.1.index_mut(node); - input.interval = new_interval; + let child_index = edges.next_node(&self.graph).unwrap(); + let child = self.graph.index(child_index); + // Cast current node's interval to the right type: + self.graph[node].interval = + child.interval.cast_to(cast_type, cast_options)?; } } - let root_interval = self.1.index(self.0).interval.clone(); - Ok(root_interval) + Ok(self.graph.index(self.root).interval.clone()) } /// Updates/shrinks bounds for leaf expressions using interval arithmetic /// via a top-down traversal. - /// If return false, it means we have a unfeasible solution. fn propagate_constraints( &mut self, expr_stats: &mut [(usize, Interval)], - ) -> Result { - let mut bfs = Bfs::new(&self.1, self.0); - while let Some(node) = bfs.next(&self.1) { + ) -> Result { + let mut bfs = Bfs::new(&self.graph, self.root); + while let Some(node) = bfs.next(&self.graph) { // Get plan - let input = self.1.index(node); - let expr = input.expr.clone(); + let input = self.graph.index(node); // Get calculated interval. BinaryExpr will propagate the interval according to // this. let node_interval = input.interval().clone(); - if let Some((_, interval)) = - expr_stats.iter_mut().find(|(e, _)| *e == node.index()) + let index = node.index(); + if let Some((_, interval)) = expr_stats.iter_mut().find(|(e, _)| *e == index) { *interval = node_interval; continue; } - // Outgoing (Children) edges - let mut edges = self.1.neighbors_directed(node, Outgoing).detach(); + // Get outgoing edges (i.e. children): + let mut edges = self.graph.neighbors_directed(node, Outgoing).detach(); - let expr_any = expr.as_any(); + let expr_any = input.expr.as_any(); if let Some(binary) = expr_any.downcast_ref::() { - // Get right node. - let second_child_node_index = edges.next_node(&self.1).unwrap(); - let second_child_interval = - self.1.index(second_child_node_index).interval(); - // Get left node. - let first_child_node_index = edges.next_node(&self.1).unwrap(); - let first_child_interval = - self.1.index(first_child_node_index).interval(); - - if let (Some(shrink_left_interval), Some(shrink_right_interval)) = - // There is no propagation process for logical operators. - if binary.op().is_logic_operator() { + // Get children intervals (note that we traverse in reverse): + let right_child_idx = edges.next_node(&self.graph).unwrap(); + let right_interval = self.graph.index(right_child_idx).interval(); + let left_child_idx = edges.next_node(&self.graph).unwrap(); + let left_interval = self.graph.index(left_child_idx).interval(); + + let op = binary.op(); + if let (Some(new_left_interval), Some(new_right_interval)) = + if op.is_logic_operator() { + // There is no propagation process for logical operators. + continue; + } else if op.is_comparison_operator() { + // If comparison is strictly false, there is nothing to do for shrink. + if let Interval { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + } = node_interval + { continue; - } else if binary.op().is_comparison_operator() { - // If comparison is strictly false, there is nothing to do for shrink. - if let Interval { - lower: ScalarValue::Boolean(Some(false)), - upper: ScalarValue::Boolean(Some(false)), - } = node_interval - { - continue; - } - // Propagate the comparison operator. - propagate_comparison_operators( - binary.op(), - first_child_interval, - second_child_interval, - )? - } else { - // Propagate the arithmetic operator. - propagate( - binary.op(), - &node_interval, - first_child_interval, - second_child_interval, - )? } + // Propagate the comparison operator. + propagate_comparison(op, left_interval, right_interval)? + } else { + // Propagate the arithmetic operator. + propagate_arithmetic( + op, + &node_interval, + left_interval, + right_interval, + )? + } { - let mutable_first_child = self.1.index_mut(first_child_node_index); - mutable_first_child.interval = shrink_left_interval.clone(); - let mutable_second_child = self.1.index_mut(second_child_node_index); - mutable_second_child.interval = shrink_right_interval.clone(); + self.graph[left_child_idx].interval = new_left_interval; + self.graph[right_child_idx].interval = new_right_interval; } else { - return Ok(OptimizationResult::UnfeasibleSolution); + return Ok(PropagationResult::Infeasible); }; } else if let Some(cast) = expr_any.downcast_ref::() { - // Calculate new interval - let child_index = edges.next_node(&self.1).unwrap(); - let child = self.1.index(child_index); - // Get child's internal datatype. + let child_index = edges.next_node(&self.graph).unwrap(); + let child = self.graph.index(child_index); + // Get child's data type: let cast_type = child.interval().get_datatype(); - let cast_options = cast.cast_options(); - let new_child_interval = - node_interval.cast_to(&cast_type, cast_options)?; - let mutable_child = self.1.index_mut(child_index); - mutable_child.interval = new_child_interval.clone(); + self.graph[child_index].interval = + node_interval.cast_to(&cast_type, cast.cast_options())?; } } - Ok(OptimizationResult::Success) + Ok(PropagationResult::Success) } /// Updates intervals for all expressions in the DAEG by successive @@ -572,14 +526,14 @@ impl ExprIntervalGraph { pub fn update_intervals( &mut self, expr_stats: &mut [(usize, Interval)], - ) -> Result { + ) -> Result { self.evaluate_bounds(expr_stats) .and_then(|interval| match interval { Interval { upper: ScalarValue::Boolean(Some(true)), .. } => self.propagate_constraints(expr_stats), - _ => Ok(OptimizationResult::CannotPropagate), + _ => Ok(PropagationResult::CannotPropagate), }) } } @@ -603,7 +557,7 @@ mod tests { right_interval: (Option, Option), left_waited: (Option, Option), right_waited: (Option, Option), - result: OptimizationResult, + result: PropagationResult, ) -> Result<()> { let col_stats = vec![ ( @@ -638,9 +592,8 @@ mod tests { ), ]; let mut graph = ExprIntervalGraph::try_new(expr)?; - let expr_indexes = graph.pair_node_indices_with_interval_providers( - &col_stats.iter().map(|(e, _)| e.clone()).collect_vec(), - )?; + let expr_indexes = graph + .gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec()); let mut col_stat_nodes = col_stats .iter() @@ -716,7 +669,7 @@ mod tests { right_interval, left_waited, right_waited, - OptimizationResult::Success, + PropagationResult::Success, )?; Ok(()) } @@ -740,7 +693,7 @@ mod tests { (Some(100), None), (Some(10), Some(20)), (Some(100), None), - OptimizationResult::CannotPropagate, + PropagationResult::CannotPropagate, )?; Ok(()) } From c8d7c21af5c64e29e670c8429bf4229bd3a69785 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Sat, 4 Feb 2023 18:27:59 +0300 Subject: [PATCH 13/39] Code deduplication between pipeline fixer and utils, also enhance comments. --- .../src/physical_optimizer/pipeline_fixer.rs | 35 +- .../physical_plan/joins/hash_join_utils.rs | 285 +++++++++------- .../core/src/physical_plan/joins/mod.rs | 4 +- .../joins/symmetric_hash_join.rs | 304 ++++++++++++++---- .../physical-expr/src/intervals/cp_solver.rs | 10 +- 5 files changed, 463 insertions(+), 175 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 5232a09c753c..2c35634a011c 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -29,7 +29,10 @@ use crate::physical_optimizer::pipeline_checker::{ check_finiteness_requirements, PipelineStatePropagator, }; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::{HashJoinExec, PartitionMode, SymmetricHashJoinExec}; +use crate::physical_plan::joins::{ + convert_sort_expr_with_filter_schema, HashJoinExec, PartitionMode, + SymmetricHashJoinExec, +}; use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::ExecutionPlan; use arrow::datatypes::DataType; @@ -42,6 +45,7 @@ use datafusion_physical_expr::physical_expr_visitor::{ }; use datafusion_physical_expr::PhysicalExpr; +use crate::physical_plan::joins::utils::JoinSide; use std::sync::Arc; /// The [PipelineFixer] rule tries to modify a given plan so that it can @@ -155,12 +159,39 @@ fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result expr.accept(UnsupportedPhysicalExprVisitor { supported: &mut is_expr_supported, })?; + let left_is_convertable = convert_sort_expr_with_filter_schema( + &JoinSide::Left, + filter, + hash_join.left().schema(), + hash_join + .left() + .output_ordering() + .as_ref() + .unwrap() + .get(0) + .unwrap(), + )? + .is_some(); + let right_is_convertable = convert_sort_expr_with_filter_schema( + &JoinSide::Right, + filter, + hash_join.right().schema(), + hash_join + .right() + .output_ordering() + .as_ref() + .unwrap() + .get(0) + .unwrap(), + )? + .is_some(); + let sort_expressions_supported = left_is_convertable && right_is_convertable; let is_fields_supported = filter .schema() .fields() .iter() .all(|f| is_datatype_supported(f.data_type())); - is_expr_supported && is_fields_supported + is_expr_supported && is_fields_supported && sort_expressions_supported } }; Ok(are_children_ordered && is_filter_supported) diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index e6891cd13df2..11e30142ade9 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -648,19 +648,18 @@ impl PhysicalExpressionVisitor for PhysicalExprColumnCollector<'_> { // index: 0, -> field index in main schema // side: JoinSide::Left, -> child side // }, -fn map_origin_col_to_filter_col( +pub fn map_origin_col_to_filter_col( filter: &JoinFilter, - main_schema: SchemaRef, - side: JoinSide, + schema: SchemaRef, + side: &JoinSide, ) -> Result> { let mut col_to_col_map: HashMap = HashMap::new(); for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { - if index.side.eq(&side) { + if index.side.eq(side) { // Get main field from column index - let main_field = main_schema.field(index.index); + let main_field = schema.field(index.index); // Create a column PhysicalExpr - let main_col = - Column::new_with_schema(main_field.name(), main_schema.as_ref())?; + let main_col = Column::new_with_schema(main_field.name(), schema.as_ref())?; // Since the filter.column_indices() order directly same with intermediate schema fields, we can // get the column. let filter_field = filter.schema().field(filter_schema_index); @@ -671,100 +670,107 @@ fn map_origin_col_to_filter_col( } Ok(col_to_col_map) } -/// If converted PhysicalExpr is inside the Filter, we use the sort information. The visitor -/// visits each node and checks if the filter expression includes the converted expr. + +/// This function plays an important role in the expression graph traversal process. It is necessary to analyze the `PhysicalSortExpr` +/// because the sorting of expressions is required for join filter expressions. /// -/// We want to filter out the expression if the filter is a + b > c + 10 AND a + b < c + 100 -/// - a + d is sorted. D is not part of the filter. -/// - d is sorted. D is not part of the filter. -/// - a + b + c is sorted. All columns are represented in filter expression however there is no exact match. -/// These expression does not indicate prune. +/// The method works as follows: +/// 1. Maps the original columns to the filter columns using the `map_origin_col_to_filter_col` function. +/// 2. Collects all columns in the sort expression using the `PhysicalExprColumnCollector` visitor. +/// 3. Checks if all columns are included in the `column_mapping_information` map. +/// 4. If all columns are included, the sort expression is converted into a filter expression using the `transform_up` and `convert_filter_columns` functions. +/// 5. Searches the converted filter expression in the filter expression using the `CheckFilterExprContainsSortInformation` visitor. +/// 6. If an exact match is encountered, returns the converted filter expression as `Some(Arc)`. +/// 7. If all columns are not included or the exact match is not encountered, returns `None`. /// -fn rewrite_sort_information_for_filter( - side: JoinSide, +/// Use Cases: +/// Consider the filter expression "a + b > c + 10 AND a + b < c + 100". +/// 1. If the expression "a@ + d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. +/// 2. If the expression "d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. +/// 3. If the expression "a@ + b@ + c@" is sorted, all columns are represented in the filter expression. However, +/// there is no exact match, so this expression does not indicate pruning. +/// +pub fn convert_sort_expr_with_filter_schema( + side: &JoinSide, filter: &JoinFilter, - col_to_col_map: &HashMap, - child_order: &[PhysicalSortExpr], -) -> Result> { - let mut sorted_exps = vec![]; - for sort_expr in child_order { - let expr = sort_expr.expr.clone(); - // Get main schema columns - let mut expr_columns = vec![]; - expr.clone() - .accept(PhysicalExprColumnCollector::new(&mut expr_columns))?; - // Calculation is possible with col_to_col_map since sort exprs belong to a child. - let all_columns_are_included = expr_columns - .iter() - .all(|col| col_to_col_map.contains_key(col)); - if all_columns_are_included { - // Since we are sure that one to one column mapping includes all column, we convert - // the sort expression into Filter expression. - let converted_filter_expr = - expr.transform_up(&|p| convert_filter_columns(p, col_to_col_map))?; - let mut contains = false; - // Search converted PhysicalExpr in filter expression - filter.expression().clone().accept( - CheckFilterExprContainsSortInformation::new( - &mut contains, - converted_filter_expr.clone(), - ), - )?; - // If the exact match is encountered, use this sorted expression in graph traversals. - if contains { - sorted_exps.push(SortedFilterExpr::new( - side, - sort_expr.expr.clone(), - converted_filter_expr.clone(), - sort_expr.options, - )); - } + schema: SchemaRef, + sort_expr: &PhysicalSortExpr, +) -> Result>> { + let column_mapping_information: HashMap = + map_origin_col_to_filter_col(filter, schema, side)?; + let expr = sort_expr.expr.clone(); + // Get main schema columns + let mut expr_columns = vec![]; + expr.clone() + .accept(PhysicalExprColumnCollector::new(&mut expr_columns))?; + // Calculation is possible with 'column_mapping_information' since sort exprs belong to a child. + let all_columns_are_included = expr_columns + .iter() + .all(|col| column_mapping_information.contains_key(col)); + if all_columns_are_included { + // Since we are sure that one to one column mapping includes all columns, we convert + // the sort expression into a filter expression. + let converted_filter_expr = expr + .transform_up(&|p| convert_filter_columns(p, &column_mapping_information))?; + let mut contains = false; + // Search converted PhysicalExpr in filter expression + filter.expression().clone().accept( + CheckFilterExprContainsSortInformation::new( + &mut contains, + converted_filter_expr.clone(), + ), + )?; + // If the exact match is encountered, use this sorted expression in graph traversals. + if contains { + Ok(Some(converted_filter_expr)) + } else { + Ok(None) } + } else { + Ok(None) } - Ok(sorted_exps) } -/// Here, we map the main column information to filter intermediate schema. -/// child_orders are derived from child schemas. However, filter schema is different. We enforce -/// that we get the column order for each corresponding [ColumnIndex]. If not, we return None and -/// assume this expression is not capable of pruning. -pub fn build_filter_input_order( +/// This function is used to build the filter expression based on the sort order of input columns. +/// +/// It first calls the convert_sort_expr_with_filter_schema method to determine if the sort +/// order of columns can be used in the filter expression. +/// If it returns a Some value, the method wraps the result in a SortedFilterExpr +/// instance with the original sort expression, converted filter expression, and sort options. +/// If it returns a None value, this function returns None. +/// +/// The SortedFilterExpr instance contains information about the sort order of columns +/// that can be used in the filter expression, which can be used to optimize the query execution process. +pub fn build_filter_input_order_v2( + side: JoinSide, filter: &JoinFilter, - left_main_schema: SchemaRef, - right_main_schema: SchemaRef, - left_order: &[PhysicalSortExpr], - right_order: &[PhysicalSortExpr], -) -> Result> { - // Get left mapping - let left_hashmap: HashMap = - map_origin_col_to_filter_col(filter, left_main_schema, JoinSide::Left)?; - // Get right mapping - let right_hashmap: HashMap = - map_origin_col_to_filter_col(filter, right_main_schema, JoinSide::Right)?; - // Extract sorted expression that belongs to the filter expression - let left_sorted_filter_exprs = rewrite_sort_information_for_filter( - JoinSide::Left, - filter, - &left_hashmap, - left_order, - )?; - let right_sorted_filter_exprs = rewrite_sort_information_for_filter( - JoinSide::Right, - filter, - &right_hashmap, - right_order, - )?; - Ok([left_sorted_filter_exprs, right_sorted_filter_exprs].concat()) + schema: SchemaRef, + order: &PhysicalSortExpr, +) -> Result> { + match convert_sort_expr_with_filter_schema(&side, filter, schema, order)? { + Some(expr) => Ok(Some(SortedFilterExpr::new( + side, + order.expr.clone(), + expr, + order.options, + ))), + None => Ok(None), + } } +/// Convert a physical expression into a filter expression using a column mapping information. fn convert_filter_columns( input: Arc, - column_change_information: &HashMap, + column_mapping_information: &HashMap, ) -> Result>> { + // Attempt to downcast the input expression to a Column type. if let Some(col) = input.as_any().downcast_ref::() { - let filter_col = column_change_information.get(col).unwrap().clone(); + // If the downcast is successful, retrieve the corresponding filter column. + let filter_col = column_mapping_information.get(col).unwrap().clone(); + // Return the filter column as an Arc wrapped in an Option. Ok(Some(Arc::new(filter_col))) } else { + // If the downcast is not successful, return the input expression as is. Ok(Some(input)) } } @@ -1026,42 +1032,99 @@ mod tests { let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - // Only two of them in Filter PhysicalExpr (la1, la2) - let left_sorted = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("lt1", 3)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("la2", 4)), - options: SortOptions::default(), - }, - PhysicalSortExpr { + assert!(build_filter_input_order_v2( + JoinSide::Left, + &filter, + left_schema.clone(), + &PhysicalSortExpr { expr: Arc::new(Column::new("la1", 0)), options: SortOptions::default(), - }, - ]; - // Only "ra1" in Filter PhysicalExpr - let right_sorted = vec![ - PhysicalSortExpr { + } + )? + .is_some()); + assert!(build_filter_input_order_v2( + JoinSide::Left, + &filter, + left_schema, + &PhysicalSortExpr { + expr: Arc::new(Column::new("lt1", 3)), + options: SortOptions::default(), + } + )? + .is_none()); + assert!(build_filter_input_order_v2( + JoinSide::Right, + &filter, + right_schema.clone(), + &PhysicalSortExpr { expr: Arc::new(Column::new("ra1", 0)), options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("rt1", 3)), + } + )? + .is_some()); + assert!(build_filter_input_order_v2( + JoinSide::Right, + &filter, + right_schema, + &PhysicalSortExpr { + expr: Arc::new(Column::new("rb1", 1)), options: SortOptions::default(), + } + )? + .is_none()); + + Ok(()) + } + // if one side is sorted by ORDER BY (a+b), and join filter condition includes (a-b). + #[test] + fn sorted_filter_expr_build() -> Result<()> { + let filter_col_0 = Arc::new(Column::new("0", 0)); + let filter_col_1 = Arc::new(Column::new("1", 1)); + + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Left, }, ]; + let intermediate_schema = Schema::new(vec![ + Field::new(filter_col_0.name(), DataType::Int32, true), + Field::new(filter_col_1.name(), DataType::Int32, true), + ]); + + let filter_expr = Arc::new(BinaryExpr { + left: Arc::new(Column::new("0", 0)), + op: Operator::Minus, + right: Arc::new(Column::new("1", 1)), + }); + + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int64, false), + ])); - let sorted = build_filter_input_order( + let sorted = PhysicalSortExpr { + expr: Arc::new(BinaryExpr { + left: Arc::new(Column::new("a", 0)), + op: Operator::Plus, + right: Arc::new(Column::new("b", 1)), + }), + options: SortOptions::default(), + }; + + let res = convert_sort_expr_with_filter_schema( + &JoinSide::Left, &filter, - left_schema, - right_schema, - &left_sorted, - &right_sorted, + schema, + &sorted, )?; - assert_eq!(sorted.len(), 3); - + assert!(res.is_none()); Ok(()) } } diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/core/src/physical_plan/joins/mod.rs index 79ebe3cce0bc..629cac4a64a3 100644 --- a/datafusion/core/src/physical_plan/joins/mod.rs +++ b/datafusion/core/src/physical_plan/joins/mod.rs @@ -21,7 +21,9 @@ pub use cross_join::CrossJoinExec; pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; // Note: SortMergeJoin is not used in plans yet -pub use hash_join_utils::SortedFilterExpr; +pub use hash_join_utils::{ + convert_sort_expr_with_filter_schema, map_origin_col_to_filter_col, SortedFilterExpr, +}; pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index bf2f0cdca789..f0781584e3ce 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -49,7 +49,7 @@ use crate::execution::context::TaskContext; use crate::logical_expr::JoinType; use crate::physical_plan::common::merge_batches; use crate::physical_plan::joins::hash_join_utils::{ - build_filter_input_order, build_join_indices, update_hash, JoinHashMap, + build_filter_input_order_v2, build_join_indices, update_hash, JoinHashMap, SortedFilterExpr, }; use crate::physical_plan::joins::utils::build_batch_from_indices; @@ -242,33 +242,64 @@ impl SymmetricHashJoinExec { ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); + + // Return error if the on constraints in the join are empty if on.is_empty() { return Err(DataFusionError::Plan( "On constraints in HashJoinExec should be non-empty".to_string(), )); } + // Check if the join is valid with the given on constraints check_join_is_valid(&left_schema, &right_schema, &on)?; + // Build the join schema from the left and right schemas let (schema, column_indices) = build_join_schema(&left_schema, &right_schema, join_type); + // Set a random state for the join let random_state = RandomState::with_seeds(0, 0, 0, 0); - // - let mut filter_columns = build_filter_input_order( + // Create expression graph for PhysicalExpr + let mut physical_expr_graph = + ExprIntervalGraph::try_new(filter.expression().clone())?; + + // Build the sorted filter expression for the left child + let left_filter_expression = match build_filter_input_order_v2( + JoinSide::Left, &filter, left.schema(), + left.output_ordering().unwrap().get(0).unwrap(), + )? { + Some(value) => value, + None => { + return Err(DataFusionError::Plan( + "The left side of the join does not have an expression sorted." + .to_string(), + )) + } + }; + // Build the sorted filter expression for the right child + let right_filter_expression = match build_filter_input_order_v2( + JoinSide::Right, + &filter, right.schema(), - left.output_ordering().unwrap(), - right.output_ordering().unwrap(), - )?; + right.output_ordering().unwrap().get(0).unwrap(), + )? { + Some(value) => value, + None => { + return Err(DataFusionError::Plan( + "The right side of the join does not have an expression sorted." + .to_string(), + )) + } + }; - // Create expression graph for PhysicalExpr. - let mut physical_expr_graph = - ExprIntervalGraph::try_new(filter.expression().clone())?; + // Store the left and right sorted filter expressions in a vector + let mut filter_columns = vec![left_filter_expression, right_filter_expression]; - // Since the graph is stable, we reuse NodeIndex of our filter expressions. + // Gather node indices of converted filter expressions in SortedFilterExpr + // using the filter columns vector let child_node_indexes = physical_expr_graph.gather_node_indices( &filter_columns .iter() @@ -276,8 +307,7 @@ impl SymmetricHashJoinExec { .collect::>(), ); - // We inject calculated node indexes into SortedFilterExpr. In graph calculations, - // we will be using node index to put calculated intervals into Columns or BinaryExprs . + // Inject calculated node indexes into SortedFilterExpr for (sorted_expr, (_, index)) in filter_columns.iter_mut().zip(child_node_indexes.iter()) { @@ -347,7 +377,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { } fn required_input_ordering(&self) -> Vec> { - vec![] + vec![None, None] } fn unbounded_output(&self, children: &[bool]) -> Result { @@ -689,7 +719,20 @@ fn determine_prune_length( .min() .unwrap()) } - +/// This method determines if the result of the join should be produced in the final step or not. +/// +/// # Arguments +/// +/// * build_side - Enum indicating the side of the join used as the build side. +/// * join_type - Enum indicating the type of join to be performed. +/// +/// # Returns +/// +/// A boolean indicating whether the result of the join should be produced in the final step or not. +/// The result will be true if the build side is JoinSide::Left and the join type is one of +/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi. +/// If the build side is JoinSide::Right, the result will be true if the join type +/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi. fn need_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool { if build_side.eq(&JoinSide::Left) { matches!( @@ -704,6 +747,19 @@ fn need_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bo } } +/// Get the anti join indices from the visited hash set. +/// +/// This method returns the indices from the original input that were not present in the visited hash set. +/// +/// # Arguments +/// +/// * prune_length - The length of the pruned record batch. +/// * deleted_offset - The offset to the indices. +/// * visited_rows - The hash set of visited indices. +/// +/// # Returns +/// +/// A `PrimitiveArray` of the anti join indices. fn get_anti_indices( prune_length: usize, deleted_offset: usize, @@ -714,6 +770,7 @@ where { let mut bitmap = BooleanBufferBuilder::new(prune_length); bitmap.append_n(prune_length, false); + // mark the indices as true if they are present in the visited hash set (0..prune_length).for_each(|v| { let row = &(v + deleted_offset); bitmap.set_bit(v, visited_rows.contains(row)); @@ -724,6 +781,20 @@ where .collect::>() } +/// This method creates a boolean buffer from the visited rows hash set +/// and the indices of the pruned record batch slice. +/// +/// It gets the indices from the original input that were present in the visited hash set. +/// +/// # Arguments +/// +/// * prune_length - The length of the pruned record batch. +/// * deleted_offset - The offset to the indices. +/// * visited_rows - The hash set of visited indices. +/// +/// # Returns +/// +/// A PrimitiveArray of the specified type T, containing the semi indices. fn get_semi_indices( prune_length: usize, deleted_offset: usize, @@ -734,6 +805,7 @@ where { let mut bitmap = BooleanBufferBuilder::new(prune_length); bitmap.append_n(prune_length, false); + // mark the indices as true if they are present in the visited hash set (0..prune_length).for_each(|v| { let row = &(v + deleted_offset); bitmap.set_bit(v, visited_rows.contains(row)); @@ -743,7 +815,15 @@ where .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) .collect::>() } - +/// Records the visited indices from the input `PrimitiveArray` of type `T` into the given hash set `visited`. +/// This function will insert the indices (offset by `offset`) into the `visited` hash set. +/// +/// # Arguments +/// +/// * `visited` - A hash set to store the visited indices. +/// * `offset` - An offset to the indices in the `PrimitiveArray`. +/// * `indices` - The input `PrimitiveArray` of type `T` which stores the indices to be recorded. +/// fn record_visited_indices( visited: &mut HashSet, offset: usize, @@ -755,6 +835,23 @@ fn record_visited_indices( } } +/// Calculate indices by join type. +/// +/// This method returns a tuple of two arrays: build and probe indices. +/// The length of both arrays will be the same. +/// +/// # Arguments +/// +/// * `build_side`: Join side which defines the build side. +/// * `prune_length`: Length of the prune data. +/// * `visited_rows`: Hash set of visited rows of the build side. +/// * `deleted_offset`: Deleted offset of the build side. +/// * `join_type`: The type of join to be performed. +/// +/// # Returns +/// +/// A tuple of two arrays of primitive types representing the build and probe indices. +/// fn calculate_indices_by_join_type( build_side: JoinSide, prune_length: usize, @@ -765,20 +862,21 @@ fn calculate_indices_by_join_type( where NativeAdapter: From<::Native>, { + // Store the result in a tuple let result = match (build_side, join_type) { + // In the case of `Left` or `Right` join, or `Full` join, get the anti indices (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) | (_, JoinType::Full) => { let build_unmatched_indices = get_anti_indices(prune_length, deleted_offset, visited_rows); - // right_indices - // all the element in the right side is None let mut builder = PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); builder.append_nulls(build_unmatched_indices.len()); let probe_indices = builder.finish(); (build_unmatched_indices, probe_indices) } + // In the case of `LeftSemi` or `RightSemi` join, get the semi indices (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => { let build_unmatched_indices = get_semi_indices(prune_length, deleted_offset, visited_rows); @@ -788,6 +886,7 @@ where let probe_indices = builder.finish(); (build_unmatched_indices, probe_indices) } + // The case of other join types is not considered _ => unreachable!(), }; Ok(result) @@ -829,15 +928,31 @@ impl OneSideHashJoiner { } } + /// Updates the internal state of the OneSideHashJoiner with the incoming batch. + /// + /// # Arguments + /// + /// * `on_build` - A slice of Columns that defines the join key of the OneSideHashJoiner + /// * `batch` - The incoming RecordBatch to be merged with the internal input buffer + /// * `random_state` - The random state used to hash values + /// + /// # Return + /// + /// Returns a Result with an empty Ok if the update was successful, otherwise returns an error. fn update_internal_state( &mut self, on_build: &[Column], batch: &RecordBatch, random_state: &RandomState, ) -> Result<()> { + // Merge the incoming batch with the existing input buffer self.input_buffer = merge_batches(&self.input_buffer, batch, batch.schema()).unwrap(); + + // Resize the hashes buffer to the number of rows in the incoming batch self.hashes_buffer.resize(batch.num_rows(), 0); + + // Update the hashmap with the join key values and hashes of the incoming batch update_hash( on_build, batch, @@ -846,13 +961,36 @@ impl OneSideHashJoiner { random_state, &mut self.hashes_buffer, )?; + + // Drain the hashes buffer and add them to the row_hash_values deque self.hashes_buffer .drain(0..) .into_iter() .for_each(|hash| self.row_hash_values.push_back(hash)); + Ok(()) } + /// This method performs a join between the build side input buffer and probe side RecordBatch. + /// + /// # Arguments + /// + /// * `schema` - A reference to the schema of the output record batch. + /// * `join_type` - The type of join to be performed. + /// * `on_build` - An array of columns on which the join will be performed. The columns are from the build side of the join. + /// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join. + /// * `filter` - An optional filter on the join condition. + /// * `probe_batch` - The second record batch to be joined. + /// * `probe_visited` - A hashset to store the visited indices from the probe batch. + /// * `probe_offset` - The offset of the probe side for visited indices calculations. + /// * `column_indices` - An array of columns to be selected for the result of the join. + /// * `random_state` - The random state for the join. + /// * `null_equals_null` - A boolean indicating whether null values should be treated as equal when joining. + /// + /// # Returns + /// + /// A `Result` containing an optional record batch if the join type is not one of LeftAnti, RightAnti, LeftSemi or RightSemi. + /// If the join type is one of the above four, the function will return None. #[allow(clippy::too_many_arguments)] fn record_batch_from_other_side( &mut self, @@ -860,7 +998,7 @@ impl OneSideHashJoiner { join_type: JoinType, on_build: &[Column], on_probe: &[Column], - filter: Option<&JoinFilter>, + filter: &JoinFilter, probe_batch: &RecordBatch, probe_visited: &mut HashSet, probe_offset: usize, @@ -877,7 +1015,7 @@ impl OneSideHashJoiner { &self.input_buffer, on_build, on_probe, - filter, + Some(filter), random_state, null_equals_null, &mut self.hashes_buffer, @@ -894,29 +1032,45 @@ impl OneSideHashJoiner { if need_produce_result_in_final(self.build_side.negate(), join_type) { record_visited_indices(probe_visited, probe_offset, &probe_side); } - match (self.build_side, join_type) { - ( - _, - JoinType::LeftAnti + if matches!( + join_type, + JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftSemi - | JoinType::RightSemi, - ) => Ok(None), - (_, _) => { - let res = build_batch_from_indices( - schema.as_ref(), - &self.input_buffer, - probe_batch, - build_side.clone(), - probe_side.clone(), - column_indices, - self.build_side, - )?; - Ok(Some(res)) - } + | JoinType::RightSemi + ) { + Ok(None) + } else { + let res = build_batch_from_indices( + schema.as_ref(), + &self.input_buffer, + probe_batch, + build_side.clone(), + probe_side.clone(), + column_indices, + self.build_side, + )?; + Ok(Some(res)) } } + /// Function to produce the unmatched record results based on the build side, + /// join type and other parameters. + /// + /// The method use first 'prune_length' rows from the build side input buffer to + /// produce results. + /// + /// # Arguments + /// + /// * `output_schema` - The schema of the final output record batch. + /// * `prune_length` - The length of the determined prune length. + /// * `probe_schema` - The schema of the probe RecordBatch. + /// * `join_type` - The type of join to be performed. + /// * `column_indices` - Indices of columns that are being joined. + /// + /// # Returns + /// + /// * `Option` - The final output record batch if required, otherwise `None`. fn build_side_determined_results( &self, output_schema: SchemaRef, @@ -925,7 +1079,9 @@ impl OneSideHashJoiner { join_type: JoinType, column_indices: &[ColumnIndex], ) -> ArrowResult> { + // Check if we need to produce a result in the final output let result = if need_produce_result_in_final(self.build_side, join_type) { + // Calculate the indices for build and probe sides based on join type and build side let (build_indices, probe_indices) = calculate_indices_by_join_type( self.build_side, prune_length, @@ -933,7 +1089,11 @@ impl OneSideHashJoiner { self.deleted_offset, join_type, )?; + + // Create an empty probe record batch let empty_probe_batch = RecordBatch::new_empty(probe_schema); + + // Build the final result from the indices of build and probe sides Some(build_batch_from_indices( output_schema.as_ref(), &self.input_buffer, @@ -944,37 +1104,63 @@ impl OneSideHashJoiner { self.build_side, )?) } else { + // If we don't need to produce a result, return None None }; Ok(result) } + /// Prunes the build side of a hash join. + /// + /// The input record batch, `probe_batch`, is used to update the intervals of the sorted filter expressions in `sorted_filter_exprs`. + /// The updated intervals are then used to determine the length of the build side rows to prune. If there are rows to prune, they are pruned and the result is returned. + /// + /// # Arguments + /// + /// * `schema` - The schema of the final output record batch + /// * `probe_batch` - Incoming RecordBatch of the probe side. + /// * `sorted_filter_exprs` - A mutable vector of sorted filter expressions that are used to update the intervals. + /// * `join_type` - The type of join (e.g. inner, left, right, etc.). + /// * `column_indices` - A vector of column indices that specifies which columns from the + /// build side should be included in the output. + /// * `physical_expr_graph` - A mutable reference to the physical expression graph. + /// + /// # Returns + /// + /// If there are no rows in the input buffer, returns `Ok(None)`. + /// If there are rows to prune, returns the pruned build side record batch wrapped in an `Ok` variant. + /// Otherwise, returns `Ok(None)`. fn prune_build_side( &mut self, schema: SchemaRef, probe_batch: &RecordBatch, - filter_columns: &mut [SortedFilterExpr], + sorted_filter_exprs: &mut [SortedFilterExpr], join_type: JoinType, column_indices: &[ColumnIndex], physical_expr_graph: &mut ExprIntervalGraph, ) -> ArrowResult> { + // Return None if there are no rows in the input buffer if self.input_buffer.num_rows() == 0 { return Ok(None); } - // We use Vec<(usize, Interval)> instead of Hashmap since the expected - // filter exprs relatively low and conversion between Vec and Hashmap - // back and forth may be slower. - let mut filter_intervals = filter_columns + + // Convert the sorted filter expressions into a vector of (node_index, interval) tuples + // for use in updating the interval graph + let mut filter_intervals = sorted_filter_exprs .iter() .map(|sorted_expr| (sorted_expr.node_index(), sorted_expr.interval().clone())) .collect::>(); - // Use this vector to seed the child PhysicalExpr interval. + + // Use the filter intervals to update the physical expression graph physical_expr_graph.update_intervals(&mut filter_intervals)?; + + // A vector to track changes in the intervals of the sorted filter expressions let mut change_track = Vec::with_capacity(filter_intervals.len()); - // Mutate the Vec for intervals if it is not equal. - // Keep track of the interval changes. - for (sorted_expr, (_, interval)) in - filter_columns.iter_mut().zip(filter_intervals.into_iter()) + + // Update the intervals of the sorted filter expressions and track any changes + for (sorted_expr, (_, interval)) in sorted_filter_exprs + .iter_mut() + .zip(filter_intervals.into_iter()) { if interval.eq(sorted_expr.interval()) { change_track.push(false); @@ -983,12 +1169,19 @@ impl OneSideHashJoiner { change_track.push(true); } } - // If any interval change, try to determine prune index. + + // Determine the length of the rows to prune if there was a change in the intervals let prune_length = if change_track.iter().any(|x| *x) { - determine_prune_length(&self.input_buffer, filter_columns, self.build_side)? + determine_prune_length( + &self.input_buffer, + sorted_filter_exprs, + self.build_side, + )? } else { 0 }; + + // If there are rows to prune, prune them and return the result if prune_length > 0 { let result = self.build_side_determined_results( schema, @@ -1101,7 +1294,7 @@ impl SymmetricHashJoinStream { self.join_type, &self.on_right, &self.on_left, - Some(&self.filter), + &self.filter, &probe_batch, &mut self.left.visited_rows, self.left.offset, @@ -1179,7 +1372,7 @@ impl SymmetricHashJoinStream { self.join_type, &self.on_left, &self.on_right, - Some(&self.filter), + &self.filter, &probe_batch, &mut self.right.visited_rows, self.right.offset, @@ -1260,13 +1453,6 @@ mod tests { const TABLE_SIZE: i32 = 1_000; - // TODO: Lazy statistics for same record batches. - // use lazy_static::lazy_static; - // - // lazy_static! { - // static ref SIDE_BATCH: Result<(RecordBatch, RecordBatch)> = build_sides_record_batches(TABLE_SIZE, (10, 11)); - // } - fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { // compare let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); @@ -1298,7 +1484,7 @@ mod tests { null_equals_null: bool, context: Arc, ) -> Result> { - let partition_count = 1; + let partition_count = 4; let left_expr = on .iter() diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index dfeb03a22406..ac1cb2782f3e 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -481,15 +481,21 @@ impl ExprIntervalGraph { let op = binary.op(); if let (Some(new_left_interval), Some(new_right_interval)) = if op.is_logic_operator() { - // There is no propagation process for logical operators. + // TODO: Currently, this implementation only supports the AND operator + // and does not require any further propagation. + // In the future, upon adding support for additional logical operators, + // this method will require modification to support propagating + // the changes accordingly. continue; } else if op.is_comparison_operator() { - // If comparison is strictly false, there is nothing to do for shrink. if let Interval { lower: ScalarValue::Boolean(Some(false)), upper: ScalarValue::Boolean(Some(false)), } = node_interval { + // TODO: The optimization of handling strictly false clauses through + // conversion to equivalent comparison operators (e.g. GT to LE, LT to GE) + // can be implemented once support for open/closed intervals is added. continue; } // Propagate the comparison operator. From 08cf4f0340041b16dafe8382234d0c45d16508ce Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 6 Feb 2023 14:00:45 +0300 Subject: [PATCH 14/39] Refactor on input stream consumer on SymmetricHashJoin --- .../joins/symmetric_hash_join.rs | 434 +++++++++--------- .../physical-expr/src/intervals/cp_solver.rs | 9 +- 2 files changed, 223 insertions(+), 220 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index f0781584e3ce..199fff2f0f08 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -181,17 +181,21 @@ pub struct SymmetricHashJoinExec { pub(crate) null_equals_null: bool, } +#[derive(Debug)] +struct SymmetricHashJoinSideMetrics { + /// Number of right batches consumed by this operator + input_batches: metrics::Count, + /// Number of left rows consumed by this operator + input_rows: metrics::Count, +} + /// Metrics for HashJoinExec #[derive(Debug)] struct SymmetricHashJoinMetrics { /// Number of left batches consumed by this operator - left_input_batches: metrics::Count, + left: SymmetricHashJoinSideMetrics, /// Number of right batches consumed by this operator - right_input_batches: metrics::Count, - /// Number of left rows consumed by this operator - left_input_rows: metrics::Count, - /// Number of right rows consumed by this operator - right_input_rows: metrics::Count, + right: SymmetricHashJoinSideMetrics, /// Number of batches produced by this operator output_batches: metrics::Count, /// Number of rows produced by this operator @@ -200,16 +204,21 @@ struct SymmetricHashJoinMetrics { impl SymmetricHashJoinMetrics { pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let left_input_batches = - MetricBuilder::new(metrics).counter("left_input_batches", partition); - let right_input_batches = - MetricBuilder::new(metrics).counter("right_input_batches", partition); - - let left_input_rows = - MetricBuilder::new(metrics).counter("left_input_rows", partition); + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let left = SymmetricHashJoinSideMetrics { + input_batches, + input_rows, + }; - let right_input_rows = - MetricBuilder::new(metrics).counter("right_input_rows", partition); + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let right = SymmetricHashJoinSideMetrics { + input_batches, + input_rows, + }; let output_batches = MetricBuilder::new(metrics).counter("output_batches", partition); @@ -217,10 +226,8 @@ impl SymmetricHashJoinMetrics { let output_rows = MetricBuilder::new(metrics).output_rows(partition); Self { - left_input_batches, - right_input_batches, - left_input_rows, - right_input_rows, + left, + right, output_batches, output_rows, } @@ -486,9 +493,10 @@ impl ExecutionPlan for SymmetricHashJoinExec { let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); // TODO: Currently, working on partitioned left and right. We can also coalesce. - let left_side_joiner = OneSideHashJoiner::new(JoinSide::Left, self.left.clone()); + let left_side_joiner = + OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.clone()); let right_side_joiner = - OneSideHashJoiner::new(JoinSide::Right, self.right.clone()); + OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.clone()); let left_stream = self.left.execute(partition, context.clone())?; let right_stream = self.right.execute(partition, context)?; @@ -496,41 +504,37 @@ impl ExecutionPlan for SymmetricHashJoinExec { left_stream, right_stream, schema: self.schema(), - on_left, - on_right, filter: self.filter.clone(), join_type: self.join_type, random_state: self.random_state.clone(), left: left_side_joiner, right: right_side_joiner, column_indices: self.column_indices.clone(), - join_metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), + metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), physical_expr_graph: self.physical_expr_graph.clone(), null_equals_null: self.null_equals_null, filter_columns: self.filter_columns.clone(), final_result: false, - data_side: JoinSide::Left, + probe_side: JoinSide::Left, })) } } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct SymmetricHashJoinStream { + /// Left stream left_stream: SendableRecordBatchStream, + /// right stream right_stream: SendableRecordBatchStream, /// Input schema schema: Arc, - /// columns from the left - on_left: Vec, - /// columns from the right used to compute the hash - on_right: Vec, /// join filter filter: JoinFilter, /// type of the join join_type: JoinType, - // left + // left hash joiner left: OneSideHashJoiner, - /// right + /// right hash joiner right: OneSideHashJoiner, /// Information of index and left / right placement of columns column_indices: Vec, @@ -543,10 +547,11 @@ struct SymmetricHashJoinStream { /// If null_equals_null is true, null == null else null != null null_equals_null: bool, /// Metrics - join_metrics: SymmetricHashJoinMetrics, + metrics: SymmetricHashJoinMetrics, /// There is nothing to process anymore final_result: bool, - data_side: JoinSide, + /// The current probe side. We choose build and probe side according to this attribute. + probe_side: JoinSide, } impl RecordBatchStream for SymmetricHashJoinStream { @@ -611,8 +616,20 @@ fn prune_visited_rows( Ok(()) } -/// We calculate PhysicalExpr boundaries. We get first ScalarValues from build side and -/// last ScalarValues from probe side, then update the SortedFilterExpr intervals. +/// +/// Calculate the filter expression intervals +/// +/// This function updates the `interval` field of each `SortedFilterExpr` in `filter_columns` based on +/// the first or last value of the expression in `build_input_buffer` or `probe_batch`. +/// +/// # Arguments +/// +/// * `build_input_buffer` - The RecordBatch on the build side of the join +/// * `probe_batch` - The RecordBatch on the probe side of the join +/// * `filter_columns` - A mutable list of `SortedFilterExpr` objects to update +/// * `build_side` - The side of the join where `build_input_buffer` is used +/// +/// # Note /// /// Build Probe /// +-------+ +-------+ @@ -632,10 +649,10 @@ fn prune_visited_rows( /// Expr3: [c, inf] /// Expr4: [d, inf] /// -/// Each probe batch trigger recalculation. Remember that the calculation logic differs -/// according to the build side. When the build side is changed by a new data arrival from other child, -/// [SortedFilterExpr] intervals change according to the side. You must use these calculated intervals -/// within same side, per [RecordBatch] arrival. +/// +/// # Returns +/// +/// Result object indicating success or failure /// fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, @@ -643,10 +660,14 @@ fn calculate_filter_expr_intervals( filter_columns: &mut [SortedFilterExpr], build_side: JoinSide, ) -> Result<()> { + // If either build or probe side has no data, return early if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(()); } + + // Iterate through each SortedFilterExpr in filter_columns for sorted_expr in filter_columns.iter_mut() { + // Destructure the SortedFilterExpr let SortedFilterExpr { join_side, origin_expr, @@ -654,6 +675,8 @@ fn calculate_filter_expr_intervals( interval, .. } = sorted_expr; + + // Get the array of the first or last value for the expression, depending on join side let array = if build_side.eq(join_side) { // Get first value for expr origin_expr @@ -665,8 +688,14 @@ fn calculate_filter_expr_intervals( .evaluate(&probe_batch.slice(probe_batch.num_rows() - 1, 1))? .into_array(1) }; + + // Convert the array to a ScalarValue let value = ScalarValue::try_from_array(&array, 0)?; + + // Create a ScalarValue representing positive or negative infinity for the same data type let infinite = ScalarValue::try_from(value.get_datatype())?; + + // Update the interval with lower and upper bounds based on the sort option *interval = if sort_option.descending { Interval { lower: infinite, @@ -895,8 +924,10 @@ where struct OneSideHashJoiner { // Build side build_side: JoinSide, - // Inout record batch buffer + /// Inout record batch buffer input_buffer: RecordBatch, + /// columns from the side + on: Vec, /// Hashmap hashmap: JoinHashMap, /// To optimize hash deleting in case of pruning, we hold them in memory @@ -914,10 +945,15 @@ struct OneSideHashJoiner { } impl OneSideHashJoiner { - pub fn new(build_side: JoinSide, plan: Arc) -> Self { + pub fn new( + build_side: JoinSide, + on: Vec, + plan: Arc, + ) -> Self { Self { build_side, input_buffer: RecordBatch::new_empty(plan.schema()), + on, hashmap: JoinHashMap(RawTable::with_capacity(10_000)), row_hash_values: VecDeque::new(), hashes_buffer: vec![], @@ -941,7 +977,6 @@ impl OneSideHashJoiner { /// Returns a Result with an empty Ok if the update was successful, otherwise returns an error. fn update_internal_state( &mut self, - on_build: &[Column], batch: &RecordBatch, random_state: &RandomState, ) -> Result<()> { @@ -954,7 +989,7 @@ impl OneSideHashJoiner { // Update the hashmap with the join key values and hashes of the incoming batch update_hash( - on_build, + &self.on, batch, &mut self.hashmap, self.offset, @@ -977,7 +1012,6 @@ impl OneSideHashJoiner { /// /// * `schema` - A reference to the schema of the output record batch. /// * `join_type` - The type of join to be performed. - /// * `on_build` - An array of columns on which the join will be performed. The columns are from the build side of the join. /// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join. /// * `filter` - An optional filter on the join condition. /// * `probe_batch` - The second record batch to be joined. @@ -992,11 +1026,10 @@ impl OneSideHashJoiner { /// A `Result` containing an optional record batch if the join type is not one of LeftAnti, RightAnti, LeftSemi or RightSemi. /// If the join type is one of the above four, the function will return None. #[allow(clippy::too_many_arguments)] - fn record_batch_from_other_side( + fn join_with_probe_batch( &mut self, schema: SchemaRef, join_type: JoinType, - on_build: &[Column], on_probe: &[Column], filter: &JoinFilter, probe_batch: &RecordBatch, @@ -1013,7 +1046,7 @@ impl OneSideHashJoiner { probe_batch, &self.hashmap, &self.input_buffer, - on_build, + &self.on, on_probe, Some(filter), random_state, @@ -1110,10 +1143,11 @@ impl OneSideHashJoiner { Ok(result) } - /// Prunes the build side of a hash join. + /// Prunes the internal buffer. /// /// The input record batch, `probe_batch`, is used to update the intervals of the sorted filter expressions in `sorted_filter_exprs`. - /// The updated intervals are then used to determine the length of the build side rows to prune. If there are rows to prune, they are pruned and the result is returned. + /// The updated intervals are then used to determine the length of the build side rows to prune. + /// If there are rows to prune, they are pruned and the result is returned. /// /// # Arguments /// @@ -1130,7 +1164,7 @@ impl OneSideHashJoiner { /// If there are no rows in the input buffer, returns `Ok(None)`. /// If there are rows to prune, returns the pruned build side record batch wrapped in an `Ok` variant. /// Otherwise, returns `Ok(None)`. - fn prune_build_side( + fn prune_with_probe_batch( &mut self, schema: SchemaRef, probe_batch: &RecordBatch, @@ -1212,33 +1246,48 @@ impl OneSideHashJoiner { } } -fn produce_batch_result( +fn combine_two_batches_result( output_schema: SchemaRef, - equal_batch: ArrowResult>, - anti_batch: ArrowResult>, + left_batch: ArrowResult>, + right_batch: ArrowResult>, ) -> ArrowResult { - match (equal_batch, anti_batch) { - (Ok(Some(batch)), Ok(None)) | (Ok(None), Ok(Some(batch))) => Ok(batch), - (Err(e), _) | (_, Err(e)) => Err(e), - (Ok(Some(equal_batch)), Ok(Some(anti_batch))) => { - concat_batches(&output_schema, &[equal_batch, anti_batch]) + match (left_batch, right_batch) { + (Ok(Some(batch)), Ok(None)) | (Ok(None), Ok(Some(batch))) => { + // If either left_batch or right_batch is present, return that batch + Ok(batch) + } + (Err(e), _) | (_, Err(e)) => { + // If either left_batch or right_batch returns an error, return the error + Err(e) + } + (Ok(Some(left_batch)), Ok(Some(right_batch))) => { + // If both left_batch and right_batch are present, concatenate them + concat_batches(&output_schema, &[left_batch, right_batch]) + } + (Ok(None), Ok(None)) => { + // If both left_batch and right_batch are not present, return an empty batch + Ok(RecordBatch::new_empty(output_schema)) } - (Ok(None), Ok(None)) => Ok(RecordBatch::new_empty(output_schema)), } } impl SymmetricHashJoinStream { - /// Separate implementation function that unpins the [`SymmetricHashJoinStream`] so - /// that partial borrows work correctly + /// Polls the next result of the join operation. + /// + /// If the result of the join is ready, it returns the next record batch. If the join has completed and there are no more results, + /// it returns `Poll::Ready(None)`. If the join operation is not complete but the current stream is not ready yet, it returns `Poll::Pending`. fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { loop { + // If the final result has already been obtained, return `Poll::Ready(None)` if self.final_result { return Poll::Ready(None); } + // If both streams have been exhausted, return the final result if self.right.exhausted && self.left.exhausted { + // Get left side results let left_result = self.left.build_side_determined_results( self.schema.clone(), self.left.input_buffer.num_rows(), @@ -1246,6 +1295,7 @@ impl SymmetricHashJoinStream { self.join_type, &self.column_indices, ); + // Get right side results let right_result = self.right.build_side_determined_results( self.schema.clone(), self.right.input_buffer.num_rows(), @@ -1254,169 +1304,125 @@ impl SymmetricHashJoinStream { &self.column_indices, ); self.final_result = true; - let result = - produce_batch_result(self.schema.clone(), left_result, right_result); + // Combine results + let result = combine_two_batches_result( + self.schema.clone(), + left_result, + right_result, + ); + // Update the metrics if let Ok(batch) = &result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); } return Poll::Ready(Some(result)); } - if self.data_side.eq(&JoinSide::Left) { - match self.left_stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(probe_batch))) => { - self.join_metrics.left_input_batches.add(1); - self.join_metrics - .left_input_rows - .add(probe_batch.num_rows()); - match self.left.update_internal_state( - &self.on_left, - &probe_batch, - &self.random_state, - ) { - Ok(_) => {} - Err(e) => { - return Poll::Ready(Some(Err(ArrowError::ComputeError( - e.to_string(), - )))) - } - } - // Each probe batch trigger recalculation. - calculate_filter_expr_intervals( - &self.right.input_buffer, - &probe_batch, - &mut self.filter_columns, - JoinSide::Right, - )?; - // Using right as build side. - let equal_result = self.right.record_batch_from_other_side( - self.schema.clone(), - self.join_type, - &self.on_right, - &self.on_left, - &self.filter, - &probe_batch, - &mut self.left.visited_rows, - self.left.offset, - &self.column_indices, - &self.random_state, - &self.null_equals_null, - ); - self.left.offset += probe_batch.num_rows(); - // Right side will be pruned since the batch coming from left. - let anti_result = self.right.prune_build_side( - self.schema.clone(), - &probe_batch, - &mut self.filter_columns, - self.join_type, - &self.column_indices, - &mut self.physical_expr_graph, - ); - let result = produce_batch_result( - self.schema.clone(), - equal_result, - anti_result, - ); - if let Ok(batch) = &result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - if !self.right.exhausted { - self.data_side = JoinSide::Right; + + // Determine which stream should be polled next. The side the RecordBatch comes from becomes probe side. + let ( + input_stream, + probe_hash_joiner, + build_hash_joiner, + build_join_side, + probe_side_metrics, + ) = if self.probe_side.eq(&JoinSide::Left) { + ( + &mut self.left_stream, + &mut self.left, + &mut self.right, + JoinSide::Right, + &mut self.metrics.left, + ) + } else { + ( + &mut self.right_stream, + &mut self.right, + &mut self.left, + JoinSide::Left, + &mut self.metrics.right, + ) + }; + // Poll the next batch from `input_stream` + match input_stream.poll_next_unpin(cx) { + // Batch is available + Poll::Ready(Some(Ok(probe_batch))) => { + // Update the metrics for the stream that was polled + probe_side_metrics.input_batches.add(1); + probe_side_metrics.input_rows.add(probe_batch.num_rows()); + // Update the internal state of the hash joiner for build side + match probe_hash_joiner + .update_internal_state(&probe_batch, &self.random_state) + { + Ok(_) => {} + Err(e) => { + return Poll::Ready(Some(Err(ArrowError::ComputeError( + e.to_string(), + )))) } - return Poll::Ready(Some(result)); } - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - self.left.exhausted = true; - self.data_side = JoinSide::Right; - continue; + // Calculate filter intervals + calculate_filter_expr_intervals( + &build_hash_joiner.input_buffer, + &probe_batch, + &mut self.filter_columns, + build_join_side, + )?; + // Join the two sides, with + let equal_result = build_hash_joiner.join_with_probe_batch( + self.schema.clone(), + self.join_type, + &probe_hash_joiner.on, + &self.filter, + &probe_batch, + &mut probe_hash_joiner.visited_rows, + probe_hash_joiner.offset, + &self.column_indices, + &self.random_state, + &self.null_equals_null, + ); + // Increment the offset for the probe hash joiner + probe_hash_joiner.offset += probe_batch.num_rows(); + // Prune build input buffer using the physical expression graph and filter intervals + let anti_result = build_hash_joiner.prune_with_probe_batch( + self.schema.clone(), + &probe_batch, + &mut self.filter_columns, + self.join_type, + &self.column_indices, + &mut self.physical_expr_graph, + ); + // Combine results + let result = combine_two_batches_result( + self.schema.clone(), + equal_result, + anti_result, + ); + // Update the metrics + if let Ok(batch) = &result { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); } - Poll::Pending => { - if !self.right.exhausted { - self.data_side = JoinSide::Right; - continue; - } else { - return Poll::Pending; - } + // Choose next pool side. If other side is not exhausted, revert the probe side + // before returning the result. + if !build_hash_joiner.exhausted { + self.probe_side = build_join_side; } + return Poll::Ready(Some(result)); } - } else { - match self.right_stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(probe_batch))) => { - self.join_metrics.right_input_batches.add(1); - self.join_metrics - .right_input_rows - .add(probe_batch.num_rows()); - // Right is build side - match self.right.update_internal_state( - &self.on_right, - &probe_batch, - &self.random_state, - ) { - Ok(_) => {} - Err(e) => { - return Poll::Ready(Some(Err(ArrowError::ComputeError( - e.to_string(), - )))) - } - } - calculate_filter_expr_intervals( - &self.left.input_buffer, - &probe_batch, - &mut self.filter_columns, - JoinSide::Left, - )?; - let equal_result = self.left.record_batch_from_other_side( - self.schema.clone(), - self.join_type, - &self.on_left, - &self.on_right, - &self.filter, - &probe_batch, - &mut self.right.visited_rows, - self.right.offset, - &self.column_indices, - &self.random_state, - &self.null_equals_null, - ); - self.right.offset += probe_batch.num_rows(); - - let anti_result = self.left.prune_build_side( - self.schema.clone(), - &probe_batch, - &mut self.filter_columns, - self.join_type, - &self.column_indices, - &mut self.physical_expr_graph, - ); - let result = produce_batch_result( - self.schema.clone(), - equal_result, - anti_result, - ); - if let Ok(batch) = &result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - if !self.left.exhausted { - self.data_side = JoinSide::Left; - } - return Poll::Ready(Some(result)); - } - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - self.right.exhausted = true; - self.data_side = JoinSide::Left; + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + // Update probe side is exhausted. + probe_hash_joiner.exhausted = true; + // Change the probe side + self.probe_side = build_join_side; + continue; + } + Poll::Pending => { + if !build_hash_joiner.exhausted { + self.probe_side = build_join_side; continue; - } - Poll::Pending => { - if !self.left.exhausted { - self.data_side = JoinSide::Left; - continue; - } else { - return Poll::Pending; - } + } else { + return Poll::Pending; } } } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index ac1cb2782f3e..e1e53ba8ad19 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -431,12 +431,9 @@ impl ExprIntervalGraph { // Calculate and replace current node's interval: self.graph[node].interval = apply_operator(binary.op(), left_interval, right_interval)?; - } else if let Some(CastExpr { - cast_type, - cast_options, - .. - }) = expr_any.downcast_ref::() - { + } else if let Some(cast_expr) = expr_any.downcast_ref::() { + let cast_type = cast_expr.cast_type(); + let cast_options = cast_expr.cast_options(); let child_index = edges.next_node(&self.graph).unwrap(); let child = self.graph.index(child_index); // Cast current node's interval to the right type: From 05beacff7ec23c21e5af98eaba2f171151a005b4 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 6 Feb 2023 17:13:10 +0300 Subject: [PATCH 15/39] After merge resolution, before proto update --- .../src/physical_plan/coalesce_batches.rs | 1 - .../joins/symmetric_hash_join.rs | 83 ++++++++++++++++--- datafusion/core/src/test_util.rs | 6 +- datafusion/core/tests/fifo.rs | 3 +- 4 files changed, 78 insertions(+), 15 deletions(-) diff --git a/datafusion/core/src/physical_plan/coalesce_batches.rs b/datafusion/core/src/physical_plan/coalesce_batches.rs index 0c3cf4df2ff2..2e7211fc3ae4 100644 --- a/datafusion/core/src/physical_plan/coalesce_batches.rs +++ b/datafusion/core/src/physical_plan/coalesce_batches.rs @@ -235,7 +235,6 @@ impl CoalesceBatchesStream { // reset buffer state self.buffer.clear(); self.buffered_rows = 0; - println!("coalesce yolluyor {}", batch.num_rows()); // return batch return Poll::Ready(Some(Ok(batch))); } diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 199fff2f0f08..97d5f61135b1 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -32,7 +32,6 @@ use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveBuilder}; use arrow::compute::concat_batches; use arrow::datatypes::ArrowNativeType; use arrow::datatypes::{Schema, SchemaRef}; -use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use futures::{Stream, StreamExt}; use hashbrown::raw::RawTable; @@ -561,7 +560,7 @@ impl RecordBatchStream for SymmetricHashJoinStream { } impl Stream for SymmetricHashJoinStream { - type Item = ArrowResult; + type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, @@ -1038,7 +1037,7 @@ impl OneSideHashJoiner { column_indices: &[ColumnIndex], random_state: &RandomState, null_equals_null: &bool, - ) -> ArrowResult> { + ) -> Result> { if self.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(Some(RecordBatch::new_empty(schema))); } @@ -1111,7 +1110,7 @@ impl OneSideHashJoiner { probe_schema: SchemaRef, join_type: JoinType, column_indices: &[ColumnIndex], - ) -> ArrowResult> { + ) -> Result> { // Check if we need to produce a result in the final output let result = if need_produce_result_in_final(self.build_side, join_type) { // Calculate the indices for build and probe sides based on join type and build side @@ -1172,7 +1171,7 @@ impl OneSideHashJoiner { join_type: JoinType, column_indices: &[ColumnIndex], physical_expr_graph: &mut ExprIntervalGraph, - ) -> ArrowResult> { + ) -> Result> { // Return None if there are no rows in the input buffer if self.input_buffer.num_rows() == 0 { return Ok(None); @@ -1248,9 +1247,9 @@ impl OneSideHashJoiner { fn combine_two_batches_result( output_schema: SchemaRef, - left_batch: ArrowResult>, - right_batch: ArrowResult>, -) -> ArrowResult { + left_batch: Result>, + right_batch: Result>, +) -> Result { match (left_batch, right_batch) { (Ok(Some(batch)), Ok(None)) | (Ok(None), Ok(Some(batch))) => { // If either left_batch or right_batch is present, return that batch @@ -1263,6 +1262,7 @@ fn combine_two_batches_result( (Ok(Some(left_batch)), Ok(Some(right_batch))) => { // If both left_batch and right_batch are present, concatenate them concat_batches(&output_schema, &[left_batch, right_batch]) + .map_err(|e| DataFusionError::ArrowError(e)) } (Ok(None), Ok(None)) => { // If both left_batch and right_batch are not present, return an empty batch @@ -1279,7 +1279,7 @@ impl SymmetricHashJoinStream { fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, - ) -> Poll>> { + ) -> Poll>> { loop { // If the final result has already been obtained, return `Poll::Ready(None)` if self.final_result { @@ -1355,7 +1355,7 @@ impl SymmetricHashJoinStream { { Ok(_) => {} Err(e) => { - return Poll::Ready(Some(Err(ArrowError::ComputeError( + return Poll::Ready(Some(Err(DataFusionError::Internal( e.to_string(), )))) } @@ -2059,6 +2059,69 @@ mod tests { Ok(()) } + #[tokio::test(flavor = "multi_thread")] + async fn single_test() -> Result<()> { + let case_expr = 1; + let cardinality = (11, 21); + let join_type = JoinType::Full; + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("la1_des", &left_batch.schema())?), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("ra1_des", &right_batch.schema())?), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + + let on = vec![( + Column::new_with_schema("lc1", &left.schema())?, + Column::new_with_schema("rc1", &right.schema())?, + )]; + + let left_col = Arc::new(Column::new("left", 0)); + let right_col = Arc::new(Column::new("right", 1)); + + let column_indices = vec![ + ColumnIndex { + index: 5, + side: JoinSide::Left, + }, + ColumnIndex { + index: 5, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(left_col.name(), DataType::Int32, true), + Field::new(right_col.name(), DataType::Int32, true), + ]); + + let filter_expr = + join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); + + let filter = JoinFilter::new( + filter_expr, + column_indices.clone(), + intermediate_schema.clone(), + ); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + #[rstest] #[tokio::test(flavor = "multi_thread")] async fn join_all_one_descending_numeric_particular( diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index 4ef6c7d52dec..1ee912d7d089 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -395,7 +395,7 @@ pub async fn test_create_unbounded_sorted_file( table_name: &str, ) -> datafusion_common::Result<()> { // Register table - let right_schema = Arc::new(Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![ Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); @@ -410,7 +410,7 @@ pub async fn test_create_unbounded_sorted_file( .collect::>(); // Mark infinite and provide schema let fifo_options = CsvReadOptions::new() - .schema(right_schema.as_ref()) + .schema(schema.as_ref()) .has_header(false) .mark_infinite(true); // Get listing options @@ -422,7 +422,7 @@ pub async fn test_create_unbounded_sorted_file( table_name, file_path.as_os_str().to_str().unwrap(), options_sort.clone(), - Some(right_schema), + Some(schema), None, ) .await?; diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 92d56a5262a0..9604e8b4da8c 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -277,7 +277,8 @@ mod unix_test { for (cnt, (a1, a2)) in a1_iter.zip(a2_iter).enumerate() { // Wait a reading sign for unbounded execution // After first batch FIFO reading, we will wait for a batch created. - while *waiting_thread.lock().unwrap() && TEST_BATCH_SIZE < cnt { + while *waiting_thread.lock().unwrap() && TEST_BATCH_SIZE + 1 < cnt + { thread::sleep(Duration::from_millis(200)); } let line = format!("{a1},{a2}\n").to_owned(); From f59018b0fedf91a837000cbfaf0dc76f3d824421 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 6 Feb 2023 18:49:35 +0300 Subject: [PATCH 16/39] Revery unnecessary changes in some exprs Also, cargo.lock update. --- datafusion-cli/Cargo.lock | 188 ++++++++++-------- .../physical_plan/joins/hash_join_utils.rs | 154 ++++++-------- .../joins/symmetric_hash_join.rs | 13 +- .../physical-expr/src/expressions/binary.rs | 6 +- .../physical-expr/src/expressions/cast.rs | 6 +- .../physical-expr/src/expressions/column.rs | 4 +- .../physical-expr/src/intervals/cp_solver.rs | 6 +- 7 files changed, 191 insertions(+), 186 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 6678e538c7d6..8323ae52b6da 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -291,9 +291,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.63" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eff18d764974428cf3a9328e23fc5c986f5fbed46e6cd4cdf42544df5d297ec1" +checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" dependencies = [ "proc-macro2", "quote", @@ -414,9 +414,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfb24e866b15a1af2a1b663f10c6b6b8f397a84aadb828f12e5b289ec23a3a3c" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" [[package]] name = "bzip2" @@ -630,9 +630,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "322296e2f2e5af4270b54df9e85a02ff037e271af20ba3e7fe1575515dc840b8" +checksum = "bc831ee6a32dd495436e317595e639a587aa9907bef96fe6e6abc290ab6204e9" dependencies = [ "cc", "cxxbridge-flags", @@ -642,9 +642,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "017a1385b05d631e7875b1f151c9f012d37b53491e2a87f65bff5c262b2111d8" +checksum = "94331d54f1b1a8895cd81049f7eaaaef9d05a7dcb4d1fd08bf3ff0806246789d" dependencies = [ "cc", "codespan-reporting", @@ -657,15 +657,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c26bbb078acf09bc1ecda02d4223f03bdd28bd4874edcb0379138efc499ce971" +checksum = "48dcd35ba14ca9b40d6e4b4b39961f23d835dbb8eed74565ded361d93e1feb8a" [[package]] name = "cxxbridge-macro" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357f40d1f06a24b60ae1fe122542c1fb05d28d32acb2aed064e84bc2ad1e252e" +checksum = "81bbeb29798b407ccd82a3324ade1a7286e0d29851475990b612670f6f5124d2" dependencies = [ "proc-macro2", "quote", @@ -903,9 +903,9 @@ checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" [[package]] name = "encoding_rs" -version = "0.8.31" +version = "0.8.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9852635589dc9f9ea1b6fe9f05b50ef208c85c834a562f0c6abb1c475736ec2b" +checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" dependencies = [ "cfg-if", ] @@ -971,13 +971,13 @@ dependencies = [ [[package]] name = "fd-lock" -version = "3.0.9" +version = "3.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28c0190ff0bd3b28bfdd4d0cf9f92faa12880fb0b8ae2054723dd6c76a4efd42" +checksum = "8ef1a30ae415c3a691a4f41afddc2dbcd6d70baf338368d85ebc1e8ed92cedb9" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -1023,9 +1023,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0" +checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84" dependencies = [ "futures-channel", "futures-core", @@ -1038,9 +1038,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed" +checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" dependencies = [ "futures-core", "futures-sink", @@ -1048,15 +1048,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac" +checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" [[package]] name = "futures-executor" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2" +checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e" dependencies = [ "futures-core", "futures-task", @@ -1065,15 +1065,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" +checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" [[package]] name = "futures-macro" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d" +checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" dependencies = [ "proc-macro2", "quote", @@ -1082,21 +1082,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9" +checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" [[package]] name = "futures-task" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea" +checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" [[package]] name = "futures-util" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6" +checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" dependencies = [ "futures-channel", "futures-core", @@ -1183,9 +1183,9 @@ dependencies = [ [[package]] name = "heck" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" @@ -1247,9 +1247,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.23" +version = "0.14.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "034711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3c" +checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c" dependencies = [ "bytes", "futures-channel", @@ -1343,12 +1343,12 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "io-lifetimes" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7d6c6f8c91b4b9ed43484ad1a938e393caf35960fce7f82a040497207bd8e9e" +checksum = "1abeb7a0dd0f8181267ff8adc397075586500b81b28a73e8a0208b00fc170fb3" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -1389,9 +1389,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49409df3e3bf0856b916e2ceaca09ee28e6871cf7d9ce97a692cacfdb2a25a47" +checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" dependencies = [ "wasm-bindgen", ] @@ -1601,7 +1601,7 @@ dependencies = [ "libc", "log", "wasi", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -1773,15 +1773,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.6" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba1ef8814b5c993410bb3adfad7a5ed269563e4a2f90c41f5d85be7fb47133bf" +checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -1895,9 +1895,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.50" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ef7d57beacfaf2d8aee5937dab7b7f28de3cb8b1828479bb5de2a7106f2bae2" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" dependencies = [ "unicode-ident", ] @@ -2080,16 +2080,16 @@ dependencies = [ [[package]] name = "rustix" -version = "0.36.7" +version = "0.36.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4fdebc4b395b7fbb9ab11e462e20ed9051e7b16e42d24042c776eca0ac81b03" +checksum = "f43abb88211988493c1abb44a70efa56ff0ce98f233b7b276146f1f3f7ba9644" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -2213,9 +2213,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.91" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883" +checksum = "7434af0dc1cbd59268aa98b4c22c131c0584d2232f6fb166efb993e2832e896a" dependencies = [ "itoa 1.0.5", "ryu", @@ -2459,9 +2459,9 @@ dependencies = [ [[package]] name = "tinyvec_macros" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" @@ -2479,7 +2479,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -2612,9 +2612,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fdbf052a0783de01e944a6ce7a8cb939e295b1e7be835a1112c3b9a7f047a5a" +checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" [[package]] name = "unicode-width" @@ -2647,9 +2647,9 @@ checksum = "936e4b492acfd135421d8dca4b1aa80a7bfc26e702ef3af710e0752684df5372" [[package]] name = "uuid" -version = "1.2.2" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "422ee0de9031b5b948b97a8fc04e3aa35230001a722ddd27943e0be31564ce4c" +checksum = "1674845326ee10d37ca60470760d4288a6f80f304007d92e5c53bab78c9cfd79" dependencies = [ "getrandom", ] @@ -2689,9 +2689,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268" +checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -2699,9 +2699,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8ffb332579b0557b52d268b91feab8df3615f265d5270fec2a8c95b17c1142" +checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" dependencies = [ "bumpalo", "log", @@ -2714,9 +2714,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.33" +version = "0.4.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23639446165ca5a5de86ae1d8896b737ae80319560fbaa4c2887b7da6e7ebd7d" +checksum = "f219e0d211ba40266969f6dbdd90636da12f75bee4fc9d6c23d1260dadb51454" dependencies = [ "cfg-if", "js-sys", @@ -2726,9 +2726,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "052be0f94026e6cbc75cdefc9bae13fd6052cdcaf532fa6c45e7ae33a1e6c810" +checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2736,9 +2736,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" +checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", @@ -2749,9 +2749,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" +checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" [[package]] name = "wasm-streams" @@ -2768,9 +2768,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcda906d8be16e728fd5adc5b729afad4e444e106ab28cd1c7256e54fa61510f" +checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97" dependencies = [ "js-sys", "wasm-bindgen", @@ -2841,6 +2841,30 @@ dependencies = [ "windows_x86_64_msvc", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.1" @@ -2903,18 +2927,18 @@ dependencies = [ [[package]] name = "zstd" -version = "0.12.2+zstd.1.5.2" +version = "0.12.3+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9262a83dc741c0b0ffec209881b45dbc232c21b02a2b9cb1adb93266e41303d" +checksum = "76eea132fb024e0e13fd9c2f5d5d595d8a967aa72382ac2f9d39fcc95afd0806" dependencies = [ "zstd-safe", ] [[package]] name = "zstd-safe" -version = "6.0.2+zstd.1.5.2" +version = "6.0.3+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6cf39f730b440bab43da8fb5faf5f254574462f73f260f85f7987f32154ff17" +checksum = "68e4a3f57d13d0ab7e478665c60f35e2a613dcd527851c2c7287ce5c787e134a" dependencies = [ "libc", "zstd-sys", @@ -2922,9 +2946,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.5+zstd.1.5.2" +version = "2.0.6+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edc50ffce891ad571e9f9afe5039c4837bede781ac4bb13052ed7ae695518596" +checksum = "68a3f9792c0c3dc6c165840a75f47ae1f4da402c2d006881129579f6597e801b" dependencies = [ "cc", "libc", diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 11e30142ade9..ba5de32c30b3 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -98,7 +98,7 @@ pub fn update_hash( Ok(()) } -// Get build and probe indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join +// Get build and probe indices which satisfies the on condition (include equal_conditon and filter_in_join) in the Join #[allow(clippy::too_many_arguments)] pub fn build_join_indices( probe_batch: &RecordBatch, @@ -113,7 +113,7 @@ pub fn build_join_indices( offset: Option, build_side: JoinSide, ) -> Result<(UInt64Array, UInt32Array)> { - // Get the indices which is satisfies the equal join condition, like `left.a1 = right.a2` + // Get the indices which satisfies the equal join condition, like `left.a1 = right.a2` let (build_indices, probe_indices) = build_equal_condition_join_indices( build_hashmap, build_input_buffer, @@ -126,7 +126,7 @@ pub fn build_join_indices( offset, )?; if let Some(filter) = filter { - // Filter the indices which is satisfies the non-equal join condition, like `left.b1 = 10` + // Filter the indices which satisfies the non-equal join condition, like `left.b1 = 10` apply_join_filter_to_indices( build_input_buffer, probe_batch, @@ -845,66 +845,50 @@ impl SortedFilterExpr { /// Filter expr for a + b > c + 10 AND a + b < c + 100 #[allow(dead_code)] pub(crate) fn complicated_filter() -> Arc { - let left_expr = BinaryExpr { - left: Arc::new(CastExpr { - expr: Arc::new(BinaryExpr { - left: Arc::new(Column { - name: "0".to_string(), - index: 0, - }), - op: Operator::Plus, - right: Arc::new(Column { - name: "1".to_string(), - index: 1, - }), - }), - cast_type: DataType::Int64, - cast_options: CastOptions { safe: false }, - }), - op: Operator::Gt, - right: Arc::new(BinaryExpr { - left: Arc::new(CastExpr { - expr: Arc::new(Column { - name: "2".to_string(), - index: 2, - }), - cast_type: DataType::Int64, - cast_options: CastOptions { safe: false }, - }), - op: Operator::Plus, - right: Arc::new(Literal::new(ScalarValue::Int64(Some(10)))), - }), - }; - let right_expr = BinaryExpr { - left: Arc::new(CastExpr { - expr: Arc::new(BinaryExpr { - left: Arc::new(Column { - name: "0".to_string(), - index: 0, - }), - op: Operator::Plus, - right: Arc::new(Column { - name: "1".to_string(), - index: 1, - }), - }), - cast_type: DataType::Int64, - cast_options: CastOptions { safe: false }, - }), - op: Operator::Lt, - right: Arc::new(BinaryExpr { - left: Arc::new(CastExpr { - expr: Arc::new(Column { - name: "2".to_string(), - index: 2, - }), - cast_type: DataType::Int64, - cast_options: CastOptions { safe: false }, - }), - op: Operator::Plus, - right: Arc::new(Literal::new(ScalarValue::Int64(Some(100)))), - }), - }; + let left_expr = BinaryExpr::new( + Arc::new(CastExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("0", 0)), + Operator::Plus, + Arc::new(Column::new("1", 1)), + )), + DataType::Int64, + CastOptions { safe: false }, + )), + Operator::Gt, + Arc::new(BinaryExpr::new( + Arc::new(CastExpr::new( + Arc::new(Column::new("2", 2)), + DataType::Int64, + CastOptions { safe: false }, + )), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int64(Some(10)))), + )), + ); + + let right_expr = BinaryExpr::new( + Arc::new(CastExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("0", 0)), + Operator::Plus, + Arc::new(Column::new("1", 1)), + )), + DataType::Int64, + CastOptions { safe: false }, + )), + Operator::Lt, + Arc::new(BinaryExpr::new( + Arc::new(CastExpr::new( + Arc::new(Column::new("2", 2)), + DataType::Int64, + CastOptions { safe: false }, + )), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int64(Some(100)))), + )), + ); + Arc::new(BinaryExpr::new( Arc::new(left_expr), Operator::And, @@ -948,21 +932,15 @@ mod tests { assert!(contains2); let mut contains3 = false; - let expr_3 = Arc::new(CastExpr { - expr: Arc::new(BinaryExpr { - left: Arc::new(Column { - name: "0".to_string(), - index: 0, - }), - op: Operator::Plus, - right: Arc::new(Column { - name: "1".to_string(), - index: 1, - }), - }), - cast_type: DataType::Int64, - cast_options: CastOptions { safe: false }, - }); + let expr_3 = Arc::new(CastExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("0", 0)), + Operator::Plus, + Arc::new(Column::new("1", 1)), + )), + DataType::Int64, + CastOptions { safe: false }, + )); filter_expr .clone() .accept(CheckFilterExprContainsSortInformation::new( @@ -1096,11 +1074,11 @@ mod tests { Field::new(filter_col_1.name(), DataType::Int32, true), ]); - let filter_expr = Arc::new(BinaryExpr { - left: Arc::new(Column::new("0", 0)), - op: Operator::Minus, - right: Arc::new(Column::new("1", 1)), - }); + let filter_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("0", 0)), + Operator::Minus, + Arc::new(Column::new("1", 1)), + )); let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); @@ -1110,11 +1088,11 @@ mod tests { ])); let sorted = PhysicalSortExpr { - expr: Arc::new(BinaryExpr { - left: Arc::new(Column::new("a", 0)), - op: Operator::Plus, - right: Arc::new(Column::new("b", 1)), - }), + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )), options: SortOptions::default(), }; diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 97d5f61135b1..8b3dabe80135 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -967,7 +967,6 @@ impl OneSideHashJoiner { /// /// # Arguments /// - /// * `on_build` - A slice of Columns that defines the join key of the OneSideHashJoiner /// * `batch` - The incoming RecordBatch to be merged with the internal input buffer /// * `random_state` - The random state used to hash values /// @@ -1262,7 +1261,7 @@ fn combine_two_batches_result( (Ok(Some(left_batch)), Ok(Some(right_batch))) => { // If both left_batch and right_batch are present, concatenate them concat_batches(&output_schema, &[left_batch, right_batch]) - .map_err(|e| DataFusionError::ArrowError(e)) + .map_err(DataFusionError::ArrowError) } (Ok(None), Ok(None)) => { // If both left_batch and right_batch are not present, return an empty batch @@ -1928,11 +1927,11 @@ mod tests { let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, cardinality)?; let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(BinaryExpr { - left: Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), - op: Operator::Plus, - right: Arc::new(Column::new_with_schema("la2", &left_batch.schema())?), - }), + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), + Operator::Plus, + Arc::new(Column::new_with_schema("la2", &left_batch.schema())?), + )), options: SortOptions::default(), }]; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index f4d08adb8182..b2023574848d 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -84,9 +84,9 @@ use datafusion_expr::{ColumnarValue, Operator}; /// Binary expression #[derive(Debug)] pub struct BinaryExpr { - pub left: Arc, - pub op: Operator, - pub right: Arc, + left: Arc, + op: Operator, + right: Arc, } impl BinaryExpr { diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index f8d0ba28a0c1..975ad3c6cd1b 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -38,11 +38,11 @@ pub const DEFAULT_DATAFUSION_CAST_OPTIONS: CastOptions = CastOptions { safe: fal #[derive(Debug)] pub struct CastExpr { /// The expression to cast - pub expr: Arc, + expr: Arc, /// The data type to cast to - pub cast_type: DataType, + cast_type: DataType, /// Cast options - pub cast_options: CastOptions, + cast_options: CastOptions, } impl CastExpr { diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index c7e153da6867..eb2be5ef217c 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -33,8 +33,8 @@ use datafusion_expr::ColumnarValue; /// Represents the column at a given index in a RecordBatch #[derive(Debug, Hash, PartialEq, Eq, Clone)] pub struct Column { - pub name: String, - pub index: usize, + name: String, + index: usize, } impl Column { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index e1e53ba8ad19..210b4a451877 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -536,6 +536,10 @@ impl ExprIntervalGraph { upper: ScalarValue::Boolean(Some(true)), .. } => self.propagate_constraints(expr_stats), + Interval { + lower: ScalarValue::Boolean(Some(false)), + .. + } => Ok(PropagationResult::Infeasible), _ => Ok(PropagationResult::CannotPropagate), }) } @@ -696,7 +700,7 @@ mod tests { (Some(100), None), (Some(10), Some(20)), (Some(100), None), - PropagationResult::CannotPropagate, + PropagationResult::Infeasible, )?; Ok(()) } From f333d6acae65bc3ba1ce8f533cdcef39eb9cb384 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Tue, 7 Feb 2023 18:38:38 -0600 Subject: [PATCH 17/39] Remove support indicators to interval library, rename module to use the standard name --- .../src/physical_optimizer/pipeline_fixer.rs | 30 +------------------ .../physical_plan/joins/hash_join_utils.rs | 2 +- .../joins/symmetric_hash_join.rs | 2 +- .../physical-expr/src/intervals/cp_solver.rs | 5 ++-- ...al_aritmetics.rs => interval_aritmetic.rs} | 29 +++++++++++++++++- datafusion/physical-expr/src/intervals/mod.rs | 3 +- 6 files changed, 35 insertions(+), 36 deletions(-) rename datafusion/physical-expr/src/intervals/{interval_aritmetics.rs => interval_aritmetic.rs} (95%) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 2c35634a011c..1adf6796c691 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -35,11 +35,10 @@ use crate::physical_plan::joins::{ }; use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::ExecutionPlan; -use arrow::datatypes::DataType; use datafusion_common::DataFusionError; use datafusion_expr::logical_plan::JoinType; -use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; +use datafusion_physical_expr::intervals::{is_datatype_supported, is_operator_supported}; use datafusion_physical_expr::physical_expr_visitor::{ PhysicalExprVisitable, PhysicalExpressionVisitor, Recursion, }; @@ -91,33 +90,6 @@ impl PhysicalOptimizerRule for PipelineFixer { } } -/// Interval calculations is supported by these operators. This will extend in the future. -fn is_operator_supported(op: &Operator) -> bool { - matches!( - op, - &Operator::Plus - | &Operator::Minus - | &Operator::And - | &Operator::Gt - | &Operator::Lt - ) -} - -/// Currently, integer data types are supported. -fn is_datatype_supported(data_type: &DataType) -> bool { - matches!( - data_type, - &DataType::Int64 - | &DataType::Int32 - | &DataType::Int16 - | &DataType::Int8 - | &DataType::UInt64 - | &DataType::UInt32 - | &DataType::UInt16 - | &DataType::UInt8 - ) -} - /// We do not support every type of [PhysicalExpr] for interval calculation. Also, we are not /// supporting every type of [Operator]s in [BinaryExpr]. This check is subject to change while /// we are adding additional [PhysicalExpr] and [Operator] support. diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index ba5de32c30b3..6fd4e6a22a14 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -47,7 +47,7 @@ use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; use datafusion_physical_expr::hash_utils::create_hashes; -use datafusion_physical_expr::intervals::interval_aritmetics::Interval; +use datafusion_physical_expr::intervals::Interval; use datafusion_physical_expr::physical_expr_visitor::{ PhysicalExprVisitable, PhysicalExpressionVisitor, Recursion, }; diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 8b3dabe80135..0bf5fc4da8f8 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -39,8 +39,8 @@ use hashbrown::HashSet; use datafusion_common::utils::bisect; use datafusion_common::ScalarValue; -use datafusion_physical_expr::intervals::interval_aritmetics::Interval; use datafusion_physical_expr::intervals::ExprIntervalGraph; +use datafusion_physical_expr::intervals::Interval; use crate::arrow::array::BooleanBufferBuilder; use crate::error::{DataFusionError, Result}; diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 210b4a451877..85845cb3224d 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -30,7 +30,7 @@ use petgraph::visit::{Bfs, DfsPostOrder}; use petgraph::Outgoing; use crate::expressions::{BinaryExpr, CastExpr, Literal}; -use crate::intervals::interval_aritmetics::{apply_operator, Interval}; +use crate::intervals::interval_aritmetic::{apply_operator, Interval}; use crate::utils::{build_physical_expr_graph, ExprTreeNode}; use crate::PhysicalExpr; @@ -375,8 +375,7 @@ impl ExprIntervalGraph { /// use datafusion_common::ScalarValue; /// use datafusion_expr::Operator; /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; - /// use datafusion_physical_expr::intervals::ExprIntervalGraph; - /// use datafusion_physical_expr::intervals::interval_aritmetics::Interval; + /// use datafusion_physical_expr::intervals::{Interval, ExprIntervalGraph}; /// use datafusion_physical_expr::PhysicalExpr; /// let expr = Arc::new(BinaryExpr::new( /// Arc::new(Column::new("gnz", 0)), diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs similarity index 95% rename from datafusion/physical-expr/src/intervals/interval_aritmetics.rs rename to datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 261659fabc3e..7fc3641b25ef 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetics.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -234,6 +234,33 @@ impl Interval { } } +/// Indicates whether interval arithmetic is supported for the given operator. +pub fn is_operator_supported(op: &Operator) -> bool { + matches!( + op, + &Operator::Plus + | &Operator::Minus + | &Operator::And + | &Operator::Gt + | &Operator::Lt + ) +} + +/// Indicates whether interval arithmetic is supported for the given data type. +pub fn is_datatype_supported(data_type: &DataType) -> bool { + matches!( + data_type, + &DataType::Int64 + | &DataType::Int32 + | &DataType::Int16 + | &DataType::Int8 + | &DataType::UInt64 + | &DataType::UInt32 + | &DataType::UInt16 + | &DataType::UInt8 + ) +} + pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { match *op { Operator::Eq => Ok(lhs.equal(rhs)), @@ -261,7 +288,7 @@ fn cast_scalar_value( #[cfg(test)] mod tests { - use crate::intervals::interval_aritmetics::Interval; + use crate::intervals::Interval; use datafusion_common::{Result, ScalarValue}; #[test] diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs index f35be672a3b5..60fb5aa9777f 100644 --- a/datafusion/physical-expr/src/intervals/mod.rs +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -19,6 +19,7 @@ //! mod cp_solver; -pub mod interval_aritmetics; +mod interval_aritmetic; pub use cp_solver::ExprIntervalGraph; +pub use interval_aritmetic::*; From 34c71e8b5d5affef16e7639e50ba14139035662c Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Tue, 7 Feb 2023 20:00:58 -0600 Subject: [PATCH 18/39] Simplify PipelineFixer, remove clones, improve comments --- .../src/physical_optimizer/pipeline_fixer.rs | 180 +++++++++--------- 1 file changed, 85 insertions(+), 95 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 1adf6796c691..7e765742f22b 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -29,6 +29,7 @@ use crate::physical_optimizer::pipeline_checker::{ check_finiteness_requirements, PipelineStatePropagator, }; use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::joins::utils::JoinSide; use crate::physical_plan::joins::{ convert_sort_expr_with_filter_schema, HashJoinExec, PartitionMode, SymmetricHashJoinExec, @@ -44,7 +45,6 @@ use datafusion_physical_expr::physical_expr_visitor::{ }; use datafusion_physical_expr::PhysicalExpr; -use crate::physical_plan::joins::utils::JoinSide; use std::sync::Arc; /// The [PipelineFixer] rule tries to modify a given plan so that it can @@ -60,7 +60,7 @@ impl PipelineFixer { } } type PipelineFixerSubrule = - dyn Fn(&PipelineStatePropagator) -> Option>; + dyn Fn(PipelineStatePropagator) -> Option>; impl PhysicalOptimizerRule for PipelineFixer { fn optimize( &self, @@ -90,13 +90,13 @@ impl PhysicalOptimizerRule for PipelineFixer { } } -/// We do not support every type of [PhysicalExpr] for interval calculation. Also, we are not -/// supporting every type of [Operator]s in [BinaryExpr]. This check is subject to change while -/// we are adding additional [PhysicalExpr] and [Operator] support. -/// -/// We support [CastExpr], [BinaryExpr], [Column], and [Literal] for interval calculations. +/// This object tracks the support state for [PhysicalExpr]s in the plan. +/// Currently, we do not support all [PhysicalExpr]s for interval calculations. +/// We do not support every type of [Operator]s either. Over time, this check +/// will relax as more types of [PhysicalExpr]s and [Operator]s are supported. +/// Currently, [CastExpr], [BinaryExpr], [Column] and [Literal] is supported. #[derive(Debug)] -pub struct UnsupportedPhysicalExprVisitor<'a> { +struct UnsupportedPhysicalExprVisitor<'a> { /// Supported state pub supported: &'a mut bool, } @@ -104,8 +104,7 @@ pub struct UnsupportedPhysicalExprVisitor<'a> { impl PhysicalExpressionVisitor for UnsupportedPhysicalExprVisitor<'_> { fn pre_visit(self, expr: Arc) -> Result> { if let Some(binary_expr) = expr.as_any().downcast_ref::() { - let op_is_supported = is_operator_supported(binary_expr.op()); - *self.supported = *self.supported && op_is_supported; + *self.supported &= is_operator_supported(binary_expr.op()); } else if !(expr.as_any().is::() || expr.as_any().is::() || expr.as_any().is::()) @@ -116,95 +115,87 @@ impl PhysicalExpressionVisitor for UnsupportedPhysicalExprVisitor<'_> { } } -// If it does not contain any unsupported [Operator]. -// Sorting information must cover every column in filter expression. -// Float is unsupported -// Anything beside than BinaryExpr, Col, Literal and operations Ge and Le is currently unsupported. +/// This function returns whether a given hash join is replacable by a +/// symmetric hash join. Basically, the requirement is that involved +/// [PhysicalExpr]s, [Operator]s and data types need to be supported, +/// and order information must cover every column in the filter expression. fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result { - let are_children_ordered = hash_join.left().output_ordering().is_some() - && hash_join.right().output_ordering().is_some(); - let is_filter_supported = match hash_join.filter() { - None => false, - Some(filter) => { - let expr: Arc = filter.expression().clone(); - let mut is_expr_supported = true; - expr.accept(UnsupportedPhysicalExprVisitor { - supported: &mut is_expr_supported, - })?; - let left_is_convertable = convert_sort_expr_with_filter_schema( - &JoinSide::Left, - filter, - hash_join.left().schema(), - hash_join - .left() - .output_ordering() - .as_ref() - .unwrap() - .get(0) - .unwrap(), - )? - .is_some(); - let right_is_convertable = convert_sort_expr_with_filter_schema( - &JoinSide::Right, - filter, - hash_join.right().schema(), - hash_join - .right() - .output_ordering() - .as_ref() - .unwrap() - .get(0) - .unwrap(), - )? - .is_some(); - let sort_expressions_supported = left_is_convertable && right_is_convertable; - let is_fields_supported = filter - .schema() - .fields() - .iter() - .all(|f| is_datatype_supported(f.data_type())); - is_expr_supported && is_fields_supported && sort_expressions_supported + if let Some(filter) = hash_join.filter() { + let left = hash_join.left(); + if let Some(left_ordering) = left.output_ordering() { + let right = hash_join.right(); + if let Some(right_ordering) = right.output_ordering() { + let expr: Arc = filter.expression().clone(); + let mut expr_supported = true; + expr.accept(UnsupportedPhysicalExprVisitor { + supported: &mut expr_supported, + })?; + let left_convertible = convert_sort_expr_with_filter_schema( + &JoinSide::Left, + filter, + left.schema(), + &left_ordering[0], + )? + .is_some(); + let right_convertible = convert_sort_expr_with_filter_schema( + &JoinSide::Right, + filter, + hash_join.right().schema(), + &right_ordering[0], + )? + .is_some(); + let fields_supported = filter + .schema() + .fields() + .iter() + .all(|f| is_datatype_supported(f.data_type())); + return Ok(expr_supported + && fields_supported + && left_convertible + && right_convertible); + } } - }; - Ok(are_children_ordered && is_filter_supported) + } + Ok(false) } +/// This subrule checks if one can replace a hash join with a symmetric hash +/// join so that the pipeline does not break due to the join operation in +/// question. If possible, it makes this replacement; otherwise, it has no +/// effect. fn hash_join_convert_symmetric_subrule( - input: &PipelineStatePropagator, + input: PipelineStatePropagator, ) -> Option> { - let plan = input.plan.clone(); - let children = &input.children_unbounded; + let plan = input.plan; if let Some(hash_join) = plan.as_any().downcast_ref::() { - let (left_unbounded, right_unbounded) = (children[0], children[1]); + let ub_flags = input.children_unbounded; + let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); let new_plan = if left_unbounded && right_unbounded { - Ok(match is_suitable_for_symmetric_hash_join(hash_join) { - Ok(true) => Arc::new( - SymmetricHashJoinExec::try_new( - hash_join.left().clone(), - hash_join.right().clone(), - hash_join - .on() - .iter() - .map(|(l, r)| (l.clone(), r.clone())) - .collect(), - hash_join.filter().unwrap().clone(), - hash_join.join_type(), - hash_join.null_equals_null(), - ) - .unwrap(), - ), - Ok(false) => plan, + match is_suitable_for_symmetric_hash_join(hash_join) { + Ok(true) => SymmetricHashJoinExec::try_new( + hash_join.left().clone(), + hash_join.right().clone(), + hash_join + .on() + .iter() + .map(|(l, r)| (l.clone(), r.clone())) + .collect(), + hash_join.filter().unwrap().clone(), + hash_join.join_type(), + hash_join.null_equals_null(), + ) + .map(|e| Arc::new(e) as _), + Ok(false) => Ok(plan), Err(e) => return Some(Err(e)), - }) + } } else { Ok(plan) }; - let new_state = new_plan.map(|plan| PipelineStatePropagator { + Some(new_plan.map(|plan| PipelineStatePropagator { plan, unbounded: left_unbounded || right_unbounded, - children_unbounded: vec![left_unbounded, right_unbounded], - }); - Some(new_state) + children_unbounded: ub_flags, + })) } else { None } @@ -252,12 +243,12 @@ fn hash_join_convert_symmetric_subrule( /// /// ``` fn hash_join_swap_subrule( - input: &PipelineStatePropagator, + input: PipelineStatePropagator, ) -> Option> { - let plan = input.plan.clone(); - let children = &input.children_unbounded; + let plan = input.plan; if let Some(hash_join) = plan.as_any().downcast_ref::() { - let (left_unbounded, right_unbounded) = (children[0], children[1]); + let ub_flags = input.children_unbounded; + let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); let new_plan = if left_unbounded && !right_unbounded { if matches!( *hash_join.join_type(), @@ -273,12 +264,11 @@ fn hash_join_swap_subrule( } else { Ok(plan) }; - let new_state = new_plan.map(|plan| PipelineStatePropagator { + Some(new_plan.map(|plan| PipelineStatePropagator { plan, unbounded: left_unbounded || right_unbounded, - children_unbounded: vec![left_unbounded, right_unbounded], - }); - Some(new_state) + children_unbounded: ub_flags, + })) } else { None } @@ -315,7 +305,7 @@ fn apply_subrules_and_check_finiteness_requirements( physical_optimizer_subrules: &Vec>, ) -> Result> { for sub_rule in physical_optimizer_subrules { - if let Some(value) = sub_rule(&input).transpose()? { + if let Some(value) = sub_rule(input.clone()).transpose()? { input = value; } } @@ -694,7 +684,7 @@ mod hash_join_tests { children_unbounded: vec![left_unbounded, right_unbounded], }; let optimized_hash_join = - hash_join_swap_subrule(&initial_hash_join_state).unwrap()?; + hash_join_swap_subrule(initial_hash_join_state).unwrap()?; let optimized_join_plan = optimized_hash_join.plan; // If swap did happen From 8ea04000eac828bebebb65489d1baaf2ccce83db Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Wed, 8 Feb 2023 15:00:06 +0300 Subject: [PATCH 19/39] Enhance the symmetric hash join code with reviews --- .../src/physical_optimizer/pipeline_fixer.rs | 25 ++-- .../physical_plan/joins/hash_join_utils.rs | 140 +++++++++--------- .../joins/symmetric_hash_join.rs | 104 +++++-------- .../core/src/physical_plan/joins/utils.rs | 10 ++ datafusion/core/src/physical_plan/memory.rs | 5 +- 5 files changed, 135 insertions(+), 149 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 7e765742f22b..0cb808d4a7ce 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -96,20 +96,26 @@ impl PhysicalOptimizerRule for PipelineFixer { /// will relax as more types of [PhysicalExpr]s and [Operator]s are supported. /// Currently, [CastExpr], [BinaryExpr], [Column] and [Literal] is supported. #[derive(Debug)] -struct UnsupportedPhysicalExprVisitor<'a> { +struct UnsupportedPhysicalExprVisitor { /// Supported state - pub supported: &'a mut bool, + pub supported: bool, } -impl PhysicalExpressionVisitor for UnsupportedPhysicalExprVisitor<'_> { - fn pre_visit(self, expr: Arc) -> Result> { +impl UnsupportedPhysicalExprVisitor { + pub fn new() -> Self { + Self { supported: true } + } +} + +impl PhysicalExpressionVisitor for UnsupportedPhysicalExprVisitor { + fn pre_visit(mut self, expr: Arc) -> Result> { if let Some(binary_expr) = expr.as_any().downcast_ref::() { - *self.supported &= is_operator_supported(binary_expr.op()); + self.supported &= is_operator_supported(binary_expr.op()); } else if !(expr.as_any().is::() || expr.as_any().is::() || expr.as_any().is::()) { - *self.supported = false; + self.supported = false; } Ok(Recursion::Continue(self)) } @@ -126,10 +132,9 @@ fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result let right = hash_join.right(); if let Some(right_ordering) = right.output_ordering() { let expr: Arc = filter.expression().clone(); - let mut expr_supported = true; - expr.accept(UnsupportedPhysicalExprVisitor { - supported: &mut expr_supported, - })?; + let expr_supported = expr + .accept(UnsupportedPhysicalExprVisitor::new())? + .supported; let left_convertible = convert_sort_expr_with_filter_schema( &JoinSide::Left, filter, diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 6fd4e6a22a14..c516092f4c51 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -596,25 +596,25 @@ impl fmt::Debug for JoinHashMap { } #[derive(Debug)] -pub struct CheckFilterExprContainsSortInformation<'a> { +pub struct CheckFilterExprContainsSortInformation { /// Supported state - pub contains: &'a mut bool, + pub contains: bool, pub checked_expr: Arc, } -impl<'a> CheckFilterExprContainsSortInformation<'a> { - pub fn new(contains: &'a mut bool, checked_expr: Arc) -> Self { +impl CheckFilterExprContainsSortInformation { + pub fn new(checked_expr: Arc) -> Self { Self { - contains, + contains: false, checked_expr, } } } -impl PhysicalExpressionVisitor for CheckFilterExprContainsSortInformation<'_> { - fn pre_visit(self, expr: Arc) -> Result> { - *self.contains = *self.contains || self.checked_expr.eq(&expr); - if *self.contains { +impl PhysicalExpressionVisitor for CheckFilterExprContainsSortInformation { + fn pre_visit(mut self, expr: Arc) -> Result> { + self.contains |= self.checked_expr.eq(&expr); + if self.contains { Ok(Recursion::Stop(self)) } else { Ok(Recursion::Continue(self)) @@ -623,18 +623,18 @@ impl PhysicalExpressionVisitor for CheckFilterExprContainsSortInformation<'_> { } #[derive(Debug)] -pub struct PhysicalExprColumnCollector<'a> { - pub columns: &'a mut Vec, +pub struct PhysicalExprColumnCollector { + pub columns: Vec, } -impl<'a> PhysicalExprColumnCollector<'a> { - pub fn new(columns: &'a mut Vec) -> Self { - Self { columns } +impl PhysicalExprColumnCollector { + pub fn new() -> Self { + Self { columns: vec![] } } } -impl PhysicalExpressionVisitor for PhysicalExprColumnCollector<'_> { - fn pre_visit(self, expr: Arc) -> Result> { +impl PhysicalExpressionVisitor for PhysicalExprColumnCollector { + fn pre_visit(mut self, expr: Arc) -> Result> { if let Some(column) = expr.as_any().downcast_ref::() { self.columns.push(column.clone()) } @@ -700,9 +700,10 @@ pub fn convert_sort_expr_with_filter_schema( map_origin_col_to_filter_col(filter, schema, side)?; let expr = sort_expr.expr.clone(); // Get main schema columns - let mut expr_columns = vec![]; - expr.clone() - .accept(PhysicalExprColumnCollector::new(&mut expr_columns))?; + let expr_columns = expr + .clone() + .accept(PhysicalExprColumnCollector::new())? + .columns; // Calculation is possible with 'column_mapping_information' since sort exprs belong to a child. let all_columns_are_included = expr_columns .iter() @@ -712,16 +713,17 @@ pub fn convert_sort_expr_with_filter_schema( // the sort expression into a filter expression. let converted_filter_expr = expr .transform_up(&|p| convert_filter_columns(p, &column_mapping_information))?; - let mut contains = false; + // Search converted PhysicalExpr in filter expression - filter.expression().clone().accept( - CheckFilterExprContainsSortInformation::new( - &mut contains, - converted_filter_expr.clone(), - ), - )?; // If the exact match is encountered, use this sorted expression in graph traversals. - if contains { + if filter + .expression() + .clone() + .accept(CheckFilterExprContainsSortInformation::new( + converted_filter_expr.clone(), + ))? + .contains + { Ok(Some(converted_filter_expr)) } else { Ok(None) @@ -737,7 +739,7 @@ pub fn convert_sort_expr_with_filter_schema( /// order of columns can be used in the filter expression. /// If it returns a Some value, the method wraps the result in a SortedFilterExpr /// instance with the original sort expression, converted filter expression, and sort options. -/// If it returns a None value, this function returns None. +/// If it returns a None value, this function returns an error. /// /// The SortedFilterExpr instance contains information about the sort order of columns /// that can be used in the filter expression, which can be used to optimize the query execution process. @@ -746,15 +748,17 @@ pub fn build_filter_input_order_v2( filter: &JoinFilter, schema: SchemaRef, order: &PhysicalSortExpr, -) -> Result> { +) -> Result { match convert_sort_expr_with_filter_schema(&side, filter, schema, order)? { - Some(expr) => Ok(Some(SortedFilterExpr::new( + Some(expr) => Ok(SortedFilterExpr::new( side, order.expr.clone(), expr, order.options, + )), + None => Err(DataFusionError::Plan(format!( + "The {side} side of the join does not have an expression sorted." ))), - None => Ok(None), } } @@ -911,27 +915,23 @@ mod tests { fn find_expr_inside_expr() -> Result<()> { let filter_expr = complicated_filter(); - let mut contains1 = false; let expr_1 = Arc::new(Column::new("gnz", 0)); - filter_expr - .clone() - .accept(CheckFilterExprContainsSortInformation::new( - &mut contains1, - expr_1, - ))?; - assert!(!contains1); + assert!( + !filter_expr + .clone() + .accept(CheckFilterExprContainsSortInformation::new(expr_1,))? + .contains + ); - let mut contains2 = false; let expr_2 = Arc::new(Column::new("1", 1)); - filter_expr - .clone() - .accept(CheckFilterExprContainsSortInformation::new( - &mut contains2, - expr_2, - ))?; - assert!(contains2); - let mut contains3 = false; + assert!( + filter_expr + .clone() + .accept(CheckFilterExprContainsSortInformation::new(expr_2,))? + .contains + ); + let expr_3 = Arc::new(CastExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("0", 0)), @@ -941,24 +941,22 @@ mod tests { DataType::Int64, CastOptions { safe: false }, )); - filter_expr - .clone() - .accept(CheckFilterExprContainsSortInformation::new( - &mut contains3, - expr_3, - ))?; - assert!(contains3); - let mut contains4 = false; + assert!( + filter_expr + .clone() + .accept(CheckFilterExprContainsSortInformation::new(expr_3,))? + .contains + ); + let expr_4 = Arc::new(Column::new("1", 42)); - filter_expr - .clone() - .accept(CheckFilterExprContainsSortInformation::new( - &mut contains4, - expr_4, - ))?; - assert!(!contains4); + assert!( + !filter_expr + .clone() + .accept(CheckFilterExprContainsSortInformation::new(expr_4,))? + .contains + ); Ok(()) } @@ -1018,8 +1016,8 @@ mod tests { expr: Arc::new(Column::new("la1", 0)), options: SortOptions::default(), } - )? - .is_some()); + ) + .is_ok()); assert!(build_filter_input_order_v2( JoinSide::Left, &filter, @@ -1028,8 +1026,8 @@ mod tests { expr: Arc::new(Column::new("lt1", 3)), options: SortOptions::default(), } - )? - .is_none()); + ) + .is_err()); assert!(build_filter_input_order_v2( JoinSide::Right, &filter, @@ -1038,8 +1036,8 @@ mod tests { expr: Arc::new(Column::new("ra1", 0)), options: SortOptions::default(), } - )? - .is_some()); + ) + .is_ok()); assert!(build_filter_input_order_v2( JoinSide::Right, &filter, @@ -1048,8 +1046,8 @@ mod tests { expr: Arc::new(Column::new("rb1", 1)), options: SortOptions::default(), } - )? - .is_none()); + ) + .is_err()); Ok(()) } diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 0bf5fc4da8f8..b31bc81b84ee 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -271,35 +271,19 @@ impl SymmetricHashJoinExec { ExprIntervalGraph::try_new(filter.expression().clone())?; // Build the sorted filter expression for the left child - let left_filter_expression = match build_filter_input_order_v2( + let left_filter_expression = build_filter_input_order_v2( JoinSide::Left, &filter, left.schema(), left.output_ordering().unwrap().get(0).unwrap(), - )? { - Some(value) => value, - None => { - return Err(DataFusionError::Plan( - "The left side of the join does not have an expression sorted." - .to_string(), - )) - } - }; + )?; // Build the sorted filter expression for the right child - let right_filter_expression = match build_filter_input_order_v2( + let right_filter_expression = build_filter_input_order_v2( JoinSide::Right, &filter, right.schema(), right.output_ordering().unwrap().get(0).unwrap(), - )? { - Some(value) => value, - None => { - return Err(DataFusionError::Plan( - "The right side of the join does not have an expression sorted." - .to_string(), - )) - } - }; + )?; // Store the left and right sorted filter expressions in a vector let mut filter_columns = vec![left_filter_expression, right_filter_expression]; @@ -995,11 +979,8 @@ impl OneSideHashJoiner { &mut self.hashes_buffer, )?; - // Drain the hashes buffer and add them to the row_hash_values deque - self.hashes_buffer - .drain(0..) - .into_iter() - .for_each(|hash| self.row_hash_values.push_back(hash)); + // Add the hashes buffer to the row_hash_values deque + self.row_hash_values.extend(self.hashes_buffer.iter()); Ok(()) } @@ -1246,26 +1227,25 @@ impl OneSideHashJoiner { fn combine_two_batches_result( output_schema: SchemaRef, - left_batch: Result>, - right_batch: Result>, -) -> Result { + left_batch: Option, + right_batch: Option, +) -> Result> { match (left_batch, right_batch) { - (Ok(Some(batch)), Ok(None)) | (Ok(None), Ok(Some(batch))) => { + (Some(batch), None) | (None, Some(batch)) => { // If either left_batch or right_batch is present, return that batch - Ok(batch) + Ok(Some(batch)) } - (Err(e), _) | (_, Err(e)) => { - // If either left_batch or right_batch returns an error, return the error - Err(e) - } - (Ok(Some(left_batch)), Ok(Some(right_batch))) => { + (Some(left_batch), Some(right_batch)) => { // If both left_batch and right_batch are present, concatenate them - concat_batches(&output_schema, &[left_batch, right_batch]) - .map_err(DataFusionError::ArrowError) + Some( + concat_batches(&output_schema, &[left_batch, right_batch]) + .map_err(DataFusionError::ArrowError), + ) + .transpose() } - (Ok(None), Ok(None)) => { + (None, None) => { // If both left_batch and right_batch are not present, return an empty batch - Ok(RecordBatch::new_empty(output_schema)) + Ok(None) } } } @@ -1293,7 +1273,7 @@ impl SymmetricHashJoinStream { self.right.input_buffer.schema(), self.join_type, &self.column_indices, - ); + )?; // Get right side results let right_result = self.right.build_side_determined_results( self.schema.clone(), @@ -1301,20 +1281,22 @@ impl SymmetricHashJoinStream { self.left.input_buffer.schema(), self.join_type, &self.column_indices, - ); + )?; self.final_result = true; // Combine results let result = combine_two_batches_result( self.schema.clone(), left_result, right_result, - ); - // Update the metrics - if let Ok(batch) = &result { + )?; + // Update the metrics if batch is some, if not, continue loop. + if let Some(batch) = &result { self.metrics.output_batches.add(1); self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Ok(result).transpose()); + } else { + continue; } - return Poll::Ready(Some(result)); } // Determine which stream should be polled next. The side the RecordBatch comes from becomes probe side. @@ -1349,16 +1331,8 @@ impl SymmetricHashJoinStream { probe_side_metrics.input_batches.add(1); probe_side_metrics.input_rows.add(probe_batch.num_rows()); // Update the internal state of the hash joiner for build side - match probe_hash_joiner - .update_internal_state(&probe_batch, &self.random_state) - { - Ok(_) => {} - Err(e) => { - return Poll::Ready(Some(Err(DataFusionError::Internal( - e.to_string(), - )))) - } - } + probe_hash_joiner + .update_internal_state(&probe_batch, &self.random_state)?; // Calculate filter intervals calculate_filter_expr_intervals( &build_hash_joiner.input_buffer, @@ -1378,7 +1352,7 @@ impl SymmetricHashJoinStream { &self.column_indices, &self.random_state, &self.null_equals_null, - ); + )?; // Increment the offset for the probe hash joiner probe_hash_joiner.offset += probe_batch.num_rows(); // Prune build input buffer using the physical expression graph and filter intervals @@ -1389,24 +1363,26 @@ impl SymmetricHashJoinStream { self.join_type, &self.column_indices, &mut self.physical_expr_graph, - ); + )?; // Combine results let result = combine_two_batches_result( self.schema.clone(), equal_result, anti_result, - ); - // Update the metrics - if let Ok(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - } + )?; // Choose next pool side. If other side is not exhausted, revert the probe side // before returning the result. if !build_hash_joiner.exhausted { self.probe_side = build_join_side; } - return Poll::Ready(Some(result)); + // Update the metrics if batch is some, if not, continue loop. + if let Some(batch) = &result { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Ok(result).transpose()); + } else { + continue; + } } Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Ready(None) => { diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index 1a21c8031cbd..5a573d30243e 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -29,6 +29,7 @@ use futures::{ready, FutureExt}; use parking_lot::Mutex; use std::cmp::max; use std::collections::HashSet; +use std::fmt::{Display, Formatter}; use std::future::Future; use std::sync::Arc; use std::task::{Context, Poll}; @@ -224,6 +225,15 @@ pub fn cross_join_equivalence_properties( new_properties } +impl Display for JoinSide { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + JoinSide::Left => write!(f, "Left"), + JoinSide::Right => write!(f, "Right"), + } + } +} + /// Used in ColumnIndex to distinguish which side the index is for #[derive(Debug, Clone, PartialEq, Copy)] pub enum JoinSide { diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/core/src/physical_plan/memory.rs index df7d0bd54290..6e2a0892508a 100644 --- a/datafusion/core/src/physical_plan/memory.rs +++ b/datafusion/core/src/physical_plan/memory.rs @@ -79,10 +79,7 @@ impl ExecutionPlan for MemoryExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - match self.sort_information.as_ref() { - None => None, - Some(sort) => Some(sort.as_slice()), - } + self.sort_information.as_deref() } fn with_new_children( From 76440a23a10a4a4d25ca6ff98c9ea41925230981 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Wed, 8 Feb 2023 17:47:24 +0300 Subject: [PATCH 20/39] Revamp according to reviews --- .../src/physical_optimizer/pipeline_fixer.rs | 16 ++++++++++------ .../physical_plan/joins/symmetric_hash_join.rs | 4 ---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 0cb808d4a7ce..37c3a3aadb40 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -90,6 +90,14 @@ impl PhysicalOptimizerRule for PipelineFixer { } } +/// Indicates whether interval arithmetic is supported for the given expression. +pub fn is_expr_supported(expr: &Arc) -> bool { + expr.as_any().is::() + || expr.as_any().is::() + || expr.as_any().is::() + || expr.as_any().is::() +} + /// This object tracks the support state for [PhysicalExpr]s in the plan. /// Currently, we do not support all [PhysicalExpr]s for interval calculations. /// We do not support every type of [Operator]s either. Over time, this check @@ -111,17 +119,13 @@ impl PhysicalExpressionVisitor for UnsupportedPhysicalExprVisitor { fn pre_visit(mut self, expr: Arc) -> Result> { if let Some(binary_expr) = expr.as_any().downcast_ref::() { self.supported &= is_operator_supported(binary_expr.op()); - } else if !(expr.as_any().is::() - || expr.as_any().is::() - || expr.as_any().is::()) - { - self.supported = false; } + self.supported &= is_expr_supported(&expr); Ok(Recursion::Continue(self)) } } -/// This function returns whether a given hash join is replacable by a +/// This function returns whether a given hash join is replaceable by a /// symmetric hash join. Basically, the requirement is that involved /// [PhysicalExpr]s, [Operator]s and data types need to be supported, /// and order information must cover every column in the filter expression. diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index b31bc81b84ee..ca5d48afc3d5 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -1380,8 +1380,6 @@ impl SymmetricHashJoinStream { self.metrics.output_batches.add(1); self.metrics.output_rows.add(batch.num_rows()); return Poll::Ready(Ok(result).transpose()); - } else { - continue; } } Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), @@ -1390,12 +1388,10 @@ impl SymmetricHashJoinStream { probe_hash_joiner.exhausted = true; // Change the probe side self.probe_side = build_join_side; - continue; } Poll::Pending => { if !build_hash_joiner.exhausted { self.probe_side = build_join_side; - continue; } else { return Poll::Pending; } From 9a66872856565e220dd84a9b899f7a198e504abf Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 8 Feb 2023 14:31:51 -0600 Subject: [PATCH 21/39] Use a simple, stateless, one-liner DFS to check for IA support --- .../src/physical_optimizer/pipeline_fixer.rs | 43 ++++--------------- 1 file changed, 9 insertions(+), 34 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 37c3a3aadb40..1d15f36f61cf 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -40,9 +40,6 @@ use datafusion_common::DataFusionError; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; use datafusion_physical_expr::intervals::{is_datatype_supported, is_operator_supported}; -use datafusion_physical_expr::physical_expr_visitor::{ - PhysicalExprVisitable, PhysicalExpressionVisitor, Recursion, -}; use datafusion_physical_expr::PhysicalExpr; use std::sync::Arc; @@ -91,37 +88,16 @@ impl PhysicalOptimizerRule for PipelineFixer { } /// Indicates whether interval arithmetic is supported for the given expression. -pub fn is_expr_supported(expr: &Arc) -> bool { - expr.as_any().is::() - || expr.as_any().is::() - || expr.as_any().is::() - || expr.as_any().is::() -} - -/// This object tracks the support state for [PhysicalExpr]s in the plan. /// Currently, we do not support all [PhysicalExpr]s for interval calculations. /// We do not support every type of [Operator]s either. Over time, this check /// will relax as more types of [PhysicalExpr]s and [Operator]s are supported. /// Currently, [CastExpr], [BinaryExpr], [Column] and [Literal] is supported. -#[derive(Debug)] -struct UnsupportedPhysicalExprVisitor { - /// Supported state - pub supported: bool, -} - -impl UnsupportedPhysicalExprVisitor { - pub fn new() -> Self { - Self { supported: true } - } -} - -impl PhysicalExpressionVisitor for UnsupportedPhysicalExprVisitor { - fn pre_visit(mut self, expr: Arc) -> Result> { - if let Some(binary_expr) = expr.as_any().downcast_ref::() { - self.supported &= is_operator_supported(binary_expr.op()); - } - self.supported &= is_expr_supported(&expr); - Ok(Recursion::Continue(self)) +fn check_support(expr: &Arc) -> bool { + let expr_any = expr.as_any(); + if let Some(binary_expr) = expr_any.downcast_ref::() { + is_operator_supported(binary_expr.op()) + } else { + expr_any.is::() || expr_any.is::() || expr_any.is::() } } @@ -135,10 +111,9 @@ fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result if let Some(left_ordering) = left.output_ordering() { let right = hash_join.right(); if let Some(right_ordering) = right.output_ordering() { - let expr: Arc = filter.expression().clone(); - let expr_supported = expr - .accept(UnsupportedPhysicalExprVisitor::new())? - .supported; + let expr = filter.expression(); + let expr_supported = + check_support(expr) && expr.children().iter().all(check_support); let left_convertible = convert_sort_expr_with_filter_schema( &JoinSide::Left, filter, From f9dfd77be39e6377222d8bba2f49cb7a2b5839a8 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 8 Feb 2023 16:10:34 -0600 Subject: [PATCH 22/39] Move test function to a test_utils module --- .../joins/symmetric_hash_join.rs | 32 ++++---- .../physical-expr/src/intervals/cp_solver.rs | 12 +-- datafusion/physical-expr/src/intervals/mod.rs | 1 + .../physical-expr/src/intervals/test_utils.rs | 67 ++++++++++++++++ datafusion/physical-expr/src/utils.rs | 76 +++++-------------- 5 files changed, 107 insertions(+), 81 deletions(-) create mode 100644 datafusion/physical-expr/src/intervals/test_utils.rs diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index ca5d48afc3d5..cf052963ef5a 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -1415,7 +1415,7 @@ mod tests { use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; - use datafusion_physical_expr::utils; + use datafusion_physical_expr::intervals::test_utils::gen_conjunctive_numeric_expr; use crate::physical_plan::collect; use crate::physical_plan::joins::hash_join_utils::complicated_filter; @@ -1594,9 +1594,9 @@ mod tests { ) -> Arc { match expr_id { // left_watermark + 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 - 0 => utils::filter_numeric_expr_generation( - left_watermark.clone(), - right_watermark.clone(), + 0 => gen_conjunctive_numeric_expr( + left_watermark, + right_watermark, Operator::Plus, Operator::Plus, Operator::Plus, @@ -1607,9 +1607,9 @@ mod tests { 10, ), // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 - 1 => utils::filter_numeric_expr_generation( - left_watermark.clone(), - right_watermark.clone(), + 1 => gen_conjunctive_numeric_expr( + left_watermark, + right_watermark, Operator::Minus, Operator::Plus, Operator::Plus, @@ -1620,9 +1620,9 @@ mod tests { 10, ), // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10 - 2 => utils::filter_numeric_expr_generation( - left_watermark.clone(), - right_watermark.clone(), + 2 => gen_conjunctive_numeric_expr( + left_watermark, + right_watermark, Operator::Minus, Operator::Plus, Operator::Minus, @@ -1633,9 +1633,9 @@ mod tests { 10, ), // left_watermark - 10 > right_watermark - 5 AND left_watermark - 3 < right_watermark + 10 - 3 => utils::filter_numeric_expr_generation( - left_watermark.clone(), - right_watermark.clone(), + 3 => gen_conjunctive_numeric_expr( + left_watermark, + right_watermark, Operator::Minus, Operator::Minus, Operator::Minus, @@ -1646,9 +1646,9 @@ mod tests { 10, ), // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 - 4 => utils::filter_numeric_expr_generation( - left_watermark.clone(), - right_watermark.clone(), + 4 => gen_conjunctive_numeric_expr( + left_watermark, + right_watermark, Operator::Minus, Operator::Minus, Operator::Minus, diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 85845cb3224d..c2258859ed54 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -547,7 +547,7 @@ impl ExprIntervalGraph { #[cfg(test)] mod tests { use super::*; - use crate::utils::filter_numeric_expr_generation; + use crate::intervals::test_utils::gen_conjunctive_numeric_expr; use itertools::Itertools; use crate::expressions::Column; @@ -712,7 +712,7 @@ mod tests { let left_col = Arc::new(Column::new("left_watermark", 0)); let right_col = Arc::new(Column::new("right_watermark", 0)); // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 < right_watermark + 33 - let expr = filter_numeric_expr_generation( + let expr = gen_conjunctive_numeric_expr( left_col.clone(), right_col.clone(), Operator::Plus, @@ -751,7 +751,7 @@ mod tests { let left_col = Arc::new(Column::new("left_watermark", 0)); let right_col = Arc::new(Column::new("right_watermark", 0)); // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 - let expr = filter_numeric_expr_generation( + let expr = gen_conjunctive_numeric_expr( left_col.clone(), right_col.clone(), Operator::Minus, @@ -791,7 +791,7 @@ mod tests { let left_col = Arc::new(Column::new("left_watermark", 0)); let right_col = Arc::new(Column::new("right_watermark", 0)); // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10 - let expr = filter_numeric_expr_generation( + let expr = gen_conjunctive_numeric_expr( left_col.clone(), right_col.clone(), Operator::Minus, @@ -830,7 +830,7 @@ mod tests { let left_col = Arc::new(Column::new("left_watermark", 0)); let right_col = Arc::new(Column::new("right_watermark", 0)); // left_watermark - 10 > right_watermark - 5 AND left_watermark - 3 < right_watermark + 10 - let expr = filter_numeric_expr_generation( + let expr = gen_conjunctive_numeric_expr( left_col.clone(), right_col.clone(), Operator::Minus, @@ -870,7 +870,7 @@ mod tests { let right_col = Arc::new(Column::new("right_watermark", 0)); // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 - let expr = filter_numeric_expr_generation( + let expr = gen_conjunctive_numeric_expr( left_col.clone(), right_col.clone(), Operator::Minus, diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs index 60fb5aa9777f..3016db4dc797 100644 --- a/datafusion/physical-expr/src/intervals/mod.rs +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -21,5 +21,6 @@ mod cp_solver; mod interval_aritmetic; +pub mod test_utils; pub use cp_solver::ExprIntervalGraph; pub use interval_aritmetic::*; diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs new file mode 100644 index 000000000000..ba02f4ff7aac --- /dev/null +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test utilities for the interval arithmetic library + +use std::sync::Arc; + +use crate::expressions::{BinaryExpr, Literal}; +use crate::PhysicalExpr; +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; + +#[allow(clippy::too_many_arguments)] +/// This test function generates a conjunctive statement with two numeric +/// terms with the following form: +/// left_col (op_1) a > right_col (op_2) b AND left_col (op_3) c < right_col (op_4) d +pub fn gen_conjunctive_numeric_expr( + left_col: Arc, + right_col: Arc, + op_1: Operator, + op_2: Operator, + op_3: Operator, + op_4: Operator, + a: i32, + b: i32, + c: i32, + d: i32, +) -> Arc { + let left_and_1 = Arc::new(BinaryExpr::new( + left_col.clone(), + op_1, + Arc::new(Literal::new(ScalarValue::Int32(Some(a)))), + )); + let left_and_2 = Arc::new(BinaryExpr::new( + right_col.clone(), + op_2, + Arc::new(Literal::new(ScalarValue::Int32(Some(b)))), + )); + + let right_and_1 = Arc::new(BinaryExpr::new( + left_col, + op_3, + Arc::new(Literal::new(ScalarValue::Int32(Some(c)))), + )); + let right_and_2 = Arc::new(BinaryExpr::new( + right_col, + op_4, + Arc::new(Literal::new(ScalarValue::Int32(Some(d)))), + )); + let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); + let right_expr = Arc::new(BinaryExpr::new(right_and_1, Operator::Lt, right_and_2)); + Arc::new(BinaryExpr::new(left_expr, Operator::And, right_expr)) +} diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index a91804340f1b..2091ab1dac76 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -16,14 +16,11 @@ // under the License. use crate::equivalence::EquivalentClass; -use crate::expressions::Column; -use crate::expressions::UnKnownColumn; -use crate::expressions::{BinaryExpr, Literal}; +use crate::expressions::{BinaryExpr, Column, UnKnownColumn}; use crate::rewrite::TreeNodeRewritable; -use crate::PhysicalSortExpr; -use crate::{EquivalenceProperties, PhysicalExpr}; +use crate::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::SchemaRef; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::Result; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; @@ -240,20 +237,6 @@ impl ExprTreeNode { } } } -/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. -pub fn ordering_satisfy EquivalenceProperties>( - provided: Option<&[PhysicalSortExpr]>, - required: Option<&[PhysicalSortExpr]>, - equal_properties: F, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => { - ordering_satisfy_concrete(provided, required, equal_properties) - } - } -} impl TreeNodeRewritable for Arc { fn map_children(self, transform: F) -> Result @@ -284,7 +267,7 @@ fn add_physical_expr_to_graph( graph: &mut StableGraph, input_node: T, ) -> NodeIndex { - // If we visited the node before, we just return it. That means, this node will be have multiple + // If we visited the node before, we just return it. That means, this node will have multiple // parents. match visited .iter() @@ -342,44 +325,19 @@ where Ok((root_tree_node.node.unwrap(), graph)) } -#[allow(clippy::too_many_arguments)] -/// left_col (op_1) a > right_col (op_2) b AND left_col (op_3) c < right_col (op_4) d -pub fn filter_numeric_expr_generation( - left_col: Arc, - right_col: Arc, - op_1: Operator, - op_2: Operator, - op_3: Operator, - op_4: Operator, - a: i32, - b: i32, - c: i32, - d: i32, -) -> Arc { - let left_and_1 = Arc::new(BinaryExpr::new( - left_col.clone(), - op_1, - Arc::new(Literal::new(ScalarValue::Int32(Some(a)))), - )); - let left_and_2 = Arc::new(BinaryExpr::new( - right_col.clone(), - op_2, - Arc::new(Literal::new(ScalarValue::Int32(Some(b)))), - )); - - let right_and_1 = Arc::new(BinaryExpr::new( - left_col.clone(), - op_3, - Arc::new(Literal::new(ScalarValue::Int32(Some(c)))), - )); - let right_and_2 = Arc::new(BinaryExpr::new( - right_col.clone(), - op_4, - Arc::new(Literal::new(ScalarValue::Int32(Some(d)))), - )); - let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); - let right_expr = Arc::new(BinaryExpr::new(right_and_1, Operator::Lt, right_and_2)); - Arc::new(BinaryExpr::new(left_expr, Operator::And, right_expr)) +/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. +pub fn ordering_satisfy EquivalenceProperties>( + provided: Option<&[PhysicalSortExpr]>, + required: Option<&[PhysicalSortExpr]>, + equal_properties: F, +) -> bool { + match (provided, required) { + (_, None) => true, + (None, Some(_)) => false, + (Some(provided), Some(required)) => { + ordering_satisfy_concrete(provided, required, equal_properties) + } + } } pub fn ordering_satisfy_concrete EquivalenceProperties>( From 7e250dfa95507bf8313268e3ce51ae0d553b4640 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 8 Feb 2023 20:04:11 -0600 Subject: [PATCH 23/39] Simplify DAG creation code --- .../physical-expr/src/intervals/cp_solver.rs | 9 +- datafusion/physical-expr/src/utils.rs | 153 ++++++------------ 2 files changed, 49 insertions(+), 113 deletions(-) diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index c2258859ed54..aad92875b353 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -31,7 +31,7 @@ use petgraph::Outgoing; use crate::expressions::{BinaryExpr, CastExpr, Literal}; use crate::intervals::interval_aritmetic::{apply_operator, Interval}; -use crate::utils::{build_physical_expr_graph, ExprTreeNode}; +use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; // Interval arithmetic provides a way to perform mathematical operations on @@ -164,8 +164,8 @@ impl ExprIntervalGraphNode { /// This function creates a DAEG node from Datafusion's [ExprTreeNode] /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). - pub fn make_node(node: Arc) -> ExprIntervalGraphNode { - let expr = node.expr(); + pub fn make_node(node: &ExprTreeNode) -> ExprIntervalGraphNode { + let expr = node.expression().clone(); if let Some(literal) = expr.as_any().downcast_ref::() { let value = literal.value(); let interval = Interval { @@ -272,8 +272,7 @@ pub fn propagate_comparison( impl ExprIntervalGraph { pub fn try_new(expr: Arc) -> Result { // Build the full graph: - let (root, graph) = - build_physical_expr_graph(expr, &ExprIntervalGraphNode::make_node)?; + let (root, graph) = build_dag(expr, &ExprIntervalGraphNode::make_node)?; Ok(Self { graph, root }) } diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 2091ab1dac76..a41263ac55d8 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -185,144 +185,81 @@ pub fn normalize_sort_expr_with_equivalence_properties( } #[derive(Clone, Debug)] -pub struct ExprTreeNode { +pub struct ExprTreeNode { expr: Arc, - node: Option, - child_nodes: Vec>, + data: Option, + child_nodes: Vec>, } -impl PartialEq for ExprTreeNode { - fn eq(&self, other: &ExprTreeNode) -> bool { - self.expr.eq(&other.expr) - } -} - -impl ExprTreeNode { - pub fn new(plan: Arc) -> Self { +impl ExprTreeNode { + pub fn new(expr: Arc) -> Self { ExprTreeNode { - expr: plan, - node: None, + expr, + data: None, child_nodes: vec![], } } - pub fn node(&self) -> &Option { - &self.node - } - - pub fn child_nodes(&self) -> &Vec> { - &self.child_nodes + pub fn expression(&self) -> &Arc { + &self.expr } - pub fn expr(&self) -> Arc { - self.expr.clone() - } - - pub fn children(&self) -> Vec> { - let plan_children = self.expr.children(); - let child_intervals: Vec> = self.child_nodes.to_vec(); - if plan_children.len() == child_intervals.len() { - child_intervals - } else { - plan_children - .into_iter() - .map(|child| { - Arc::new(ExprTreeNode { - expr: child.clone(), - node: None, - child_nodes: vec![], - }) - }) - .collect() - } + pub fn children(&self) -> Vec> { + self.expr + .children() + .into_iter() + .map(ExprTreeNode::new) + .collect() } } -impl TreeNodeRewritable for Arc { - fn map_children(self, transform: F) -> Result +impl TreeNodeRewritable for ExprTreeNode { + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { let children = self.children(); if !children.is_empty() { - let new_children: Vec> = children + self.child_nodes = children .into_iter() .map(transform) - .collect::>>>() - .unwrap(); - Ok(Arc::new(ExprTreeNode { - expr: self.expr.clone(), - node: None, - child_nodes: new_children, - })) - } else { - Ok(self) - } - } -} - -fn add_physical_expr_to_graph( - input: Arc, - visited: &mut Vec<(Arc, NodeIndex)>, - graph: &mut StableGraph, - input_node: T, -) -> NodeIndex { - // If we visited the node before, we just return it. That means, this node will have multiple - // parents. - match visited - .iter() - .find(|(visited_expr, _node)| visited_expr.eq(&input.expr())) - { - Some((_, idx)) => *idx, - None => { - let node_idx = graph.add_node(input_node); - visited.push((input.expr().clone(), node_idx)); - input.child_nodes().iter().for_each(|expr_node| { - graph.add_edge(node_idx, expr_node.node().unwrap(), 0); - }); - node_idx + .collect::>>()?; } + Ok(self) } } -fn post_order_tree_traverse_graph_create( - input: Arc, - graph: &mut StableGraph, - constructor: F, - visited: &mut Vec<(Arc, NodeIndex)>, -) -> Result>> -where - F: Fn(Arc) -> T, -{ - let node = constructor(input.clone()); - let node_idx = add_physical_expr_to_graph(input.clone(), visited, graph, node); - Ok(Some(Arc::new(ExprTreeNode { - expr: input.expr.clone(), - node: Some(node_idx), - child_nodes: input.child_nodes.to_vec(), - }))) -} - -pub fn build_physical_expr_graph( +/// This function converts the [PhysicalExpr] tree into a DAG by collecting identical +/// expressions in one node. Caller specifies the node type in this DAG via the +/// `constructor` argument, which constructs nodes in this DAG from the [ExprTreeNode] +/// ancillary object. +pub fn build_dag( expr: Arc, constructor: &F, ) -> Result<(NodeIndex, StableGraph)> where - F: Fn(Arc) -> T, + F: Fn(&ExprTreeNode) -> T, { - let init = Arc::new(ExprTreeNode::new(expr.clone())); - let mut graph: StableGraph = StableGraph::new(); - let mut visited_plans: Vec<(Arc, NodeIndex)> = vec![]; - // Create a graph - let root_tree_node = init.mutable_transform_up(&mut |expr| { - post_order_tree_traverse_graph_create( - expr, - &mut graph, - constructor, - &mut visited_plans, - ) + let init = ExprTreeNode::new(expr); + let mut graph = StableGraph::::new(); + let mut visited_plans = Vec::<(Arc, NodeIndex)>::new(); + let root = init.mutable_transform_up(&mut |mut input| { + let expr = input.expr.clone(); + let node_idx = match visited_plans.iter().find(|(e, _)| expr.eq(e)) { + Some((_, idx)) => *idx, + None => { + let node_idx = graph.add_node(constructor(&input)); + for expr_node in input.child_nodes.iter() { + graph.add_edge(node_idx, expr_node.data.unwrap(), 0); + } + visited_plans.push((expr, node_idx)); + node_idx + } + }; + input.data = Some(node_idx); + Ok(Some(input)) })?; - Ok((root_tree_node.node.unwrap(), graph)) + Ok((root.data.unwrap(), graph)) } /// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. From f7c216f0fd4c67fc0b63078faf2009f65fc505a9 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Thu, 9 Feb 2023 18:07:27 +0300 Subject: [PATCH 24/39] Reducing code change --- .../src/physical_optimizer/pipeline_fixer.rs | 42 +- .../core/src/physical_plan/joins/hash_join.rs | 602 ++++++++++++++- .../physical_plan/joins/hash_join_utils.rs | 729 ++---------------- .../joins/symmetric_hash_join.rs | 33 +- datafusion/physical-expr/src/lib.rs | 1 - .../src/physical_expr_visitor.rs | 100 --- 6 files changed, 705 insertions(+), 802 deletions(-) delete mode 100644 datafusion/physical-expr/src/physical_expr_visitor.rs diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 1d15f36f61cf..ff0388085dfa 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -94,11 +94,13 @@ impl PhysicalOptimizerRule for PipelineFixer { /// Currently, [CastExpr], [BinaryExpr], [Column] and [Literal] is supported. fn check_support(expr: &Arc) -> bool { let expr_any = expr.as_any(); - if let Some(binary_expr) = expr_any.downcast_ref::() { + let expr_supported = if let Some(binary_expr) = expr_any.downcast_ref::() + { is_operator_supported(binary_expr.op()) } else { expr_any.is::() || expr_any.is::() || expr_any.is::() - } + }; + expr_supported && expr.children().iter().all(check_support) } /// This function returns whether a given hash join is replaceable by a @@ -112,8 +114,7 @@ fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result let right = hash_join.right(); if let Some(right_ordering) = right.output_ordering() { let expr = filter.expression(); - let expr_supported = - check_support(expr) && expr.children().iter().all(check_support); + let expr_supported = check_support(expr); let left_convertible = convert_sort_expr_with_filter_schema( &JoinSide::Left, filter, @@ -296,6 +297,39 @@ fn apply_subrules_and_check_finiteness_requirements( check_finiteness_requirements(input) } +#[cfg(test)] +mod util_tests { + use crate::physical_optimizer::pipeline_fixer::check_support; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; + use datafusion_physical_expr::PhysicalExpr; + use std::sync::Arc; + + #[test] + fn check_expr_supported() { + let supported_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )) as Arc; + assert!(check_support(&supported_expr)); + let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc; + assert!(check_support(&supported_expr_2)); + let unsupported_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(Column::new("a", 0)), + )) as Arc; + assert!(!check_support(&unsupported_expr)); + let unsupported_expr_2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))), + )) as Arc; + assert!(!check_support(&unsupported_expr_2)); + } +} + #[cfg(test)] mod hash_join_tests { use super::*; diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index 3ca9596d069b..99042d69a457 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -18,37 +18,49 @@ //! Defines the join plan for executing partitions in parallel and then joining the results //! into a set of partitions. -use std::fmt; +use ahash::RandomState; + +use arrow::{ + array::{ + ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, Decimal128Array, + DictionaryArray, LargeStringArray, PrimitiveArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, + UInt32BufferBuilder, UInt64BufferBuilder, + }, + datatypes::{ + Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, + }, +}; +use smallvec::{smallvec, SmallVec}; use std::sync::Arc; -use std::task::Poll; use std::{any::Any, usize}; use std::{time::Instant, vec}; -use ahash::RandomState; +use futures::{ready, Stream, StreamExt, TryStreamExt}; +use arrow::array::Array; +use arrow::datatypes::{ArrowNativeType, DataType}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use futures::{ready, Stream, StreamExt, TryStreamExt}; -use hashbrown::raw::RawTable; -use log::debug; +use arrow::array::{ + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, + StringArray, TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; -use crate::arrow::array::BooleanBufferBuilder; +use datafusion_common::cast::{as_dictionary_array, as_string_array}; + +use hashbrown::raw::RawTable; -use crate::error::{DataFusionError, Result}; -use crate::execution::context::TaskContext; -use crate::logical_expr::JoinType; -use crate::physical_plan::joins::hash_join_utils; -use crate::physical_plan::joins::hash_join_utils::JoinHashMap; -use crate::physical_plan::joins::utils::{ - adjust_indices_by_join_type, build_batch_from_indices, - get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide, -}; use crate::physical_plan::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, expressions::Column, expressions::PhysicalSortExpr, + hash_utils::create_hashes, joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, combine_join_equivalence_properties, estimate_join_statistics, @@ -59,10 +71,44 @@ use crate::physical_plan::{ PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::error::{DataFusionError, Result}; +use crate::logical_expr::JoinType; + +use crate::arrow::array::BooleanBufferBuilder; +use crate::arrow::datatypes::TimeUnit; +use crate::execution::context::TaskContext; + use super::{ utils::{OnceAsync, OnceFut}, PartitionMode, }; +use crate::physical_plan::joins::utils::{ + adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, + get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide, +}; +use log::debug; +use std::fmt; +use std::task::Poll; + +// Maps a `u64` hash value based on the build ["on" values] to a list of indices with this key's value. +// +// Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used +// to put the indices in a certain bucket. +// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, +// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. +// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 +// As the key is a hash value, we need to check possible hash collisions in the probe stage +// During this stage it might be the case that a row is contained the same hashmap value, +// but the values don't match. Those are checked in the [equal_rows] macro +// TODO: speed up collision check and move away from using a hashbrown HashMap +// https://github.com/apache/arrow-datafusion/issues/50 +pub struct JoinHashMap(pub RawTable<(u64, SmallVec<[u64; 1]>)>); + +impl fmt::Debug for JoinHashMap { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { + Ok(()) + } +} type JoinLeftData = (JoinHashMap, RecordBatch); @@ -264,7 +310,7 @@ impl ExecutionPlan for HashJoinExec { /// Specifies whether this plan generates an infinite stream of records. /// If the plan does not support pipelining, but it its input(s) are - /// infinite, returns an error to indicate this. + /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { let (left, right) = (children[0], children[1]); // If left is unbounded, or right is unbounded with JoinType::Right, @@ -482,7 +528,7 @@ async fn collect_left_input( for batch in batches.iter() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); - hash_join_utils::update_hash( + update_hash( &on_left, batch, &mut hashmap, @@ -537,7 +583,7 @@ async fn partitioned_left_input( for batch in batches.iter() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); - hash_join_utils::update_hash( + update_hash( &on_left, batch, &mut hashmap, @@ -561,6 +607,43 @@ async fn partitioned_left_input( Ok((hashmap, single_batch)) } +/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, +/// assuming that the [RecordBatch] corresponds to the `index`th +pub fn update_hash( + on: &[Column], + batch: &RecordBatch, + hash_map: &mut JoinHashMap, + offset: usize, + random_state: &RandomState, + hashes_buffer: &mut Vec, +) -> Result<()> { + // evaluate the keys + let keys_values = on + .iter() + .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) + .collect::>>()?; + + // calculate the hash values + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + + // insert hashes to key of the hashmap + for (row, hash_value) in hash_values.iter().enumerate() { + let item = hash_map + .0 + .get_mut(*hash_value, |(hash, _)| *hash_value == *hash); + if let Some((_, indices)) = item { + indices.push((row + offset) as u64); + } else { + hash_map.0.insert( + *hash_value, + (*hash_value, smallvec![(row + offset) as u64]), + |(hash, _)| *hash, + ); + } + } + Ok(()) +} + /// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct HashJoinStream { /// Input schema @@ -597,6 +680,483 @@ impl RecordBatchStream for HashJoinStream { } } +// Get build and probe indices which satisfies the on condition (include equal_conditon and filter_in_join) in the Join +#[allow(clippy::too_many_arguments)] +pub fn build_join_indices( + probe_batch: &RecordBatch, + build_hashmap: &JoinHashMap, + build_input_buffer: &RecordBatch, + on_build: &[Column], + on_probe: &[Column], + filter: Option<&JoinFilter>, + random_state: &RandomState, + null_equals_null: &bool, + hashes_buffer: &mut Vec, + offset: Option, + build_side: JoinSide, +) -> Result<(UInt64Array, UInt32Array)> { + // Get the indices which satisfies the equal join condition, like `left.a1 = right.a2` + let (build_indices, probe_indices) = build_equal_condition_join_indices( + build_hashmap, + build_input_buffer, + probe_batch, + on_build, + on_probe, + random_state, + null_equals_null, + hashes_buffer, + offset, + )?; + if let Some(filter) = filter { + // Filter the indices which satisfies the non-equal join condition, like `left.b1 = 10` + apply_join_filter_to_indices( + build_input_buffer, + probe_batch, + build_indices, + probe_indices, + filter, + build_side, + ) + } else { + Ok((build_indices, probe_indices)) + } +} + +// Returns the index of equal condition join result: build_indices and probe_indices +// On LEFT.b1 = RIGHT.b2 +// LEFT Table: +// a1 b1 c1 +// 1 1 10 +// 3 3 30 +// 5 5 50 +// 7 7 70 +// 9 8 90 +// 11 8 110 +// 13 10 130 +// RIGHT Table: +// a2 b2 c2 +// 2 2 20 +// 4 4 40 +// 6 6 60 +// 8 8 80 +// 10 10 100 +// 12 10 120 +// The result is +// "+----+----+-----+----+----+-----+", +// "| a1 | b1 | c1 | a2 | b2 | c2 |", +// "+----+----+-----+----+----+-----+", +// "| 11 | 8 | 110 | 8 | 8 | 80 |", +// "| 13 | 10 | 130 | 10 | 10 | 100 |", +// "| 13 | 10 | 130 | 12 | 10 | 120 |", +// "| 9 | 8 | 90 | 8 | 8 | 80 |", +// "+----+----+-----+----+----+-----+" +// And the result of build and probe indices +// build indices: 5, 6, 6, 4 +// probe indices: 3, 4, 5, 3 +#[allow(clippy::too_many_arguments)] +pub fn build_equal_condition_join_indices( + build_hashmap: &JoinHashMap, + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_on: &[Column], + probe_on: &[Column], + random_state: &RandomState, + null_equals_null: &bool, + hashes_buffer: &mut Vec, + offset: Option, +) -> Result<(UInt64Array, UInt32Array)> { + let keys_values = probe_on + .iter() + .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) + .collect::>>()?; + let build_join_values = build_on + .iter() + .map(|c| { + Ok(c.evaluate(build_input_buffer)? + .into_array(build_input_buffer.num_rows())) + }) + .collect::>>()?; + hashes_buffer.clear(); + hashes_buffer.resize(probe_batch.num_rows(), 0); + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + // Using a buffer builder to avoid slower normal builder + let mut build_indices = UInt64BufferBuilder::new(0); + let mut probe_indices = UInt32BufferBuilder::new(0); + let offset_value = offset.unwrap_or(0); + // Visit all of the probe rows + for (row, hash_value) in hash_values.iter().enumerate() { + // Get the hash and find it in the build index + + // For every item on the build and probe we check if it matches + // This possibly contains rows with hash collisions, + // So we have to check here whether rows are equal or not + if let Some((_, indices)) = build_hashmap + .0 + .get(*hash_value, |(hash, _)| *hash_value == *hash) + { + for &i in indices { + // Check hash collisions + let offset_build_index = i as usize - offset_value; + // Check hash collisions + if equal_rows( + offset_build_index, + row, + &build_join_values, + &keys_values, + *null_equals_null, + )? { + build_indices.append(offset_build_index as u64); + probe_indices.append(row as u32); + } + } + } + } + let build = ArrayData::builder(DataType::UInt64) + .len(build_indices.len()) + .add_buffer(build_indices.finish()) + .build() + .unwrap(); + let probe = ArrayData::builder(DataType::UInt32) + .len(probe_indices.len()) + .add_buffer(probe_indices.finish()) + .build() + .unwrap(); + + Ok(( + PrimitiveArray::::from(build), + PrimitiveArray::::from(probe), + )) +} + +macro_rules! equal_rows_elem { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => left_array.value($left) == right_array.value($right), + (true, true) => $null_equals_null, + _ => false, + } + }}; +} + +macro_rules! equal_rows_elem_with_string_dict { + ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ + let left_array: &DictionaryArray<$key_array_type> = + as_dictionary_array::<$key_array_type>($l).unwrap(); + let right_array: &DictionaryArray<$key_array_type> = + as_dictionary_array::<$key_array_type>($r).unwrap(); + + let (left_values, left_values_index) = { + let keys_col = left_array.keys(); + if keys_col.is_valid($left) { + let values_index = keys_col + .value($left) + .to_usize() + .expect("Can not convert index to usize in dictionary"); + + ( + as_string_array(left_array.values()).unwrap(), + Some(values_index), + ) + } else { + (as_string_array(left_array.values()).unwrap(), None) + } + }; + let (right_values, right_values_index) = { + let keys_col = right_array.keys(); + if keys_col.is_valid($right) { + let values_index = keys_col + .value($right) + .to_usize() + .expect("Can not convert index to usize in dictionary"); + + ( + as_string_array(right_array.values()).unwrap(), + Some(values_index), + ) + } else { + (as_string_array(right_array.values()).unwrap(), None) + } + }; + + match (left_values_index, right_values_index) { + (Some(left_values_index), Some(right_values_index)) => { + left_values.value(left_values_index) + == right_values.value(right_values_index) + } + (None, None) => $null_equals_null, + _ => false, + } + }}; +} + +/// Left and right row have equal values +/// If more data types are supported here, please also add the data types in can_hash function +/// to generate hash join logical plan. +pub fn equal_rows( + left: usize, + right: usize, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], + null_equals_null: bool, +) -> Result { + let mut err = None; + let res = left_arrays + .iter() + .zip(right_arrays) + .all(|(l, r)| match l.data_type() { + DataType::Null => { + // lhs and rhs are both `DataType::Null`, so the equal result + // is dependent on `null_equals_null` + null_equals_null + } + DataType::Boolean => { + equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null) + } + DataType::Int8 => { + equal_rows_elem!(Int8Array, l, r, left, right, null_equals_null) + } + DataType::Int16 => { + equal_rows_elem!(Int16Array, l, r, left, right, null_equals_null) + } + DataType::Int32 => { + equal_rows_elem!(Int32Array, l, r, left, right, null_equals_null) + } + DataType::Int64 => { + equal_rows_elem!(Int64Array, l, r, left, right, null_equals_null) + } + DataType::UInt8 => { + equal_rows_elem!(UInt8Array, l, r, left, right, null_equals_null) + } + DataType::UInt16 => { + equal_rows_elem!(UInt16Array, l, r, left, right, null_equals_null) + } + DataType::UInt32 => { + equal_rows_elem!(UInt32Array, l, r, left, right, null_equals_null) + } + DataType::UInt64 => { + equal_rows_elem!(UInt64Array, l, r, left, right, null_equals_null) + } + DataType::Float32 => { + equal_rows_elem!(Float32Array, l, r, left, right, null_equals_null) + } + DataType::Float64 => { + equal_rows_elem!(Float64Array, l, r, left, right, null_equals_null) + } + DataType::Date32 => { + equal_rows_elem!(Date32Array, l, r, left, right, null_equals_null) + } + DataType::Date64 => { + equal_rows_elem!(Date64Array, l, r, left, right, null_equals_null) + } + DataType::Time32(time_unit) => match time_unit { + TimeUnit::Second => { + equal_rows_elem!(Time32SecondArray, l, r, left, right, null_equals_null) + } + TimeUnit::Millisecond => { + equal_rows_elem!(Time32MillisecondArray, l, r, left, right, null_equals_null) + } + _ => { + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + } + DataType::Time64(time_unit) => match time_unit { + TimeUnit::Microsecond => { + equal_rows_elem!(Time64MicrosecondArray, l, r, left, right, null_equals_null) + } + TimeUnit::Nanosecond => { + equal_rows_elem!(Time64NanosecondArray, l, r, left, right, null_equals_null) + } + _ => { + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + } + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => { + equal_rows_elem!( + TimestampSecondArray, + l, + r, + left, + right, + null_equals_null + ) + } + TimeUnit::Millisecond => { + equal_rows_elem!( + TimestampMillisecondArray, + l, + r, + left, + right, + null_equals_null + ) + } + TimeUnit::Microsecond => { + equal_rows_elem!( + TimestampMicrosecondArray, + l, + r, + left, + right, + null_equals_null + ) + } + TimeUnit::Nanosecond => { + equal_rows_elem!( + TimestampNanosecondArray, + l, + r, + left, + right, + null_equals_null + ) + } + }, + DataType::Utf8 => { + equal_rows_elem!(StringArray, l, r, left, right, null_equals_null) + } + DataType::LargeUtf8 => { + equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null) + } + DataType::Decimal128(_, lscale) => match r.data_type() { + DataType::Decimal128(_, rscale) => { + if lscale == rscale { + equal_rows_elem!( + Decimal128Array, + l, + r, + left, + right, + null_equals_null + ) + } else { + err = Some(Err(DataFusionError::Internal( + "Inconsistent Decimal data type in hasher, the scale should be same".to_string(), + ))); + false + } + } + _ => { + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + }, + DataType::Dictionary(key_type, value_type) + if *value_type.as_ref() == DataType::Utf8 => + { + match key_type.as_ref() { + DataType::Int8 => { + equal_rows_elem_with_string_dict!( + Int8Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int16 => { + equal_rows_elem_with_string_dict!( + Int16Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int32 => { + equal_rows_elem_with_string_dict!( + Int32Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int64 => { + equal_rows_elem_with_string_dict!( + Int64Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt8 => { + equal_rows_elem_with_string_dict!( + UInt8Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt16 => { + equal_rows_elem_with_string_dict!( + UInt16Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt32 => { + equal_rows_elem_with_string_dict!( + UInt32Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt64 => { + equal_rows_elem_with_string_dict!( + UInt64Type, + l, + r, + left, + right, + null_equals_null + ) + } + _ => { + // should not happen + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + } + } + other => { + // This is internal because we should have caught this before. + err = Some(Err(DataFusionError::Internal(format!( + "Unsupported data type in hasher: {other}" + )))); + false + } + }); + + err.unwrap_or(Ok(res)) +} + impl HashJoinStream { /// Separate implementation function that unpins the [`HashJoinStream`] so /// that partial borrows work correctly @@ -637,7 +1197,7 @@ impl HashJoinStream { let timer = self.join_metrics.probe_time.timer(); // get the matched two indices for the on condition - let left_right_indices = hash_join_utils::build_join_indices( + let left_right_indices = build_join_indices( &batch, &left_data.0, &left_data.1, @@ -748,7 +1308,7 @@ mod tests { use super::*; use crate::physical_expr::expressions::BinaryExpr; - use crate::physical_plan::joins::hash_join_utils::build_equal_condition_join_indices; + use crate::physical_plan::joins::hash_join::build_equal_condition_join_indices; use crate::physical_plan::joins::utils::JoinSide; use crate::prelude::SessionContext; use crate::{ diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index c516092f4c51..ccf3c224142d 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -17,629 +17,53 @@ //! Hash Join related functionality used both on logical and physical plans //! -use std::{fmt, usize}; +use std::usize; -use ahash::RandomState; +use arrow::datatypes::DataType; use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use arrow::{ - array::{ - Array, ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, - Decimal128Array, DictionaryArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, LargeStringArray, PrimitiveArray, StringArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt32BufferBuilder, UInt64Array, UInt64BufferBuilder, UInt8Array, - }, - datatypes::{ - ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, - }, -}; -use hashbrown::raw::RawTable; -use smallvec::{smallvec, SmallVec}; use crate::common::Result; use arrow::compute::CastOptions; -use datafusion_common::cast::{as_dictionary_array, as_string_array}; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; -use datafusion_physical_expr::hash_utils::create_hashes; use datafusion_physical_expr::intervals::Interval; -use datafusion_physical_expr::physical_expr_visitor::{ - PhysicalExprVisitable, PhysicalExpressionVisitor, Recursion, -}; use datafusion_physical_expr::rewrite::TreeNodeRewritable; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use std::collections::HashMap; use std::sync::Arc; -use crate::physical_plan::joins::utils::{ - apply_join_filter_to_indices, JoinFilter, JoinSide, -}; -use crate::physical_plan::sorts::sort::SortOptions; +use crate::physical_plan::joins::utils::{JoinFilter, JoinSide}; -/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, -/// assuming that the [RecordBatch] corresponds to the `index`th -pub fn update_hash( - on: &[Column], - batch: &RecordBatch, - hash_map: &mut JoinHashMap, - offset: usize, - random_state: &RandomState, - hashes_buffer: &mut Vec, -) -> Result<()> { - // evaluate the keys - let keys_values = on - .iter() - .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) - .collect::>>()?; - - // calculate the hash values - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - - // insert hashes to key of the hashmap - for (row, hash_value) in hash_values.iter().enumerate() { - let item = hash_map - .0 - .get_mut(*hash_value, |(hash, _)| *hash_value == *hash); - if let Some((_, indices)) = item { - indices.push((row + offset) as u64); - } else { - hash_map.0.insert( - *hash_value, - (*hash_value, smallvec![(row + offset) as u64]), - |(hash, _)| *hash, - ); - } - } - Ok(()) -} - -// Get build and probe indices which satisfies the on condition (include equal_conditon and filter_in_join) in the Join -#[allow(clippy::too_many_arguments)] -pub fn build_join_indices( - probe_batch: &RecordBatch, - build_hashmap: &JoinHashMap, - build_input_buffer: &RecordBatch, - on_build: &[Column], - on_probe: &[Column], - filter: Option<&JoinFilter>, - random_state: &RandomState, - null_equals_null: &bool, - hashes_buffer: &mut Vec, - offset: Option, - build_side: JoinSide, -) -> Result<(UInt64Array, UInt32Array)> { - // Get the indices which satisfies the equal join condition, like `left.a1 = right.a2` - let (build_indices, probe_indices) = build_equal_condition_join_indices( - build_hashmap, - build_input_buffer, - probe_batch, - on_build, - on_probe, - random_state, - null_equals_null, - hashes_buffer, - offset, - )?; - if let Some(filter) = filter { - // Filter the indices which satisfies the non-equal join condition, like `left.b1 = 10` - apply_join_filter_to_indices( - build_input_buffer, - probe_batch, - build_indices, - probe_indices, - filter, - build_side, - ) - } else { - Ok((build_indices, probe_indices)) - } -} - -macro_rules! equal_rows_elem { - ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ - let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); - let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); - - match (left_array.is_null($left), right_array.is_null($right)) { - (false, false) => left_array.value($left) == right_array.value($right), - (true, true) => $null_equals_null, - _ => false, - } - }}; -} - -macro_rules! equal_rows_elem_with_string_dict { - ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ - let left_array: &DictionaryArray<$key_array_type> = - as_dictionary_array::<$key_array_type>($l).unwrap(); - let right_array: &DictionaryArray<$key_array_type> = - as_dictionary_array::<$key_array_type>($r).unwrap(); - - let (left_values, left_values_index) = { - let keys_col = left_array.keys(); - if keys_col.is_valid($left) { - let values_index = keys_col - .value($left) - .to_usize() - .expect("Can not convert index to usize in dictionary"); - - ( - as_string_array(left_array.values()).unwrap(), - Some(values_index), - ) - } else { - (as_string_array(left_array.values()).unwrap(), None) - } - }; - let (right_values, right_values_index) = { - let keys_col = right_array.keys(); - if keys_col.is_valid($right) { - let values_index = keys_col - .value($right) - .to_usize() - .expect("Can not convert index to usize in dictionary"); - - ( - as_string_array(right_array.values()).unwrap(), - Some(values_index), - ) - } else { - (as_string_array(right_array.values()).unwrap(), None) - } - }; - - match (left_values_index, right_values_index) { - (Some(left_values_index), Some(right_values_index)) => { - left_values.value(left_values_index) - == right_values.value(right_values_index) - } - (None, None) => $null_equals_null, - _ => false, - } - }}; -} - -/// Left and right row have equal values -/// If more data types are supported here, please also add the data types in can_hash function -/// to generate hash join logical plan. -pub fn equal_rows( - left: usize, - right: usize, - left_arrays: &[ArrayRef], - right_arrays: &[ArrayRef], - null_equals_null: bool, -) -> Result { - let mut err = None; - let res = left_arrays - .iter() - .zip(right_arrays) - .all(|(l, r)| match l.data_type() { - DataType::Null => { - // lhs and rhs are both `DataType::Null`, so the equal result - // is dependent on `null_equals_null` - null_equals_null - } - DataType::Boolean => { - equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null) - } - DataType::Int8 => { - equal_rows_elem!(Int8Array, l, r, left, right, null_equals_null) - } - DataType::Int16 => { - equal_rows_elem!(Int16Array, l, r, left, right, null_equals_null) - } - DataType::Int32 => { - equal_rows_elem!(Int32Array, l, r, left, right, null_equals_null) - } - DataType::Int64 => { - equal_rows_elem!(Int64Array, l, r, left, right, null_equals_null) - } - DataType::UInt8 => { - equal_rows_elem!(UInt8Array, l, r, left, right, null_equals_null) - } - DataType::UInt16 => { - equal_rows_elem!(UInt16Array, l, r, left, right, null_equals_null) - } - DataType::UInt32 => { - equal_rows_elem!(UInt32Array, l, r, left, right, null_equals_null) - } - DataType::UInt64 => { - equal_rows_elem!(UInt64Array, l, r, left, right, null_equals_null) - } - DataType::Float32 => { - equal_rows_elem!(Float32Array, l, r, left, right, null_equals_null) - } - DataType::Float64 => { - equal_rows_elem!(Float64Array, l, r, left, right, null_equals_null) - } - DataType::Date32 => { - equal_rows_elem!(Date32Array, l, r, left, right, null_equals_null) - } - DataType::Date64 => { - equal_rows_elem!(Date64Array, l, r, left, right, null_equals_null) - } - DataType::Time32(time_unit) => match time_unit { - TimeUnit::Second => { - equal_rows_elem!(Time32SecondArray, l, r, left, right, null_equals_null) - } - TimeUnit::Millisecond => { - equal_rows_elem!(Time32MillisecondArray, l, r, left, right, null_equals_null) - } - _ => { - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - } - DataType::Time64(time_unit) => match time_unit { - TimeUnit::Microsecond => { - equal_rows_elem!(Time64MicrosecondArray, l, r, left, right, null_equals_null) - } - TimeUnit::Nanosecond => { - equal_rows_elem!(Time64NanosecondArray, l, r, left, right, null_equals_null) - } - _ => { - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - } - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => { - equal_rows_elem!( - TimestampSecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - TimeUnit::Millisecond => { - equal_rows_elem!( - TimestampMillisecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - TimeUnit::Microsecond => { - equal_rows_elem!( - TimestampMicrosecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - TimeUnit::Nanosecond => { - equal_rows_elem!( - TimestampNanosecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - }, - DataType::Utf8 => { - equal_rows_elem!(StringArray, l, r, left, right, null_equals_null) - } - DataType::LargeUtf8 => { - equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null) - } - DataType::Decimal128(_, lscale) => match r.data_type() { - DataType::Decimal128(_, rscale) => { - if lscale == rscale { - equal_rows_elem!( - Decimal128Array, - l, - r, - left, - right, - null_equals_null - ) - } else { - err = Some(Err(DataFusionError::Internal( - "Inconsistent Decimal data type in hasher, the scale should be same".to_string(), - ))); - false - } - } - _ => { - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - }, - DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - match key_type.as_ref() { - DataType::Int8 => { - equal_rows_elem_with_string_dict!( - Int8Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::Int16 => { - equal_rows_elem_with_string_dict!( - Int16Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::Int32 => { - equal_rows_elem_with_string_dict!( - Int32Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::Int64 => { - equal_rows_elem_with_string_dict!( - Int64Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt8 => { - equal_rows_elem_with_string_dict!( - UInt8Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt16 => { - equal_rows_elem_with_string_dict!( - UInt16Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt32 => { - equal_rows_elem_with_string_dict!( - UInt32Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt64 => { - equal_rows_elem_with_string_dict!( - UInt64Type, - l, - r, - left, - right, - null_equals_null - ) - } - _ => { - // should not happen - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - } - } - other => { - // This is internal because we should have caught this before. - err = Some(Err(DataFusionError::Internal(format!( - "Unsupported data type in hasher: {other}" - )))); - false - } - }); - - err.unwrap_or(Ok(res)) -} - -// Returns the index of equal condition join result: build_indices and probe_indices -// On LEFT.b1 = RIGHT.b2 -// LEFT Table: -// a1 b1 c1 -// 1 1 10 -// 3 3 30 -// 5 5 50 -// 7 7 70 -// 9 8 90 -// 11 8 110 -// 13 10 130 -// RIGHT Table: -// a2 b2 c2 -// 2 2 20 -// 4 4 40 -// 6 6 60 -// 8 8 80 -// 10 10 100 -// 12 10 120 -// The result is -// "+----+----+-----+----+----+-----+", -// "| a1 | b1 | c1 | a2 | b2 | c2 |", -// "+----+----+-----+----+----+-----+", -// "| 11 | 8 | 110 | 8 | 8 | 80 |", -// "| 13 | 10 | 130 | 10 | 10 | 100 |", -// "| 13 | 10 | 130 | 12 | 10 | 120 |", -// "| 9 | 8 | 90 | 8 | 8 | 80 |", -// "+----+----+-----+----+----+-----+" -// And the result of build and probe indices -// build indices: 5, 6, 6, 4 -// probe indices: 3, 4, 5, 3 -#[allow(clippy::too_many_arguments)] -pub fn build_equal_condition_join_indices( - build_hashmap: &JoinHashMap, - build_input_buffer: &RecordBatch, - probe_batch: &RecordBatch, - build_on: &[Column], - probe_on: &[Column], - random_state: &RandomState, - null_equals_null: &bool, - hashes_buffer: &mut Vec, - offset: Option, -) -> Result<(UInt64Array, UInt32Array)> { - let keys_values = probe_on - .iter() - .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) - .collect::>>()?; - let build_join_values = build_on - .iter() - .map(|c| { - Ok(c.evaluate(build_input_buffer)? - .into_array(build_input_buffer.num_rows())) +fn check_filter_expr_contains_sort_information( + expr: &Arc, + checked_expr: &Arc, +) -> bool { + let contains = checked_expr.eq(expr); + contains + || expr.children().iter().any(|child_expr| { + check_filter_expr_contains_sort_information(child_expr, checked_expr) }) - .collect::>>()?; - hashes_buffer.clear(); - hashes_buffer.resize(probe_batch.num_rows(), 0); - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - // Using a buffer builder to avoid slower normal builder - let mut build_indices = UInt64BufferBuilder::new(0); - let mut probe_indices = UInt32BufferBuilder::new(0); - let offset_value = offset.unwrap_or(0); - // Visit all of the probe rows - for (row, hash_value) in hash_values.iter().enumerate() { - // Get the hash and find it in the build index - - // For every item on the build and probe we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - if let Some((_, indices)) = build_hashmap - .0 - .get(*hash_value, |(hash, _)| *hash_value == *hash) - { - for &i in indices { - // Check hash collisions - let offset_build_index = i as usize - offset_value; - // Check hash collisions - if equal_rows( - offset_build_index, - row, - &build_join_values, - &keys_values, - *null_equals_null, - )? { - build_indices.append(offset_build_index as u64); - probe_indices.append(row as u32); - } - } - } - } - let build = ArrayData::builder(DataType::UInt64) - .len(build_indices.len()) - .add_buffer(build_indices.finish()) - .build() - .unwrap(); - let probe = ArrayData::builder(DataType::UInt32) - .len(probe_indices.len()) - .add_buffer(probe_indices.finish()) - .build() - .unwrap(); - - Ok(( - PrimitiveArray::::from(build), - PrimitiveArray::::from(probe), - )) -} - -// Maps a `u64` hash value based on the build ["on" values] to a list of indices with this key's value. -// -// Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used -// to put the indices in a certain bucket. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, -// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. -// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 -// As the key is a hash value, we need to check possible hash collisions in the probe stage -// During this stage it might be the case that a row is contained the same hashmap value, -// but the values don't match. Those are checked in the [equal_rows] macro -// TODO: speed up collision check and move away from using a hashbrown HashMap -// https://github.com/apache/arrow-datafusion/issues/50 -pub struct JoinHashMap(pub RawTable<(u64, SmallVec<[u64; 1]>)>); - -impl fmt::Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { - Ok(()) - } -} - -#[derive(Debug)] -pub struct CheckFilterExprContainsSortInformation { - /// Supported state - pub contains: bool, - pub checked_expr: Arc, } -impl CheckFilterExprContainsSortInformation { - pub fn new(checked_expr: Arc) -> Self { - Self { - contains: false, - checked_expr, - } - } -} - -impl PhysicalExpressionVisitor for CheckFilterExprContainsSortInformation { - fn pre_visit(mut self, expr: Arc) -> Result> { - self.contains |= self.checked_expr.eq(&expr); - if self.contains { - Ok(Recursion::Stop(self)) - } else { - Ok(Recursion::Continue(self)) +fn recursive_physical_expr_column_collector( + expr: &Arc, + columns: &mut Vec, +) { + if let Some(column) = expr.as_any().downcast_ref::() { + if !columns.iter().any(|c| c.eq(column)) { + columns.push(column.clone()) } } + expr.children().iter().for_each(|child_expr| { + recursive_physical_expr_column_collector(child_expr, columns) + }) } -#[derive(Debug)] -pub struct PhysicalExprColumnCollector { - pub columns: Vec, -} - -impl PhysicalExprColumnCollector { - pub fn new() -> Self { - Self { columns: vec![] } - } -} - -impl PhysicalExpressionVisitor for PhysicalExprColumnCollector { - fn pre_visit(mut self, expr: Arc) -> Result> { - if let Some(column) = expr.as_any().downcast_ref::() { - self.columns.push(column.clone()) - } - Ok(Recursion::Continue(self)) - } +fn physical_expr_column_collector(expr: &Arc) -> Vec { + let mut columns = vec![]; + recursive_physical_expr_column_collector(expr, &mut columns); + columns } /// Create main_col -> filter_col one to one mapping from filter column indices. @@ -679,7 +103,7 @@ pub fn map_origin_col_to_filter_col( /// 2. Collects all columns in the sort expression using the `PhysicalExprColumnCollector` visitor. /// 3. Checks if all columns are included in the `column_mapping_information` map. /// 4. If all columns are included, the sort expression is converted into a filter expression using the `transform_up` and `convert_filter_columns` functions. -/// 5. Searches the converted filter expression in the filter expression using the `CheckFilterExprContainsSortInformation` visitor. +/// 5. Searches the converted filter expression in the filter expression using the `check_filter_expr_contains_sort_information`. /// 6. If an exact match is encountered, returns the converted filter expression as `Some(Arc)`. /// 7. If all columns are not included or the exact match is not encountered, returns `None`. /// @@ -700,10 +124,7 @@ pub fn convert_sort_expr_with_filter_schema( map_origin_col_to_filter_col(filter, schema, side)?; let expr = sort_expr.expr.clone(); // Get main schema columns - let expr_columns = expr - .clone() - .accept(PhysicalExprColumnCollector::new())? - .columns; + let expr_columns = physical_expr_column_collector(&expr); // Calculation is possible with 'column_mapping_information' since sort exprs belong to a child. let all_columns_are_included = expr_columns .iter() @@ -716,14 +137,10 @@ pub fn convert_sort_expr_with_filter_schema( // Search converted PhysicalExpr in filter expression // If the exact match is encountered, use this sorted expression in graph traversals. - if filter - .expression() - .clone() - .accept(CheckFilterExprContainsSortInformation::new( - converted_filter_expr.clone(), - ))? - .contains - { + if check_filter_expr_contains_sort_information( + filter.expression(), + &converted_filter_expr, + ) { Ok(Some(converted_filter_expr)) } else { Ok(None) @@ -750,12 +167,7 @@ pub fn build_filter_input_order_v2( order: &PhysicalSortExpr, ) -> Result { match convert_sort_expr_with_filter_schema(&side, filter, schema, order)? { - Some(expr) => Ok(SortedFilterExpr::new( - side, - order.expr.clone(), - expr, - order.options, - )), + Some(expr) => Ok(SortedFilterExpr::new(side, order.clone(), expr)), None => Err(DataFusionError::Plan(format!( "The {side} side of the join does not have an expression sorted." ))), @@ -788,12 +200,10 @@ pub struct SortedFilterExpr { /// Column side pub join_side: JoinSide, /// Sorted expr from a particular join side (child) - pub origin_expr: Arc, + pub origin_sorted_expr: PhysicalSortExpr, /// For interval calculations, one to one mapping of the columns according to filter expression, /// and column indices. pub filter_expr: Arc, - /// Sort option - pub sort_option: SortOptions, /// Interval pub interval: Interval, /// NodeIndex in Graph @@ -804,30 +214,24 @@ impl SortedFilterExpr { /// Constructor pub fn new( join_side: JoinSide, - origin_expr: Arc, + origin_sorted_expr: PhysicalSortExpr, filter_expr: Arc, - sort_option: SortOptions, ) -> Self { Self { join_side, - origin_expr, + origin_sorted_expr, filter_expr, - sort_option, interval: Interval::default(), node_index: 0, } } /// Get origin expr information - pub fn origin_expr(&self) -> Arc { - self.origin_expr.clone() + pub fn origin_sorted_expr(&self) -> PhysicalSortExpr { + self.origin_sorted_expr.clone() } /// Get filter expr information - pub fn filter_expr(&self) -> Arc { - self.filter_expr.clone() - } - /// Get sort information - pub fn sort_option(&self) -> SortOptions { - self.sort_option + pub fn filter_expr(&self) -> &Arc { + &self.filter_expr } /// Get interval information pub fn interval(&self) -> &Interval { @@ -911,28 +315,31 @@ mod tests { use arrow::compute::{CastOptions, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; use std::sync::Arc; + #[test] + fn test_column_collector() { + let filter_expr = complicated_filter(); + let columns = physical_expr_column_collector(&filter_expr); + assert_eq!(columns.len(), 3) + } + #[test] fn find_expr_inside_expr() -> Result<()> { let filter_expr = complicated_filter(); - let expr_1 = Arc::new(Column::new("gnz", 0)); - assert!( - !filter_expr - .clone() - .accept(CheckFilterExprContainsSortInformation::new(expr_1,))? - .contains - ); + let expr_1: Arc = Arc::new(Column::new("gnz", 0)); + assert!(!check_filter_expr_contains_sort_information( + &filter_expr, + &expr_1 + )); - let expr_2 = Arc::new(Column::new("1", 1)); + let expr_2: Arc = Arc::new(Column::new("1", 1)); - assert!( - filter_expr - .clone() - .accept(CheckFilterExprContainsSortInformation::new(expr_2,))? - .contains - ); + assert!(check_filter_expr_contains_sort_information( + &filter_expr, + &expr_2 + )); - let expr_3 = Arc::new(CastExpr::new( + let expr_3: Arc = Arc::new(CastExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("0", 0)), Operator::Plus, @@ -942,21 +349,17 @@ mod tests { CastOptions { safe: false }, )); - assert!( - filter_expr - .clone() - .accept(CheckFilterExprContainsSortInformation::new(expr_3,))? - .contains - ); + assert!(check_filter_expr_contains_sort_information( + &filter_expr, + &expr_3 + )); - let expr_4 = Arc::new(Column::new("1", 42)); + let expr_4: Arc = Arc::new(Column::new("1", 42)); - assert!( - !filter_expr - .clone() - .accept(CheckFilterExprContainsSortInformation::new(expr_4,))? - .contains - ); + assert!(!check_filter_expr_contains_sort_information( + &filter_expr, + &expr_4, + )); Ok(()) } diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index cf052963ef5a..340854666fec 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -47,9 +47,11 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; use crate::logical_expr::JoinType; use crate::physical_plan::common::merge_batches; +use crate::physical_plan::joins::hash_join::{ + build_join_indices, update_hash, JoinHashMap, +}; use crate::physical_plan::joins::hash_join_utils::{ - build_filter_input_order_v2, build_join_indices, update_hash, JoinHashMap, - SortedFilterExpr, + build_filter_input_order_v2, SortedFilterExpr, }; use crate::physical_plan::joins::utils::build_batch_from_indices; use crate::physical_plan::{ @@ -293,7 +295,7 @@ impl SymmetricHashJoinExec { let child_node_indexes = physical_expr_graph.gather_node_indices( &filter_columns .iter() - .map(|sorted_expr| sorted_expr.filter_expr()) + .map(|sorted_expr| sorted_expr.filter_expr().clone()) .collect::>(), ); @@ -653,8 +655,7 @@ fn calculate_filter_expr_intervals( // Destructure the SortedFilterExpr let SortedFilterExpr { join_side, - origin_expr, - sort_option, + origin_sorted_expr, interval, .. } = sorted_expr; @@ -662,12 +663,14 @@ fn calculate_filter_expr_intervals( // Get the array of the first or last value for the expression, depending on join side let array = if build_side.eq(join_side) { // Get first value for expr - origin_expr + origin_sorted_expr + .expr .evaluate(&build_input_buffer.slice(0, 1))? .into_array(1) } else { // Get last value for expr - origin_expr + origin_sorted_expr + .expr .evaluate(&probe_batch.slice(probe_batch.num_rows() - 1, 1))? .into_array(1) }; @@ -679,7 +682,7 @@ fn calculate_filter_expr_intervals( let infinite = ScalarValue::try_from(value.get_datatype())?; // Update the interval with lower and upper bounds based on the sort option - *interval = if sort_option.descending { + *interval = if origin_sorted_expr.options.descending { Interval { lower: infinite, upper: value, @@ -704,22 +707,26 @@ fn determine_prune_length( .flat_map(|sorted_expr| { let SortedFilterExpr { join_side, - origin_expr, - sort_option, + origin_sorted_expr, interval, .. } = sorted_expr; if build_side.eq(join_side) { - let batch_arr = origin_expr + let batch_arr = origin_sorted_expr + .expr .evaluate(buffer) .unwrap() .into_array(buffer.num_rows()); - let target = if sort_option.descending { + let target = if origin_sorted_expr.options.descending { interval.upper.clone() } else { interval.lower.clone() }; - Some(bisect::(&[batch_arr], &[target], &[*sort_option])) + Some(bisect::( + &[batch_arr], + &[target], + &[origin_sorted_expr.options], + )) } else { None } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index e15da3a1e753..9719b4b7901e 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -29,7 +29,6 @@ pub mod hash_utils; pub mod intervals; pub mod math_expressions; mod physical_expr; -pub mod physical_expr_visitor; pub mod planner; #[cfg(feature = "regex_expressions")] pub mod regex_expressions; diff --git a/datafusion/physical-expr/src/physical_expr_visitor.rs b/datafusion/physical-expr/src/physical_expr_visitor.rs deleted file mode 100644 index 7bc0227e9ecb..000000000000 --- a/datafusion/physical-expr/src/physical_expr_visitor.rs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::expressions::{ - BinaryExpr, CastExpr, DateTimeIntervalExpr, GetIndexedFieldExpr, InListExpr, -}; -use crate::PhysicalExpr; - -use datafusion_common::Result; - -use std::sync::Arc; - -/// Controls how the visitor recursion should proceed. -pub enum Recursion { - /// Attempt to visit all the children, recursively, of this expression. - Continue(V), - /// Do not visit the children of this expression, though the walk - /// of parents of this expression will not be affected - Stop(V), -} - -/// Encode the traversal of an expression tree. When passed to -/// `Arc::accept`, `PhysicalExpressionVisitor::visit` is invoked -/// recursively on all nodes of an expression tree. See the comments -/// on `Arc::accept` for details on its use -pub trait PhysicalExpressionVisitor>: - Sized -{ - /// Invoked before any children of `Arc` are visited. - fn pre_visit(self, expr: E) -> Result> - where - Self: PhysicalExpressionVisitor; - - /// Invoked after all children of `Arc` are visited. Default - /// implementation does nothing. - fn post_visit(self, _expr: E) -> Result { - Ok(self) - } -} - -/// trait for types that can be visited by [`ExpressionVisitor`] -pub trait PhysicalExprVisitable: Sized { - /// accept a visitor, calling `visit` on all children of this - fn accept>(self, visitor: V) -> Result; -} - -impl PhysicalExprVisitable for Arc { - fn accept>(self, visitor: V) -> Result { - let visitor = match visitor.pre_visit(self.clone())? { - Recursion::Continue(visitor) => visitor, - // If the recursion should stop, do not visit children - Recursion::Stop(visitor) => return Ok(visitor), - }; - let visitor = - if let Some(binary_expr) = self.as_any().downcast_ref::() { - let left_expr = binary_expr.left().clone(); - let right_expr = binary_expr.right().clone(); - let visitor = left_expr.accept(visitor)?; - right_expr.accept(visitor) - } else if let Some(datatime_expr) = - self.as_any().downcast_ref::() - { - let lhs = datatime_expr.lhs().clone(); - let rhs = datatime_expr.rhs().clone(); - let visitor = lhs.accept(visitor)?; - rhs.accept(visitor) - } else if let Some(cast) = self.as_any().downcast_ref::() { - let expr = cast.expr().clone(); - expr.accept(visitor) - } else if let Some(get_index) = - self.as_any().downcast_ref::() - { - let arg = get_index.arg().clone(); - arg.accept(visitor) - } else if let Some(inlist_expr) = self.as_any().downcast_ref::() { - let expr = inlist_expr.expr().clone(); - let list = inlist_expr.list(); - let visitor = expr.clone().accept(visitor)?; - list.iter() - .try_fold(visitor, |visitor, arg| arg.clone().accept(visitor)) - } else { - Ok(visitor) - }?; - visitor.post_visit(self.clone()) - } -} From 24724463d1aa0a3b781a80a9467f5c469d27a3dc Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Thu, 9 Feb 2023 17:43:41 -0600 Subject: [PATCH 25/39] Comment improvements and simplifications --- .../src/physical_optimizer/pipeline_fixer.rs | 3 +- .../core/src/physical_plan/joins/hash_join.rs | 37 ++- .../physical_plan/joins/hash_join_utils.rs | 263 +++++++++--------- .../core/src/physical_plan/joins/mod.rs | 4 +- .../joins/symmetric_hash_join.rs | 210 ++++++-------- .../core/src/physical_plan/joins/utils.rs | 6 +- datafusion/core/src/test_util.rs | 15 +- datafusion/core/tests/fifo.rs | 19 +- datafusion/expr/src/operator.rs | 18 +- datafusion/physical-expr/src/utils.rs | 12 +- 10 files changed, 272 insertions(+), 315 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index ff0388085dfa..049d3b0f8b9e 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -113,8 +113,7 @@ fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result if let Some(left_ordering) = left.output_ordering() { let right = hash_join.right(); if let Some(right_ordering) = right.output_ordering() { - let expr = filter.expression(); - let expr_supported = check_support(expr); + let expr_supported = check_support(filter.expression()); let left_convertible = convert_sort_expr_with_filter_schema( &JoinSide::Left, filter, diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index 99042d69a457..d013995cdadb 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -90,7 +90,7 @@ use log::debug; use std::fmt; use std::task::Poll; -// Maps a `u64` hash value based on the build ["on" values] to a list of indices with this key's value. +// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. // // Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used // to put the indices in a certain bucket. @@ -680,7 +680,8 @@ impl RecordBatchStream for HashJoinStream { } } -// Get build and probe indices which satisfies the on condition (include equal_conditon and filter_in_join) in the Join +/// Gets build and probe indices which satisfy the on condition (including +/// the equality condition and the join filter) in the join. #[allow(clippy::too_many_arguments)] pub fn build_join_indices( probe_batch: &RecordBatch, @@ -695,7 +696,7 @@ pub fn build_join_indices( offset: Option, build_side: JoinSide, ) -> Result<(UInt64Array, UInt32Array)> { - // Get the indices which satisfies the equal join condition, like `left.a1 = right.a2` + // Get the indices that satisfy the equality condition, like `left.a1 = right.a2` let (build_indices, probe_indices) = build_equal_condition_join_indices( build_hashmap, build_input_buffer, @@ -708,7 +709,7 @@ pub fn build_join_indices( offset, )?; if let Some(filter) = filter { - // Filter the indices which satisfies the non-equal join condition, like `left.b1 = 10` + // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` apply_join_filter_to_indices( build_input_buffer, probe_batch, @@ -722,7 +723,7 @@ pub fn build_join_indices( } } -// Returns the index of equal condition join result: build_indices and probe_indices +// Returns build/probe indices satisfying the equality condition. // On LEFT.b1 = RIGHT.b2 // LEFT Table: // a1 b1 c1 @@ -750,9 +751,9 @@ pub fn build_join_indices( // "| 13 | 10 | 130 | 12 | 10 | 120 |", // "| 9 | 8 | 90 | 8 | 8 | 80 |", // "+----+----+-----+----+----+-----+" -// And the result of build and probe indices -// build indices: 5, 6, 6, 4 -// probe indices: 3, 4, 5, 3 +// And the result of build and probe indices are: +// Build indices: 5, 6, 6, 4 +// Probe indices: 3, 4, 5, 3 #[allow(clippy::too_many_arguments)] pub fn build_equal_condition_join_indices( build_hashmap: &JoinHashMap, @@ -814,13 +815,11 @@ pub fn build_equal_condition_join_indices( let build = ArrayData::builder(DataType::UInt64) .len(build_indices.len()) .add_buffer(build_indices.finish()) - .build() - .unwrap(); + .build()?; let probe = ArrayData::builder(DataType::UInt32) .len(probe_indices.len()) .add_buffer(probe_indices.finish()) - .build() - .unwrap(); + .build()?; Ok(( PrimitiveArray::::from(build), @@ -895,7 +894,7 @@ macro_rules! equal_rows_elem_with_string_dict { /// Left and right row have equal values /// If more data types are supported here, please also add the data types in can_hash function /// to generate hash join logical plan. -pub fn equal_rows( +fn equal_rows( left: usize, right: usize, left_arrays: &[ArrayRef], @@ -1308,21 +1307,21 @@ mod tests { use super::*; use crate::physical_expr::expressions::BinaryExpr; - use crate::physical_plan::joins::hash_join::build_equal_condition_join_indices; - use crate::physical_plan::joins::utils::JoinSide; use crate::prelude::SessionContext; use crate::{ assert_batches_sorted_eq, physical_plan::{ - common, expressions::Column, hash_utils::create_hashes, memory::MemoryExec, + common, + expressions::Column, + hash_utils::create_hashes, + joins::{hash_join::build_equal_condition_join_indices, utils::JoinSide}, + memory::MemoryExec, repartition::RepartitionExec, }, test::exec::MockExec, test::{build_table_i32, columns}, }; - use arrow::array::UInt32Builder; - use arrow::array::UInt64Builder; - use arrow::array::{ArrayRef, Date32Array, Int32Array}; + use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::Operator; diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index ccf3c224142d..dccd34241438 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -16,87 +16,86 @@ // under the License. //! Hash Join related functionality used both on logical and physical plans -//! + +use std::collections::HashMap; +use std::sync::Arc; use std::usize; -use arrow::datatypes::DataType; use arrow::datatypes::SchemaRef; -use crate::common::Result; -use arrow::compute::CastOptions; -use datafusion_common::{DataFusionError, ScalarValue}; -use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::Interval; use datafusion_physical_expr::rewrite::TreeNodeRewritable; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use std::collections::HashMap; -use std::sync::Arc; +use crate::common::Result; use crate::physical_plan::joins::utils::{JoinFilter, JoinSide}; fn check_filter_expr_contains_sort_information( expr: &Arc, - checked_expr: &Arc, + reference: &Arc, ) -> bool { - let contains = checked_expr.eq(expr); - contains - || expr.children().iter().any(|child_expr| { - check_filter_expr_contains_sort_information(child_expr, checked_expr) - }) + expr.eq(reference) + || expr + .children() + .iter() + .any(|e| check_filter_expr_contains_sort_information(e, reference)) } -fn recursive_physical_expr_column_collector( - expr: &Arc, - columns: &mut Vec, -) { +fn collect_columns_recursive(expr: &Arc, columns: &mut Vec) { if let Some(column) = expr.as_any().downcast_ref::() { if !columns.iter().any(|c| c.eq(column)) { columns.push(column.clone()) } } - expr.children().iter().for_each(|child_expr| { - recursive_physical_expr_column_collector(child_expr, columns) - }) + expr.children() + .iter() + .for_each(|e| collect_columns_recursive(e, columns)) } -fn physical_expr_column_collector(expr: &Arc) -> Vec { +fn collect_columns(expr: &Arc) -> Vec { let mut columns = vec![]; - recursive_physical_expr_column_collector(expr, &mut columns); + collect_columns_recursive(expr, &mut columns); columns } -/// Create main_col -> filter_col one to one mapping from filter column indices. -/// A column index looks like -/// ColumnIndex { -// index: 0, -> field index in main schema -// side: JoinSide::Left, -> child side -// }, +/// Create a one to one mapping from main columns to filter columns using +/// filter column indices. A column index looks like: +/// ```text +/// ColumnIndex { +/// index: 0, // field index in main schema +/// side: JoinSide::Left, // child side +/// } +/// ``` pub fn map_origin_col_to_filter_col( filter: &JoinFilter, schema: SchemaRef, side: &JoinSide, ) -> Result> { + let filter_schema = filter.schema(); let mut col_to_col_map: HashMap = HashMap::new(); for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { if index.side.eq(side) { - // Get main field from column index + // Get the main field from column index: let main_field = schema.field(index.index); - // Create a column PhysicalExpr + // Create a column expression: let main_col = Column::new_with_schema(main_field.name(), schema.as_ref())?; - // Since the filter.column_indices() order directly same with intermediate schema fields, we can - // get the column. - let filter_field = filter.schema().field(filter_schema_index); + // Since the order of by filter.column_indices() is the same with + // that of intermediate schema fields, we can get the column directly. + let filter_field = filter_schema.field(filter_schema_index); let filter_col = Column::new(filter_field.name(), filter_schema_index); - // Insert mapping + // Insert mapping: col_to_col_map.insert(main_col, filter_col); } } Ok(col_to_col_map) } -/// This function plays an important role in the expression graph traversal process. It is necessary to analyze the `PhysicalSortExpr` -/// because the sorting of expressions is required for join filter expressions. +/// This function analyzes [PhysicalSortExpr] graphs with respect to monotonicity +/// (sorting) properties. This is necessary since monotonically increasing and/or +/// decreasing expressions are required when using join filter expressions for +/// data pruning purposes. /// /// The method works as follows: /// 1. Maps the original columns to the filter columns using the `map_origin_col_to_filter_col` function. @@ -107,13 +106,12 @@ pub fn map_origin_col_to_filter_col( /// 6. If an exact match is encountered, returns the converted filter expression as `Some(Arc)`. /// 7. If all columns are not included or the exact match is not encountered, returns `None`. /// -/// Use Cases: +/// Examples: /// Consider the filter expression "a + b > c + 10 AND a + b < c + 100". /// 1. If the expression "a@ + d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. /// 2. If the expression "d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. /// 3. If the expression "a@ + b@ + c@" is sorted, all columns are represented in the filter expression. However, /// there is no exact match, so this expression does not indicate pruning. -/// pub fn convert_sort_expr_with_filter_schema( side: &JoinSide, filter: &JoinFilter, @@ -123,9 +121,9 @@ pub fn convert_sort_expr_with_filter_schema( let column_mapping_information: HashMap = map_origin_col_to_filter_col(filter, schema, side)?; let expr = sort_expr.expr.clone(); - // Get main schema columns - let expr_columns = physical_expr_column_collector(&expr); - // Calculation is possible with 'column_mapping_information' since sort exprs belong to a child. + // Get main schema columns: + let expr_columns = collect_columns(&expr); + // Calculation is possible with `column_mapping_information` since sort exprs belong to a child. let all_columns_are_included = expr_columns .iter() .all(|col| column_mapping_information.contains_key(col)); @@ -134,68 +132,69 @@ pub fn convert_sort_expr_with_filter_schema( // the sort expression into a filter expression. let converted_filter_expr = expr .transform_up(&|p| convert_filter_columns(p, &column_mapping_information))?; - - // Search converted PhysicalExpr in filter expression - // If the exact match is encountered, use this sorted expression in graph traversals. + // Search the converted `PhysicalExpr` in filter expression; if an exact + // match is found, use this sorted expression in graph traversals. if check_filter_expr_contains_sort_information( filter.expression(), &converted_filter_expr, ) { - Ok(Some(converted_filter_expr)) - } else { - Ok(None) + return Ok(Some(converted_filter_expr)); } - } else { - Ok(None) } + Ok(None) } /// This function is used to build the filter expression based on the sort order of input columns. /// -/// It first calls the convert_sort_expr_with_filter_schema method to determine if the sort -/// order of columns can be used in the filter expression. -/// If it returns a Some value, the method wraps the result in a SortedFilterExpr -/// instance with the original sort expression, converted filter expression, and sort options. -/// If it returns a None value, this function returns an error. +/// It first calls the [convert_sort_expr_with_filter_schema] method to determine if the sort +/// order of columns can be used in the filter expression. If it returns a [Some] value, the +/// method wraps the result in a [SortedFilterExpr] instance with the original sort expression, +/// converted filter expression, and sort options. Otherwise, this function returns an error. /// -/// The SortedFilterExpr instance contains information about the sort order of columns -/// that can be used in the filter expression, which can be used to optimize the query execution process. +/// The [SortedFilterExpr] instance contains information about the sort order of columns that can +/// be used in the filter expression, which can be used to optimize the query execution process. pub fn build_filter_input_order_v2( side: JoinSide, filter: &JoinFilter, schema: SchemaRef, order: &PhysicalSortExpr, ) -> Result { - match convert_sort_expr_with_filter_schema(&side, filter, schema, order)? { - Some(expr) => Ok(SortedFilterExpr::new(side, order.clone(), expr)), - None => Err(DataFusionError::Plan(format!( + if let Some(expr) = + convert_sort_expr_with_filter_schema(&side, filter, schema, order)? + { + Ok(SortedFilterExpr::new(side, order.clone(), expr)) + } else { + Err(DataFusionError::Plan(format!( "The {side} side of the join does not have an expression sorted." - ))), + ))) } } -/// Convert a physical expression into a filter expression using a column mapping information. +/// Convert a physical expression into a filter expression using the given +/// column mapping information. fn convert_filter_columns( input: Arc, column_mapping_information: &HashMap, ) -> Result>> { // Attempt to downcast the input expression to a Column type. - if let Some(col) = input.as_any().downcast_ref::() { - // If the downcast is successful, retrieve the corresponding filter column. - let filter_col = column_mapping_information.get(col).unwrap().clone(); - // Return the filter column as an Arc wrapped in an Option. - Ok(Some(Arc::new(filter_col))) - } else { - // If the downcast is not successful, return the input expression as is. - Ok(Some(input)) - } + Ok(Some( + if let Some(col) = input.as_any().downcast_ref::() { + // If the downcast is successful, retrieve the corresponding filter column. + let filter_col = column_mapping_information.get(col).unwrap().clone(); + // Return the filter column as an Arc wrapped in an Option. + Arc::new(filter_col) + } else { + // If the downcast fails, return the input expression as is. + input + }, + )) } +/// The SortedFilterExpr object is used to represent a sorted filter expression in the +/// `SymmetricHashJoinExec` struct. It contains information about the join side, the +/// origin expression, the filter expression, and the sort option. The object has +/// several methods to access and modify its fields. #[derive(Debug, Clone)] -/// The SortedFilterExpr struct is used to represent a sorted filter expression in the -/// [SymmetricHashJoinExec] struct. It contains information about the join side, the origin -/// expression, the filter expression, and the sort option. -/// The struct has several methods to access and modify its fields. pub struct SortedFilterExpr { /// Column side pub join_side: JoinSide, @@ -250,75 +249,79 @@ impl SortedFilterExpr { self.node_index = node_index; } } -/// Filter expr for a + b > c + 10 AND a + b < c + 100 -#[allow(dead_code)] -pub(crate) fn complicated_filter() -> Arc { - let left_expr = BinaryExpr::new( - Arc::new(CastExpr::new( - Arc::new(BinaryExpr::new( - Arc::new(Column::new("0", 0)), - Operator::Plus, - Arc::new(Column::new("1", 1)), - )), - DataType::Int64, - CastOptions { safe: false }, - )), - Operator::Gt, - Arc::new(BinaryExpr::new( + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::physical_plan::{ + expressions::Column, + expressions::PhysicalSortExpr, + joins::utils::{ColumnIndex, JoinFilter, JoinSide}, + }; + use arrow::compute::{CastOptions, SortOptions}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Literal}; + use std::sync::Arc; + + /// Filter expr for a + b > c + 10 AND a + b < c + 100 + pub(crate) fn complicated_filter() -> Arc { + let left_expr = BinaryExpr::new( Arc::new(CastExpr::new( - Arc::new(Column::new("2", 2)), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("0", 0)), + Operator::Plus, + Arc::new(Column::new("1", 1)), + )), DataType::Int64, CastOptions { safe: false }, )), - Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int64(Some(10)))), - )), - ); - - let right_expr = BinaryExpr::new( - Arc::new(CastExpr::new( + Operator::Gt, Arc::new(BinaryExpr::new( - Arc::new(Column::new("0", 0)), + Arc::new(CastExpr::new( + Arc::new(Column::new("2", 2)), + DataType::Int64, + CastOptions { safe: false }, + )), Operator::Plus, - Arc::new(Column::new("1", 1)), + Arc::new(Literal::new(ScalarValue::Int64(Some(10)))), )), - DataType::Int64, - CastOptions { safe: false }, - )), - Operator::Lt, - Arc::new(BinaryExpr::new( + ); + + let right_expr = BinaryExpr::new( Arc::new(CastExpr::new( - Arc::new(Column::new("2", 2)), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("0", 0)), + Operator::Plus, + Arc::new(Column::new("1", 1)), + )), DataType::Int64, CastOptions { safe: false }, )), - Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int64(Some(100)))), - )), - ); - - Arc::new(BinaryExpr::new( - Arc::new(left_expr), - Operator::And, - Arc::new(right_expr), - )) -} + Operator::Lt, + Arc::new(BinaryExpr::new( + Arc::new(CastExpr::new( + Arc::new(Column::new("2", 2)), + DataType::Int64, + CastOptions { safe: false }, + )), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int64(Some(100)))), + )), + ); + + Arc::new(BinaryExpr::new( + Arc::new(left_expr), + Operator::And, + Arc::new(right_expr), + )) + } -#[cfg(test)] -mod tests { - use super::*; - use crate::physical_plan::{ - expressions::Column, - expressions::PhysicalSortExpr, - joins::utils::{ColumnIndex, JoinFilter, JoinSide}, - }; - use arrow::compute::{CastOptions, SortOptions}; - use arrow::datatypes::{DataType, Field, Schema}; - use std::sync::Arc; #[test] fn test_column_collector() { let filter_expr = complicated_filter(); - let columns = physical_expr_column_collector(&filter_expr); + let columns = collect_columns(&filter_expr); assert_eq!(columns.len(), 3) } diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/core/src/physical_plan/joins/mod.rs index 629cac4a64a3..fe9954082e11 100644 --- a/datafusion/core/src/physical_plan/joins/mod.rs +++ b/datafusion/core/src/physical_plan/joins/mod.rs @@ -19,11 +19,11 @@ pub use cross_join::CrossJoinExec; pub use hash_join::HashJoinExec; -pub use nested_loop_join::NestedLoopJoinExec; -// Note: SortMergeJoin is not used in plans yet pub use hash_join_utils::{ convert_sort_expr_with_filter_schema, map_origin_col_to_filter_col, SortedFilterExpr, }; +pub use nested_loop_join::NestedLoopJoinExec; +// Note: SortMergeJoin is not used in plans yet pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 340854666fec..ba86d517a5d3 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -27,43 +27,38 @@ use std::vec; use std::{any::Any, usize}; use ahash::RandomState; -use arrow::array::PrimitiveArray; -use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveBuilder}; +use arrow::array::{ + ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, + PrimitiveBuilder, +}; use arrow::compute::concat_batches; -use arrow::datatypes::ArrowNativeType; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use futures::{Stream, StreamExt}; -use hashbrown::raw::RawTable; -use hashbrown::HashSet; +use hashbrown::{raw::RawTable, HashSet}; -use datafusion_common::utils::bisect; -use datafusion_common::ScalarValue; -use datafusion_physical_expr::intervals::ExprIntervalGraph; -use datafusion_physical_expr::intervals::Interval; +use datafusion_common::{utils::bisect, ScalarValue}; +use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval}; -use crate::arrow::array::BooleanBufferBuilder; use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; use crate::logical_expr::JoinType; -use crate::physical_plan::common::merge_batches; -use crate::physical_plan::joins::hash_join::{ - build_join_indices, update_hash, JoinHashMap, -}; -use crate::physical_plan::joins::hash_join_utils::{ - build_filter_input_order_v2, SortedFilterExpr, -}; -use crate::physical_plan::joins::utils::build_batch_from_indices; use crate::physical_plan::{ + common::merge_batches, expressions::Column, expressions::PhysicalSortExpr, - joins::utils::{ - build_join_schema, check_join_is_valid, combine_join_equivalence_properties, - partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, JoinSide, + joins::{ + hash_join::{build_join_indices, update_hash, JoinHashMap}, + hash_join_utils::{build_filter_input_order_v2, SortedFilterExpr}, + utils::{ + build_batch_from_indices, build_join_schema, check_join_is_valid, + combine_join_equivalence_properties, partitioned_join_output_partitioning, + ColumnIndex, JoinFilter, JoinOn, JoinSide, + }, }, metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, - PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; /// A symmetric hash join with range conditions is when both streams are hashed on the @@ -100,8 +95,8 @@ use crate::physical_plan::{ /// | | /// +-------------------------+ /// -/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetics -/// on JoinFilter's sorted PhysicalExprs to calculate joinable range. +/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic +/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range. /// /// /// PROBE SIDE BUILD SIDE @@ -150,19 +145,15 @@ use crate::physical_plan::{ /// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" , /// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending) /// than that can be dropped from the inner buffer. -/// -/// -/// /// ``` -/// pub struct SymmetricHashJoinExec { - /// left side stream + /// Left side stream pub(crate) left: Arc, - /// right side stream + /// Right side stream pub(crate) right: Arc, /// Set of common columns used to join on pub(crate) on: Vec<(Column, Column)>, - /// Filters which are applied while finding matching rows + /// Filters applied when finding matching rows pub(crate) filter: JoinFilter, /// How the join is performed pub(crate) join_type: JoinType, @@ -184,18 +175,18 @@ pub struct SymmetricHashJoinExec { #[derive(Debug)] struct SymmetricHashJoinSideMetrics { - /// Number of right batches consumed by this operator + /// Number of batches consumed by this operator input_batches: metrics::Count, - /// Number of left rows consumed by this operator + /// Number of rows consumed by this operator input_rows: metrics::Count, } /// Metrics for HashJoinExec #[derive(Debug)] struct SymmetricHashJoinMetrics { - /// Number of left batches consumed by this operator + /// Number of left batches/rows consumed by this operator left: SymmetricHashJoinSideMetrics, - /// Number of right batches consumed by this operator + /// Number of right batches/rows consumed by this operator right: SymmetricHashJoinSideMetrics, /// Number of batches produced by this operator output_batches: metrics::Count, @@ -238,8 +229,10 @@ impl SymmetricHashJoinMetrics { impl SymmetricHashJoinExec { /// Tries to create a new [SymmetricHashJoinExec]. /// # Error - /// This function errors when it is not possible to join the left and right sides on keys `on`, - /// iterate and construct [SortedFilterExpr]s,and create [ExprIntervalGraph] + /// This function errors when: + /// - It is not possible to join the left and right sides on keys `on`, or + /// - It fails to construct [SortedFilterExpr]s, or + /// - It fails to create the [ExprIntervalGraph]. pub fn try_new( left: Arc, right: Arc, @@ -373,20 +366,14 @@ impl ExecutionPlan for SymmetricHashJoinExec { } fn unbounded_output(&self, children: &[bool]) -> Result { - let (left, right) = (children[0], children[1]); - Ok(left || right) + Ok(children.iter().any(|u| *u)) } fn required_input_distribution(&self) -> Vec { let (left_expr, right_expr) = self .on .iter() - .map(|(l, r)| { - ( - Arc::new(l.clone()) as Arc, - Arc::new(r.clone()) as Arc, - ) - }) + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) .unzip(); // TODO: This will change when we extend collected executions. vec![ @@ -412,6 +399,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { left_columns_len, ) } + // TODO Output ordering might be kept for some cases. // For example if it is inner join then the stream side order can be kept fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { @@ -523,7 +511,7 @@ struct SymmetricHashJoinStream { right: OneSideHashJoiner, /// Information of index and left / right placement of columns column_indices: Vec, - // Range Prunner. + // Range pruner. physical_expr_graph: ExprIntervalGraph, /// Information of filter columns filter_columns: Vec, @@ -533,7 +521,7 @@ struct SymmetricHashJoinStream { null_equals_null: bool, /// Metrics metrics: SymmetricHashJoinMetrics, - /// There is nothing to process anymore + /// Flag indicating whether there is nothing to process anymore final_result: bool, /// The current probe side. We choose build and probe side according to this attribute. probe_side: JoinSide, @@ -577,31 +565,19 @@ fn prune_hash_values( for (hash_value, index_set) in hash_value_map.iter() { if let Some((_, separation_chain)) = hashmap .0 - .get_mut(*hash_value, |(hash, _)| *hash_value == *hash) + .get_mut(*hash_value, |(hash, _)| hash_value == hash) { separation_chain.retain(|n| !index_set.contains(n)); if separation_chain.is_empty() { hashmap .0 - .remove_entry(*hash_value, |(hash, _)| *hash_value == *hash); + .remove_entry(*hash_value, |(hash, _)| hash_value == hash); } } } Ok(()) } -fn prune_visited_rows( - prune_length: usize, - visited_rows: &mut HashSet, - deleted_offset: usize, -) -> Result<()> { - (deleted_offset..(deleted_offset + prune_length)).for_each(|row| { - visited_rows.remove(&row); - }); - Ok(()) -} - -/// /// Calculate the filter expression intervals /// /// This function updates the `interval` field of each `SortedFilterExpr` in `filter_columns` based on @@ -614,7 +590,8 @@ fn prune_visited_rows( /// * `filter_columns` - A mutable list of `SortedFilterExpr` objects to update /// * `build_side` - The side of the join where `build_input_buffer` is used /// -/// # Note +/// ### Note +/// ```text /// /// Build Probe /// +-------+ +-------+ @@ -633,7 +610,7 @@ fn prune_visited_rows( /// Expr2: [b, inf] /// Expr3: [c, inf] /// Expr4: [d, inf] -/// +/// ``` /// /// # Returns /// @@ -645,12 +622,11 @@ fn calculate_filter_expr_intervals( filter_columns: &mut [SortedFilterExpr], build_side: JoinSide, ) -> Result<()> { - // If either build or probe side has no data, return early + // If either build or probe side has no data, return early: if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(()); } - - // Iterate through each SortedFilterExpr in filter_columns + // Iterate through each [SortedFilterExpr] in filter_columns: for sorted_expr in filter_columns.iter_mut() { // Destructure the SortedFilterExpr let SortedFilterExpr { @@ -661,26 +637,20 @@ fn calculate_filter_expr_intervals( } = sorted_expr; // Get the array of the first or last value for the expression, depending on join side + let expr = &origin_sorted_expr.expr; let array = if build_side.eq(join_side) { // Get first value for expr - origin_sorted_expr - .expr - .evaluate(&build_input_buffer.slice(0, 1))? - .into_array(1) + expr.evaluate(&build_input_buffer.slice(0, 1)) } else { // Get last value for expr - origin_sorted_expr - .expr - .evaluate(&probe_batch.slice(probe_batch.num_rows() - 1, 1))? - .into_array(1) - }; + expr.evaluate(&probe_batch.slice(probe_batch.num_rows() - 1, 1)) + }? + .into_array(1); - // Convert the array to a ScalarValue + // Convert the array to a ScalarValue: let value = ScalarValue::try_from_array(&array, 0)?; - - // Create a ScalarValue representing positive or negative infinity for the same data type + // Create a ScalarValue representing positive or negative infinity for the same data type: let infinite = ScalarValue::try_from(value.get_datatype())?; - // Update the interval with lower and upper bounds based on the sort option *interval = if origin_sorted_expr.options.descending { Interval { @@ -731,19 +701,18 @@ fn determine_prune_length( None } }) - .collect::>>() - .into_iter() .collect::>>()? .into_iter() .min() .unwrap()) } + /// This method determines if the result of the join should be produced in the final step or not. /// /// # Arguments /// -/// * build_side - Enum indicating the side of the join used as the build side. -/// * join_type - Enum indicating the type of join to be performed. +/// * `build_side` - Enum indicating the side of the join used as the build side. +/// * `join_type` - Enum indicating the type of join to be performed. /// /// # Returns /// @@ -752,8 +721,8 @@ fn determine_prune_length( /// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi. /// If the build side is JoinSide::Right, the result will be true if the join type /// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi. -fn need_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool { - if build_side.eq(&JoinSide::Left) { +fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool { + if build_side == JoinSide::Left { matches!( join_type, JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi @@ -772,9 +741,9 @@ fn need_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bo /// /// # Arguments /// -/// * prune_length - The length of the pruned record batch. -/// * deleted_offset - The offset to the indices. -/// * visited_rows - The hash set of visited indices. +/// * `prune_length` - The length of the pruned record batch. +/// * `deleted_offset` - The offset to the indices. +/// * `visited_rows` - The hash set of visited indices. /// /// # Returns /// @@ -790,14 +759,14 @@ where let mut bitmap = BooleanBufferBuilder::new(prune_length); bitmap.append_n(prune_length, false); // mark the indices as true if they are present in the visited hash set - (0..prune_length).for_each(|v| { - let row = &(v + deleted_offset); - bitmap.set_bit(v, visited_rows.contains(row)); - }); + for v in 0..prune_length { + let row = v + deleted_offset; + bitmap.set_bit(v, visited_rows.contains(&row)); + } // get the anti index (0..prune_length) .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) - .collect::>() + .collect() } /// This method creates a boolean buffer from the visited rows hash set @@ -807,13 +776,13 @@ where /// /// # Arguments /// -/// * prune_length - The length of the pruned record batch. -/// * deleted_offset - The offset to the indices. -/// * visited_rows - The hash set of visited indices. +/// * `prune_length` - The length of the pruned record batch. +/// * `deleted_offset` - The offset to the indices. +/// * `visited_rows` - The hash set of visited indices. /// /// # Returns /// -/// A PrimitiveArray of the specified type T, containing the semi indices. +/// A [PrimitiveArray] of the specified type T, containing the semi indices. fn get_semi_indices( prune_length: usize, deleted_offset: usize, @@ -848,8 +817,7 @@ fn record_visited_indices( offset: usize, indices: &PrimitiveArray, ) { - let batch_indices: &[T::Native] = indices.values(); - for i in batch_indices { + for i in indices.values() { visited.insert(i.as_usize() + offset); } } @@ -961,7 +929,7 @@ impl OneSideHashJoiner { /// * `batch` - The incoming RecordBatch to be merged with the internal input buffer /// * `random_state` - The random state used to hash values /// - /// # Return + /// # Returns /// /// Returns a Result with an empty Ok if the update was successful, otherwise returns an error. fn update_internal_state( @@ -1009,8 +977,8 @@ impl OneSideHashJoiner { /// /// # Returns /// - /// A `Result` containing an optional record batch if the join type is not one of LeftAnti, RightAnti, LeftSemi or RightSemi. - /// If the join type is one of the above four, the function will return None. + /// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`. + /// If the join type is one of the above four, the function will return [None]. #[allow(clippy::too_many_arguments)] fn join_with_probe_batch( &mut self, @@ -1041,14 +1009,14 @@ impl OneSideHashJoiner { Some(self.deleted_offset), self.build_side, )?; - if need_produce_result_in_final(self.build_side, join_type) { + if need_to_produce_result_in_final(self.build_side, join_type) { record_visited_indices( &mut self.visited_rows, self.deleted_offset, &build_side, ); } - if need_produce_result_in_final(self.build_side.negate(), join_type) { + if need_to_produce_result_in_final(self.build_side.negate(), join_type) { record_visited_indices(probe_visited, probe_offset, &probe_side); } if matches!( @@ -1089,7 +1057,7 @@ impl OneSideHashJoiner { /// /// # Returns /// - /// * `Option` - The final output record batch if required, otherwise `None`. + /// * `Option` - The final output record batch if required, otherwise [None]. fn build_side_determined_results( &self, output_schema: SchemaRef, @@ -1099,7 +1067,7 @@ impl OneSideHashJoiner { column_indices: &[ColumnIndex], ) -> Result> { // Check if we need to produce a result in the final output - let result = if need_produce_result_in_final(self.build_side, join_type) { + let result = if need_to_produce_result_in_final(self.build_side, join_type) { // Calculate the indices for build and probe sides based on join type and build side let (build_indices, probe_indices) = calculate_indices_by_join_type( self.build_side, @@ -1216,11 +1184,9 @@ impl OneSideHashJoiner { &mut self.row_hash_values, self.deleted_offset as u64, )?; - prune_visited_rows( - prune_length, - &mut self.visited_rows, - self.deleted_offset, - )?; + for row in self.deleted_offset..(self.deleted_offset + prune_length) { + self.visited_rows.remove(&row); + } self.input_buffer = self .input_buffer .slice(prune_length, self.input_buffer.num_rows() - prune_length); @@ -1423,12 +1389,13 @@ mod tests { use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::intervals::test_utils::gen_conjunctive_numeric_expr; + use datafusion_physical_expr::PhysicalExpr; - use crate::physical_plan::collect; - use crate::physical_plan::joins::hash_join_utils::complicated_filter; - use crate::physical_plan::joins::{HashJoinExec, PartitionMode}; + use crate::physical_plan::joins::{ + hash_join_utils::tests::complicated_filter, HashJoinExec, PartitionMode, + }; use crate::physical_plan::{ - common, memory::MemoryExec, repartition::RepartitionExec, + collect, common, memory::MemoryExec, repartition::RepartitionExec, }; use crate::prelude::{SessionConfig, SessionContext}; use crate::test_util; @@ -1472,12 +1439,12 @@ mod tests { let left_expr = on .iter() - .map(|(l, _)| Arc::new(l.clone()) as Arc) + .map(|(l, _)| Arc::new(l.clone()) as _) .collect::>(); let right_expr = on .iter() - .map(|(_, r)| Arc::new(r.clone()) as Arc) + .map(|(_, r)| Arc::new(r.clone()) as _) .collect::>(); let join = SymmetricHashJoinExec::try_new( @@ -1523,12 +1490,7 @@ mod tests { let (left_expr, right_expr) = on .iter() - .map(|(l, r)| { - ( - Arc::new(l.clone()) as Arc, - Arc::new(r.clone()) as Arc, - ) - }) + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) .unzip(); let join = HashJoinExec::try_new( diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index 5a573d30243e..382c79df73ef 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -228,14 +228,14 @@ pub fn cross_join_equivalence_properties( impl Display for JoinSide { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - JoinSide::Left => write!(f, "Left"), - JoinSide::Right => write!(f, "Right"), + JoinSide::Left => write!(f, "left"), + JoinSide::Right => write!(f, "right"), } } } /// Used in ColumnIndex to distinguish which side the index is for -#[derive(Debug, Clone, PartialEq, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum JoinSide { /// Left side of the join Left, diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index 1ee912d7d089..a5f8263f47d3 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -388,18 +388,19 @@ mod tests { assert!(PathBuf::from(res).is_dir()); } } -/// It creates unbounded sorted files for tests + +/// This function creates an unbounded sorted file for testing purposes. pub async fn test_create_unbounded_sorted_file( ctx: &SessionContext, file_path: PathBuf, table_name: &str, ) -> datafusion_common::Result<()> { - // Register table + // Create schema: let schema = Arc::new(Schema::new(vec![ Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - // The sort order is specified + // Specify the ordering: let file_sort_order = [datafusion_expr::col("a1")] .into_iter() .map(|e| { @@ -408,20 +409,20 @@ pub async fn test_create_unbounded_sorted_file( e.sort(ascending, nulls_first) }) .collect::>(); - // Mark infinite and provide schema + // Mark infinite and provide schema: let fifo_options = CsvReadOptions::new() .schema(schema.as_ref()) .has_header(false) .mark_infinite(true); - // Get listing options + // Get listing options: let options_sort = fifo_options .to_listing_options(&ctx.copied_config()) .with_file_sort_order(Some(file_sort_order)); - // Register table + // Register table: ctx.register_listing_table( table_name, file_path.as_os_str().to_str().unwrap(), - options_sort.clone(), + options_sort, Some(schema), None, ) diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 9604e8b4da8c..5899e4d479b7 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -22,10 +22,11 @@ mod unix_test { use arrow::array::Array; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::test_util::test_create_unbounded_sorted_file; use datafusion::{ prelude::{CsvReadOptions, SessionConfig, SessionContext}, - test_util::{aggr_test_schema, arrow_test_data}, + test_util::{ + aggr_test_schema, arrow_test_data, test_create_unbounded_sorted_file, + }, }; use datafusion_common::{DataFusionError, Result}; use futures::StreamExt; @@ -140,7 +141,7 @@ mod unix_test { } // This test provides a relatively realistic end-to-end scenario where - // we ensure that unbounded execution happened + // we swap join sides to accommodate a FIFO source. #[rstest] #[timeout(std::time::Duration::from_secs(30))] #[tokio::test(flavor = "multi_thread", worker_threads = 5)] @@ -224,6 +225,7 @@ mod unix_test { assert_eq!(interleave(&result), unbounded_file); Ok(()) } + #[derive(Debug, PartialEq)] enum JoinOperation { LeftUnmatched, @@ -232,8 +234,8 @@ mod unix_test { } // This test provides a relatively realistic end-to-end scenario where - // we ensure that we change join into [SymmetricHashJoin] correctly to accommodate the FIFO sources. - // Creates 2 FIFO files with unbounded source. + // we change the join into a [SymmetricHashJoin] to accommodate two + // unbounded (FIFO) sources. #[rstest] #[timeout(std::time::Duration::from_secs(30))] #[tokio::test(flavor = "multi_thread")] @@ -264,16 +266,15 @@ mod unix_test { // Reference time to use when deciding to fail the test let execution_start = Instant::now(); // Join filter - let a1_iter = (0..TEST_DATA_SIZE).clone().map(|x| { - let change = rng.gen_range(0.0..1.0); - if change < 0.3 { + let a1_iter = (0..TEST_DATA_SIZE).map(|x| { + if rng.gen_range(0.0..1.0) < 0.3 { x - 1 } else { x } }); // Join key - let a2_iter = (0..TEST_DATA_SIZE).clone().map(|x| x % 10); + let a2_iter = (0..TEST_DATA_SIZE).map(|x| x % 10); for (cnt, (a1, a2)) in a1_iter.zip(a2_iter).enumerate() { // Wait a reading sign for unbounded execution // After first batch FIFO reading, we will wait for a batch created. diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index d4a40ea083e6..e9074b0d8c82 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -112,7 +112,7 @@ impl Operator { /// Return true if the operator is a comparison operator. /// - /// For example, 'Binary(a, >, b)' would be a comparison. + /// For example, 'Binary(a, >, b)' would be a comparison expression. pub fn is_comparison_operator(&self) -> bool { matches!( self, @@ -122,6 +122,8 @@ impl Operator { | Operator::LtEq | Operator::Gt | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom | Operator::RegexMatch | Operator::RegexIMatch | Operator::RegexNotMatch @@ -131,18 +133,10 @@ impl Operator { /// Return true if the operator is a logic operator. /// - /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would be a logic. + /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would + /// be a logical expression. pub fn is_logic_operator(&self) -> bool { - matches!( - self, - Operator::And - | Operator::Or - | Operator::BitwiseAnd - | Operator::BitwiseOr - | Operator::BitwiseXor - | Operator::BitwiseShiftRight - | Operator::BitwiseShiftLeft - ) + matches!(self, Operator::And | Operator::Or) } /// Return the operator where swapping lhs and rhs wouldn't change the result. diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index a41263ac55d8..120ef0f09a83 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -218,13 +218,11 @@ impl TreeNodeRewritable for ExprTreeNode { where F: FnMut(Self) -> Result, { - let children = self.children(); - if !children.is_empty() { - self.child_nodes = children - .into_iter() - .map(transform) - .collect::>>()?; - } + self.child_nodes = self + .children() + .into_iter() + .map(transform) + .collect::>>()?; Ok(self) } } From 263eab01e1299082890cb75aa02b100c9b35fc77 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Fri, 10 Feb 2023 17:01:31 +0300 Subject: [PATCH 26/39] Revamp SortedFilterExpr usage and enhance comments --- .../core/src/physical_plan/joins/hash_join.rs | 1 - .../physical_plan/joins/hash_join_utils.rs | 33 +- .../joins/symmetric_hash_join.rs | 305 ++++++++++-------- datafusion/core/tests/fifo.rs | 5 +- 4 files changed, 188 insertions(+), 156 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index d013995cdadb..0ad7615539d0 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -777,7 +777,6 @@ pub fn build_equal_condition_join_indices( .into_array(build_input_buffer.num_rows())) }) .collect::>>()?; - hashes_buffer.clear(); hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; // Using a buffer builder to avoid slower normal builder diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index dccd34241438..2d62d0e9b767 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Hash Join related functionality used both on logical and physical plans +//! Symmetric and regular Hash Join related functionality used in join calculations and optimization +//! rules. use std::collections::HashMap; use std::sync::Arc; @@ -148,12 +149,12 @@ pub fn convert_sort_expr_with_filter_schema( /// /// It first calls the [convert_sort_expr_with_filter_schema] method to determine if the sort /// order of columns can be used in the filter expression. If it returns a [Some] value, the -/// method wraps the result in a [SortedFilterExpr] instance with the original sort expression, -/// converted filter expression, and sort options. Otherwise, this function returns an error. +/// method wraps the result in a [SortedFilterExpr] instance with the original sort expression and +/// converted filter expression. Otherwise, this function returns an error. /// /// The [SortedFilterExpr] instance contains information about the sort order of columns that can /// be used in the filter expression, which can be used to optimize the query execution process. -pub fn build_filter_input_order_v2( +pub fn build_filter_input_order( side: JoinSide, filter: &JoinFilter, schema: SchemaRef, @@ -162,7 +163,7 @@ pub fn build_filter_input_order_v2( if let Some(expr) = convert_sort_expr_with_filter_schema(&side, filter, schema, order)? { - Ok(SortedFilterExpr::new(side, order.clone(), expr)) + Ok(SortedFilterExpr::new(order.clone(), expr)) } else { Err(DataFusionError::Plan(format!( "The {side} side of the join does not have an expression sorted." @@ -191,33 +192,29 @@ fn convert_filter_columns( } /// The SortedFilterExpr object is used to represent a sorted filter expression in the -/// `SymmetricHashJoinExec` struct. It contains information about the join side, the +/// `SymmetricHashJoinExec` struct. It contains information about the /// origin expression, the filter expression, and the sort option. The object has /// several methods to access and modify its fields. #[derive(Debug, Clone)] pub struct SortedFilterExpr { - /// Column side - pub join_side: JoinSide, /// Sorted expr from a particular join side (child) - pub origin_sorted_expr: PhysicalSortExpr, + origin_sorted_expr: PhysicalSortExpr, /// For interval calculations, one to one mapping of the columns according to filter expression, /// and column indices. - pub filter_expr: Arc, + filter_expr: Arc, /// Interval - pub interval: Interval, + interval: Interval, /// NodeIndex in Graph - pub node_index: usize, + node_index: usize, } impl SortedFilterExpr { /// Constructor pub fn new( - join_side: JoinSide, origin_sorted_expr: PhysicalSortExpr, filter_expr: Arc, ) -> Self { Self { - join_side, origin_sorted_expr, filter_expr, interval: Interval::default(), @@ -414,7 +411,7 @@ pub mod tests { let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - assert!(build_filter_input_order_v2( + assert!(build_filter_input_order( JoinSide::Left, &filter, left_schema.clone(), @@ -424,7 +421,7 @@ pub mod tests { } ) .is_ok()); - assert!(build_filter_input_order_v2( + assert!(build_filter_input_order( JoinSide::Left, &filter, left_schema, @@ -434,7 +431,7 @@ pub mod tests { } ) .is_err()); - assert!(build_filter_input_order_v2( + assert!(build_filter_input_order( JoinSide::Right, &filter, right_schema.clone(), @@ -444,7 +441,7 @@ pub mod tests { } ) .is_ok()); - assert!(build_filter_input_order_v2( + assert!(build_filter_input_order( JoinSide::Right, &filter, right_schema, diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index ba86d517a5d3..785c8bbd0d59 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -15,8 +15,14 @@ // specific language governing permissions and limitations // under the License. -//! Defines the join plan for executing partitions in parallel and then joining the results -//! into a set of partitions. +//! The code implements the symmetric hash join algorithm for executing partitions in parallel +//! streams and buffering with range-condition pruning. +//! +//! A Symmetric Hash Join plan consumes two sorted children plan and produces +//! joined output by given join type and other options. +//! +//! It includes OneSideHashJoiner struct constructed for both childs and calculates joins. +//! use std::collections::{HashMap, VecDeque}; use std::fmt; @@ -49,7 +55,7 @@ use crate::physical_plan::{ expressions::PhysicalSortExpr, joins::{ hash_join::{build_join_indices, update_hash, JoinHashMap}, - hash_join_utils::{build_filter_input_order_v2, SortedFilterExpr}, + hash_join_utils::{build_filter_input_order, SortedFilterExpr}, utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, combine_join_equivalence_properties, partitioned_join_output_partitioning, @@ -157,8 +163,12 @@ pub struct SymmetricHashJoinExec { pub(crate) filter: JoinFilter, /// How the join is performed pub(crate) join_type: JoinType, - /// Order information of filter columns - filter_columns: Vec, + /// Order information of filter expressions + sorted_filter_exprs: Vec, + /// Left required sort + left_required_sort_exprs: Vec, + /// Right required sort + right_required_sort_exprs: Vec, /// Expression graph for interval calculations physical_expr_graph: ExprIntervalGraph, /// The schema once the join is applied @@ -265,36 +275,48 @@ impl SymmetricHashJoinExec { let mut physical_expr_graph = ExprIntervalGraph::try_new(filter.expression().clone())?; + let (left_required_sort_exprs, right_required_sort_exprs): (Vec<_>, Vec<_>) = + left.output_ordering() + .unwrap() + .iter() + .zip(right.output_ordering().unwrap()) + .take(1) + .map(|(l, r)| (l.clone(), r.clone())) + .unzip(); + // Build the sorted filter expression for the left child - let left_filter_expression = build_filter_input_order_v2( + let left_filter_expression = build_filter_input_order( JoinSide::Left, &filter, left.schema(), - left.output_ordering().unwrap().get(0).unwrap(), + &left_required_sort_exprs[0], )?; + // Build the sorted filter expression for the right child - let right_filter_expression = build_filter_input_order_v2( + let right_filter_expression = build_filter_input_order( JoinSide::Right, &filter, right.schema(), - right.output_ordering().unwrap().get(0).unwrap(), + &right_required_sort_exprs[0], )?; // Store the left and right sorted filter expressions in a vector - let mut filter_columns = vec![left_filter_expression, right_filter_expression]; + let mut sorted_filter_exprs = + vec![left_filter_expression, right_filter_expression]; // Gather node indices of converted filter expressions in SortedFilterExpr // using the filter columns vector let child_node_indexes = physical_expr_graph.gather_node_indices( - &filter_columns + &sorted_filter_exprs .iter() .map(|sorted_expr| sorted_expr.filter_expr().clone()) .collect::>(), ); // Inject calculated node indexes into SortedFilterExpr - for (sorted_expr, (_, index)) in - filter_columns.iter_mut().zip(child_node_indexes.iter()) + for (sorted_expr, (_, index)) in sorted_filter_exprs + .iter_mut() + .zip(child_node_indexes.iter()) { sorted_expr.set_node_index(*index); } @@ -305,7 +327,9 @@ impl SymmetricHashJoinExec { on, filter, join_type: *join_type, - filter_columns, + sorted_filter_exprs, + left_required_sort_exprs, + right_required_sort_exprs, physical_expr_graph, schema: Arc::new(schema), random_state, @@ -362,7 +386,10 @@ impl ExecutionPlan for SymmetricHashJoinExec { } fn required_input_ordering(&self) -> Vec> { - vec![None, None] + vec![ + Some(&self.left_required_sort_exprs), + Some(&self.right_required_sort_exprs), + ] } fn unbounded_output(&self, children: &[bool]) -> Result { @@ -401,7 +428,6 @@ impl ExecutionPlan for SymmetricHashJoinExec { } // TODO Output ordering might be kept for some cases. - // For example if it is inner join then the stream side order can be kept fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { None } @@ -465,11 +491,18 @@ impl ExecutionPlan for SymmetricHashJoinExec { ) -> Result { let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); - // TODO: Currently, working on partitioned left and right. We can also coalesce. - let left_side_joiner = - OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.clone()); - let right_side_joiner = - OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.clone()); + let left_side_joiner = OneSideHashJoiner::new( + JoinSide::Left, + self.sorted_filter_exprs[0].clone(), + on_left, + self.left.clone(), + ); + let right_side_joiner = OneSideHashJoiner::new( + JoinSide::Right, + self.sorted_filter_exprs[1].clone(), + on_right, + self.right.clone(), + ); let left_stream = self.left.execute(partition, context.clone())?; let right_stream = self.right.execute(partition, context)?; @@ -486,7 +519,6 @@ impl ExecutionPlan for SymmetricHashJoinExec { metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), physical_expr_graph: self.physical_expr_graph.clone(), null_equals_null: self.null_equals_null, - filter_columns: self.filter_columns.clone(), final_result: false, probe_side: JoinSide::Left, })) @@ -513,8 +545,6 @@ struct SymmetricHashJoinStream { column_indices: Vec, // Range pruner. physical_expr_graph: ExprIntervalGraph, - /// Information of filter columns - filter_columns: Vec, /// Random state used for hashing initialization random_state: RandomState, /// If null_equals_null is true, null == null else null != null @@ -580,15 +610,15 @@ fn prune_hash_values( /// Calculate the filter expression intervals /// -/// This function updates the `interval` field of each `SortedFilterExpr` in `filter_columns` based on +/// This function updates the `interval` field of each `SortedFilterExpr` based on /// the first or last value of the expression in `build_input_buffer` or `probe_batch`. /// /// # Arguments /// /// * `build_input_buffer` - The RecordBatch on the build side of the join +/// * `build_sorted_filter_expr` - Build side `SortedFilterExpr` to update. /// * `probe_batch` - The RecordBatch on the probe side of the join -/// * `filter_columns` - A mutable list of `SortedFilterExpr` objects to update -/// * `build_side` - The side of the join where `build_input_buffer` is used +/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. /// /// ### Note /// ```text @@ -606,105 +636,114 @@ fn prune_hash_values( /// | | | ||c |d || /// | | | |+--|--+| /// +-------+ +-------+ -/// Expr1: [a, inf] -/// Expr2: [b, inf] -/// Expr3: [c, inf] -/// Expr4: [d, inf] -/// ``` /// -/// # Returns +/// FV -> First value +/// LV -> Last value +/// +/// Interval arithmetic is used to calculate the range of possible join key values for +/// build-side pruning. This is done by first creating intervals for the join filter values +/// in the build side of the join, which is first value of the side and infinity since the stream is also infinite. +/// If the expression is descending, interval becomes [-∞, FV] and if ascending, +/// interval becomes [FV, ∞]. For the probe side, the last of value of the RecordBatch is used +/// for interval calculation. This time we use the last value of the batch, +/// if the expression is descending, the interval becomes [-∞, LV] and if it is ascending, +/// it becomes [LV, ∞]. +/// +/// In our case: +/// Expr1: [a(FV), inf] +/// Expr2: [b(FV), inf] +/// Expr3: [c(LV), inf] +/// Expr4: [d(LV), inf] +/// ``` /// -/// Result object indicating success or failure /// fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, + build_sorted_filter_expr: &mut SortedFilterExpr, probe_batch: &RecordBatch, - filter_columns: &mut [SortedFilterExpr], - build_side: JoinSide, + probe_sorted_filter_expr: &mut SortedFilterExpr, ) -> Result<()> { // If either build or probe side has no data, return early: if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(()); } - // Iterate through each [SortedFilterExpr] in filter_columns: - for sorted_expr in filter_columns.iter_mut() { - // Destructure the SortedFilterExpr - let SortedFilterExpr { - join_side, - origin_sorted_expr, - interval, - .. - } = sorted_expr; - - // Get the array of the first or last value for the expression, depending on join side - let expr = &origin_sorted_expr.expr; - let array = if build_side.eq(join_side) { - // Get first value for expr - expr.evaluate(&build_input_buffer.slice(0, 1)) - } else { - // Get last value for expr - expr.evaluate(&probe_batch.slice(probe_batch.num_rows() - 1, 1)) - }? + // Evaluate build side filter expression and convert the result to an array + let build_array = build_sorted_filter_expr + .origin_sorted_expr() + .expr + .evaluate(&build_input_buffer.slice(0, 1))? + .into_array(1); + // Evaluate probe side filter expression and convert the result to an array + let probe_array = probe_sorted_filter_expr + .origin_sorted_expr() + .expr + .evaluate(&probe_batch.slice(probe_batch.num_rows() - 1, 1))? .into_array(1); + // Update intervals for both build and probe side filter expressions + for (array, sorted_expr) in vec![ + (build_array, build_sorted_filter_expr), + (probe_array, probe_sorted_filter_expr), + ] { // Convert the array to a ScalarValue: let value = ScalarValue::try_from_array(&array, 0)?; // Create a ScalarValue representing positive or negative infinity for the same data type: let infinite = ScalarValue::try_from(value.get_datatype())?; // Update the interval with lower and upper bounds based on the sort option - *interval = if origin_sorted_expr.options.descending { - Interval { - lower: infinite, - upper: value, - } - } else { - Interval { - lower: value, - upper: infinite, - } - }; + sorted_expr.set_interval( + if sorted_expr.origin_sorted_expr().options.descending { + Interval { + lower: infinite, + upper: value, + } + } else { + Interval { + lower: value, + upper: infinite, + } + }, + ); } Ok(()) } +/// Determine the length of the record batch to be pruned. +/// +/// This function evaluates the build side filter expression, converts the result into an array and +/// determines the length of the record batch to be pruned by performing a binary search on the array. +/// +/// # Parameters +/// +/// * `buffer`: The record batch to be pruned. +/// * `build_side_filter_expr`: The filter expression on the build side used to determine the length +/// of the record batch to be pruned. +/// +/// # Returns +/// +/// A Result object that contains the length of the record batch to be pruned. An error will be +/// returned if there is an issue evaluating the build side filter expression. fn determine_prune_length( buffer: &RecordBatch, - filter_columns: &[SortedFilterExpr], - build_side: JoinSide, + build_side_filter_expr: &SortedFilterExpr, ) -> Result { - Ok(filter_columns - .iter() - .flat_map(|sorted_expr| { - let SortedFilterExpr { - join_side, - origin_sorted_expr, - interval, - .. - } = sorted_expr; - if build_side.eq(join_side) { - let batch_arr = origin_sorted_expr - .expr - .evaluate(buffer) - .unwrap() - .into_array(buffer.num_rows()); - let target = if origin_sorted_expr.options.descending { - interval.upper.clone() - } else { - interval.lower.clone() - }; - Some(bisect::( - &[batch_arr], - &[target], - &[origin_sorted_expr.options], - )) - } else { - None - } - }) - .collect::>>()? - .into_iter() - .min() - .unwrap()) + let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr(); + let interval = build_side_filter_expr.interval(); + // Evaluate the build side filter expression and convert it into an array + let batch_arr = origin_sorted_expr + .expr + .evaluate(buffer) + .unwrap() + .into_array(buffer.num_rows()); + + // Get the lower or upper interval based on the sort direction + let target = if origin_sorted_expr.options.descending { + interval.upper.clone() + } else { + interval.lower.clone() + }; + + // Perform binary search on the array to determine the length of the record batch to be pruned + bisect::(&[batch_arr], &[target], &[origin_sorted_expr.options]) } /// This method determines if the result of the join should be produced in the final step or not. @@ -882,6 +921,8 @@ where struct OneSideHashJoiner { // Build side build_side: JoinSide, + /// Build side filter sort information + sorted_filter_expr: SortedFilterExpr, /// Inout record batch buffer input_buffer: RecordBatch, /// columns from the side @@ -905,6 +946,7 @@ struct OneSideHashJoiner { impl OneSideHashJoiner { pub fn new( build_side: JoinSide, + sorted_filter_expr: SortedFilterExpr, on: Vec, plan: Arc, ) -> Self { @@ -915,6 +957,7 @@ impl OneSideHashJoiner { hashmap: JoinHashMap(RawTable::with_capacity(10_000)), row_hash_values: VecDeque::new(), hashes_buffer: vec![], + sorted_filter_expr, visited_rows: HashSet::new(), offset: 0, deleted_offset: 0, @@ -1099,15 +1142,15 @@ impl OneSideHashJoiner { /// Prunes the internal buffer. /// - /// The input record batch, `probe_batch`, is used to update the intervals of the sorted filter expressions in `sorted_filter_exprs`. - /// The updated intervals are then used to determine the length of the build side rows to prune. + /// The input record batch and `probe_batch`, is used to update the intervals of the sorted filter expressions. + /// The updated build interval is then used to determine the length of the build side rows to prune. /// If there are rows to prune, they are pruned and the result is returned. /// /// # Arguments /// /// * `schema` - The schema of the final output record batch /// * `probe_batch` - Incoming RecordBatch of the probe side. - /// * `sorted_filter_exprs` - A mutable vector of sorted filter expressions that are used to update the intervals. + /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression. /// * `join_type` - The type of join (e.g. inner, left, right, etc.). /// * `column_indices` - A vector of column indices that specifies which columns from the /// build side should be included in the output. @@ -1122,7 +1165,7 @@ impl OneSideHashJoiner { &mut self, schema: SchemaRef, probe_batch: &RecordBatch, - sorted_filter_exprs: &mut [SortedFilterExpr], + probe_side_sorted_filter_expr: &mut SortedFilterExpr, join_type: JoinType, column_indices: &[ColumnIndex], physical_expr_graph: &mut ExprIntervalGraph, @@ -1134,37 +1177,29 @@ impl OneSideHashJoiner { // Convert the sorted filter expressions into a vector of (node_index, interval) tuples // for use in updating the interval graph - let mut filter_intervals = sorted_filter_exprs - .iter() - .map(|sorted_expr| (sorted_expr.node_index(), sorted_expr.interval().clone())) - .collect::>(); - - // Use the filter intervals to update the physical expression graph + let mut filter_intervals = vec![ + ( + self.sorted_filter_expr.node_index(), + self.sorted_filter_expr.interval().clone(), + ), + ( + probe_side_sorted_filter_expr.node_index(), + probe_side_sorted_filter_expr.interval().clone(), + ), + ]; + // Use the join filter intervals to update the physical expression graph physical_expr_graph.update_intervals(&mut filter_intervals)?; - // A vector to track changes in the intervals of the sorted filter expressions - let mut change_track = Vec::with_capacity(filter_intervals.len()); - - // Update the intervals of the sorted filter expressions and track any changes - for (sorted_expr, (_, interval)) in sorted_filter_exprs - .iter_mut() - .zip(filter_intervals.into_iter()) - { - if interval.eq(sorted_expr.interval()) { - change_track.push(false); - } else { - sorted_expr.set_interval(interval.clone()); - change_track.push(true); - } - } + // Calculated join filter interval for build side. + let new_build_side_sorted_filter_expr = filter_intervals.remove(0).1; // Determine the length of the rows to prune if there was a change in the intervals - let prune_length = if change_track.iter().any(|x| *x) { - determine_prune_length( - &self.input_buffer, - sorted_filter_exprs, - self.build_side, - )? + let prune_length = if !new_build_side_sorted_filter_expr + .eq(self.sorted_filter_expr.interval()) + { + self.sorted_filter_expr + .set_interval(new_build_side_sorted_filter_expr); + determine_prune_length(&self.input_buffer, &self.sorted_filter_expr)? } else { 0 }; @@ -1309,9 +1344,9 @@ impl SymmetricHashJoinStream { // Calculate filter intervals calculate_filter_expr_intervals( &build_hash_joiner.input_buffer, + &mut build_hash_joiner.sorted_filter_expr, &probe_batch, - &mut self.filter_columns, - build_join_side, + &mut probe_hash_joiner.sorted_filter_expr, )?; // Join the two sides, with let equal_result = build_hash_joiner.join_with_probe_batch( @@ -1332,7 +1367,7 @@ impl SymmetricHashJoinStream { let anti_result = build_hash_joiner.prune_with_probe_batch( self.schema.clone(), &probe_batch, - &mut self.filter_columns, + &mut probe_hash_joiner.sorted_filter_expr, self.join_type, &self.column_indices, &mut self.physical_expr_graph, diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 5899e4d479b7..4a08e39d5752 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -312,8 +312,9 @@ mod unix_test { }; operations.push(op); } - // It check if the right unmatches or left unmatches generated at the end of the file, which - // should not in unbounded mode. + // It check if the right unmatches or left unmatches generated at the end of the file and only once + // in hash join executor. SymmetricHashJoin exec produce it without waiting the end of the file + // and produce it more than once. assert!( operations .iter() From ec6e5ad09beb624637d08795f9112dd6509239da Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Fri, 10 Feb 2023 17:02:05 +0300 Subject: [PATCH 27/39] Update fifo.rs --- datafusion/core/tests/fifo.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 4a08e39d5752..7910cec1caab 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -312,7 +312,7 @@ mod unix_test { }; operations.push(op); } - // It check if the right unmatches or left unmatches generated at the end of the file and only once + // It check if the right unmatched or left unmatched results generated at the end of the file and only once // in hash join executor. SymmetricHashJoin exec produce it without waiting the end of the file // and produce it more than once. assert!( From 35c04a088761894925f3297a94c42bdb0737b2ef Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Fri, 10 Feb 2023 20:11:06 -0600 Subject: [PATCH 28/39] Remove unnecessary clones, improve comments and code structure --- .../src/physical_optimizer/pipeline_fixer.rs | 4 +- .../physical_plan/joins/hash_join_utils.rs | 78 ++-- .../core/src/physical_plan/joins/mod.rs | 4 +- .../joins/symmetric_hash_join.rs | 395 +++++++++--------- .../physical-expr/src/intervals/cp_solver.rs | 28 +- 5 files changed, 244 insertions(+), 265 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 049d3b0f8b9e..29fce233917c 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -117,14 +117,14 @@ fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result let left_convertible = convert_sort_expr_with_filter_schema( &JoinSide::Left, filter, - left.schema(), + &left.schema(), &left_ordering[0], )? .is_some(); let right_convertible = convert_sort_expr_with_filter_schema( &JoinSide::Right, filter, - hash_join.right().schema(), + &right.schema(), &right_ordering[0], )? .is_some(); diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 2d62d0e9b767..7824503afda2 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Symmetric and regular Hash Join related functionality used in join calculations and optimization -//! rules. +//! This file contains common subroutines for regular and symmetric hash join +//! related functionality, used both in join calculations and optimization rules. use std::collections::HashMap; use std::sync::Arc; @@ -71,7 +71,7 @@ fn collect_columns(expr: &Arc) -> Vec { /// ``` pub fn map_origin_col_to_filter_col( filter: &JoinFilter, - schema: SchemaRef, + schema: &SchemaRef, side: &JoinSide, ) -> Result> { let filter_schema = filter.schema(); @@ -116,23 +116,21 @@ pub fn map_origin_col_to_filter_col( pub fn convert_sort_expr_with_filter_schema( side: &JoinSide, filter: &JoinFilter, - schema: SchemaRef, + schema: &SchemaRef, sort_expr: &PhysicalSortExpr, ) -> Result>> { - let column_mapping_information: HashMap = - map_origin_col_to_filter_col(filter, schema, side)?; + let column_map = map_origin_col_to_filter_col(filter, schema, side)?; let expr = sort_expr.expr.clone(); // Get main schema columns: let expr_columns = collect_columns(&expr); - // Calculation is possible with `column_mapping_information` since sort exprs belong to a child. - let all_columns_are_included = expr_columns - .iter() - .all(|col| column_mapping_information.contains_key(col)); + // Calculation is possible with `column_map` since sort exprs belong to a child. + let all_columns_are_included = + expr_columns.iter().all(|col| column_map.contains_key(col)); if all_columns_are_included { // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. - let converted_filter_expr = expr - .transform_up(&|p| convert_filter_columns(p, &column_mapping_information))?; + let converted_filter_expr = + expr.transform_up(&|p| convert_filter_columns(p, &column_map))?; // Search the converted `PhysicalExpr` in filter expression; if an exact // match is found, use this sorted expression in graph traversals. if check_filter_expr_contains_sort_information( @@ -150,14 +148,14 @@ pub fn convert_sort_expr_with_filter_schema( /// It first calls the [convert_sort_expr_with_filter_schema] method to determine if the sort /// order of columns can be used in the filter expression. If it returns a [Some] value, the /// method wraps the result in a [SortedFilterExpr] instance with the original sort expression and -/// converted filter expression. Otherwise, this function returns an error. +/// the converted filter expression. Otherwise, this function returns an error. /// /// The [SortedFilterExpr] instance contains information about the sort order of columns that can /// be used in the filter expression, which can be used to optimize the query execution process. pub fn build_filter_input_order( side: JoinSide, filter: &JoinFilter, - schema: SchemaRef, + schema: &SchemaRef, order: &PhysicalSortExpr, ) -> Result { if let Some(expr) = @@ -175,36 +173,32 @@ pub fn build_filter_input_order( /// column mapping information. fn convert_filter_columns( input: Arc, - column_mapping_information: &HashMap, + column_map: &HashMap, ) -> Result>> { // Attempt to downcast the input expression to a Column type. - Ok(Some( - if let Some(col) = input.as_any().downcast_ref::() { - // If the downcast is successful, retrieve the corresponding filter column. - let filter_col = column_mapping_information.get(col).unwrap().clone(); - // Return the filter column as an Arc wrapped in an Option. - Arc::new(filter_col) - } else { - // If the downcast fails, return the input expression as is. - input - }, - )) + Ok(if let Some(col) = input.as_any().downcast_ref::() { + // If the downcast is successful, retrieve the corresponding filter column. + column_map.get(col).map(|c| Arc::new(c.clone()) as _) + } else { + // If the downcast fails, return the input expression as is. + Some(input) + }) } -/// The SortedFilterExpr object is used to represent a sorted filter expression in the -/// `SymmetricHashJoinExec` struct. It contains information about the -/// origin expression, the filter expression, and the sort option. The object has -/// several methods to access and modify its fields. +/// The [SortedFilterExpr] object represents a sorted filter expression. It +/// contains the following information: The origin expression, the filter +/// expression, an interval encapsulating expression bounds, and a stable +/// index identifying the expression in the expression DAG. #[derive(Debug, Clone)] pub struct SortedFilterExpr { - /// Sorted expr from a particular join side (child) + /// Sorted expression from a join side (i.e. a child of the join) origin_sorted_expr: PhysicalSortExpr, - /// For interval calculations, one to one mapping of the columns according to filter expression, - /// and column indices. + /// For interval calculations, one to one mapping of the columns + /// according to filter expression and column indices. filter_expr: Arc, - /// Interval + /// Interval containing expression bounds interval: Interval, - /// NodeIndex in Graph + /// Node index in the expression DAG node_index: usize, } @@ -222,8 +216,8 @@ impl SortedFilterExpr { } } /// Get origin expr information - pub fn origin_sorted_expr(&self) -> PhysicalSortExpr { - self.origin_sorted_expr.clone() + pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { + &self.origin_sorted_expr } /// Get filter expr information pub fn filter_expr(&self) -> &Arc { @@ -414,7 +408,7 @@ pub mod tests { assert!(build_filter_input_order( JoinSide::Left, &filter, - left_schema.clone(), + &left_schema, &PhysicalSortExpr { expr: Arc::new(Column::new("la1", 0)), options: SortOptions::default(), @@ -424,7 +418,7 @@ pub mod tests { assert!(build_filter_input_order( JoinSide::Left, &filter, - left_schema, + &left_schema, &PhysicalSortExpr { expr: Arc::new(Column::new("lt1", 3)), options: SortOptions::default(), @@ -434,7 +428,7 @@ pub mod tests { assert!(build_filter_input_order( JoinSide::Right, &filter, - right_schema.clone(), + &right_schema, &PhysicalSortExpr { expr: Arc::new(Column::new("ra1", 0)), options: SortOptions::default(), @@ -444,7 +438,7 @@ pub mod tests { assert!(build_filter_input_order( JoinSide::Right, &filter, - right_schema, + &right_schema, &PhysicalSortExpr { expr: Arc::new(Column::new("rb1", 1)), options: SortOptions::default(), @@ -500,7 +494,7 @@ pub mod tests { let res = convert_sort_expr_with_filter_schema( &JoinSide::Left, &filter, - schema, + &schema, &sorted, )?; assert!(res.is_none()); diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/core/src/physical_plan/joins/mod.rs index fe9954082e11..8ad50514f0b0 100644 --- a/datafusion/core/src/physical_plan/joins/mod.rs +++ b/datafusion/core/src/physical_plan/joins/mod.rs @@ -19,9 +19,7 @@ pub use cross_join::CrossJoinExec; pub use hash_join::HashJoinExec; -pub use hash_join_utils::{ - convert_sort_expr_with_filter_schema, map_origin_col_to_filter_col, SortedFilterExpr, -}; +pub use hash_join_utils::convert_sort_expr_with_filter_schema; pub use nested_loop_join::NestedLoopJoinExec; // Note: SortMergeJoin is not used in plans yet pub use sort_merge_join::SortMergeJoinExec; diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 785c8bbd0d59..bbfea952fcff 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. -//! The code implements the symmetric hash join algorithm for executing partitions in parallel -//! streams and buffering with range-condition pruning. +//! This file implements the symmetric hash join algorithm with range-based +//! data pruning to join two (potentially infinite) streams. //! -//! A Symmetric Hash Join plan consumes two sorted children plan and produces -//! joined output by given join type and other options. -//! -//! It includes OneSideHashJoiner struct constructed for both childs and calculates joins. +//! A [SymmetricHashJoinExec] plan takes two children plan (with appropriate +//! output ordering) and produces the join output according to the given join +//! type and other options. //! +//! This plan uses the [OneSideHashJoiner] object to facilitate join calculations +//! for both its children. use std::collections::{HashMap, VecDeque}; use std::fmt; @@ -275,37 +276,42 @@ impl SymmetricHashJoinExec { let mut physical_expr_graph = ExprIntervalGraph::try_new(filter.expression().clone())?; - let (left_required_sort_exprs, right_required_sort_exprs): (Vec<_>, Vec<_>) = - left.output_ordering() - .unwrap() - .iter() - .zip(right.output_ordering().unwrap()) - .take(1) - .map(|(l, r)| (l.clone(), r.clone())) - .unzip(); + let (left_ordering, right_ordering) = match ( + left.output_ordering(), + right.output_ordering(), + ) { + (Some([left_ordering, ..]), Some([right_ordering, ..])) => { + (left_ordering, right_ordering) + } + _ => { + return Err(DataFusionError::Plan( + "Symmetric hash join requires its children to have an output ordering".to_string(), + )); + } + }; - // Build the sorted filter expression for the left child + // Build the sorted filter expression for the left child: let left_filter_expression = build_filter_input_order( JoinSide::Left, &filter, - left.schema(), - &left_required_sort_exprs[0], + &left.schema(), + left_ordering, )?; - // Build the sorted filter expression for the right child + // Build the sorted filter expression for the right child: let right_filter_expression = build_filter_input_order( JoinSide::Right, &filter, - right.schema(), - &right_required_sort_exprs[0], + &right.schema(), + right_ordering, )?; // Store the left and right sorted filter expressions in a vector let mut sorted_filter_exprs = vec![left_filter_expression, right_filter_expression]; - // Gather node indices of converted filter expressions in SortedFilterExpr - // using the filter columns vector + // Gather node indices of converted filter expressions in `SortedFilterExpr` + // using the filter columns vector: let child_node_indexes = physical_expr_graph.gather_node_indices( &sorted_filter_exprs .iter() @@ -313,7 +319,7 @@ impl SymmetricHashJoinExec { .collect::>(), ); - // Inject calculated node indexes into SortedFilterExpr + // Inject calculated node indices into SortedFilterExpr: for (sorted_expr, (_, index)) in sorted_filter_exprs .iter_mut() .zip(child_node_indexes.iter()) @@ -321,6 +327,9 @@ impl SymmetricHashJoinExec { sorted_expr.set_node_index(*index); } + let left_required_sort_exprs = vec![left_ordering.clone()]; + let right_required_sort_exprs = vec![right_ordering.clone()]; + Ok(SymmetricHashJoinExec { left, right, @@ -615,39 +624,36 @@ fn prune_hash_values( /// /// # Arguments /// -/// * `build_input_buffer` - The RecordBatch on the build side of the join -/// * `build_sorted_filter_expr` - Build side `SortedFilterExpr` to update. -/// * `probe_batch` - The RecordBatch on the probe side of the join +/// * `build_input_buffer` - The [RecordBatch] on the build side of the join. +/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. +/// * `probe_batch` - The `RecordBatch` on the probe side of the join. /// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. /// /// ### Note /// ```text /// -/// Build Probe -/// +-------+ +-------+ -/// |+-----+| | | | -/// ||a |b || | | | -/// |+--|--+| | | | -/// | | | | | | -/// | | | | | | -/// | | | | | | -/// | | | | | | -/// | | | |+-----+| -/// | | | ||c |d || -/// | | | |+--|--+| -/// +-------+ +-------+ +/// Build Probe +/// +-------+ +-------+ +/// |+-----+| | | | +/// FV ||a |b || | | | +/// |+--|--+| | | | +/// | | | | | | +/// | | | | | | +/// | | | | | | +/// | | | | | | +/// | | | |+-----+| +/// | | | ||c |d || LV +/// | | | |+--|--+| +/// +-------+ +-------+ /// -/// FV -> First value -/// LV -> Last value -/// -/// Interval arithmetic is used to calculate the range of possible join key values for -/// build-side pruning. This is done by first creating intervals for the join filter values -/// in the build side of the join, which is first value of the side and infinity since the stream is also infinite. -/// If the expression is descending, interval becomes [-∞, FV] and if ascending, -/// interval becomes [FV, ∞]. For the probe side, the last of value of the RecordBatch is used -/// for interval calculation. This time we use the last value of the batch, -/// if the expression is descending, the interval becomes [-∞, LV] and if it is ascending, -/// it becomes [LV, ∞]. +/// Interval arithmetic is used to calculate viable join ranges for build-side +/// pruning. This is done by first creating an interval for join filter values in +/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the +/// ordering (descending/ascending) of the filter expression. Here, FV denotes the +/// first value on the build side. This range is then compared with the probe side +/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering +/// (ascending/descending) of the probe side. Here, LV denotes the last value on +/// the probe side. /// /// In our case: /// Expr1: [a(FV), inf] @@ -655,8 +661,6 @@ fn prune_hash_values( /// Expr3: [c(LV), inf] /// Expr4: [d(LV), inf] /// ``` -/// -/// fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, build_sorted_filter_expr: &mut SortedFilterExpr, @@ -707,21 +711,22 @@ fn calculate_filter_expr_intervals( Ok(()) } -/// Determine the length of the record batch to be pruned. +/// Determine the pruning length for `buffer`. /// -/// This function evaluates the build side filter expression, converts the result into an array and -/// determines the length of the record batch to be pruned by performing a binary search on the array. +/// This function evaluates the build side filter expression, converts the +/// result into an array and determines the pruning length by performing a +/// binary search on the array. /// -/// # Parameters +/// # Arguments /// /// * `buffer`: The record batch to be pruned. -/// * `build_side_filter_expr`: The filter expression on the build side used to determine the length -/// of the record batch to be pruned. +/// * `build_side_filter_expr`: The filter expression on the build side used +/// to determine the pruning length. /// /// # Returns /// -/// A Result object that contains the length of the record batch to be pruned. An error will be -/// returned if there is an issue evaluating the build side filter expression. +/// A [Result] object that contains the pruning length. The function will return +/// an error if there is an issue evaluating the build side filter expression. fn determine_prune_length( buffer: &RecordBatch, build_side_filter_expr: &SortedFilterExpr, @@ -731,8 +736,7 @@ fn determine_prune_length( // Evaluate the build side filter expression and convert it into an array let batch_arr = origin_sorted_expr .expr - .evaluate(buffer) - .unwrap() + .evaluate(buffer)? .into_array(buffer.num_rows()); // Get the lower or upper interval based on the sort direction @@ -919,13 +923,13 @@ where } struct OneSideHashJoiner { - // Build side + /// Build side build_side: JoinSide, /// Build side filter sort information sorted_filter_expr: SortedFilterExpr, - /// Inout record batch buffer + /// Input record batch buffer input_buffer: RecordBatch, - /// columns from the side + /// Columns from the side on: Vec, /// Hashmap hashmap: JoinHashMap, @@ -965,29 +969,26 @@ impl OneSideHashJoiner { } } - /// Updates the internal state of the OneSideHashJoiner with the incoming batch. + /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch. /// /// # Arguments /// - /// * `batch` - The incoming RecordBatch to be merged with the internal input buffer + /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer /// * `random_state` - The random state used to hash values /// /// # Returns /// - /// Returns a Result with an empty Ok if the update was successful, otherwise returns an error. + /// Returns a [Result] encapsulating any intermediate errors. fn update_internal_state( &mut self, batch: &RecordBatch, random_state: &RandomState, ) -> Result<()> { - // Merge the incoming batch with the existing input buffer - self.input_buffer = - merge_batches(&self.input_buffer, batch, batch.schema()).unwrap(); - - // Resize the hashes buffer to the number of rows in the incoming batch + // Merge the incoming batch with the existing input buffer: + self.input_buffer = merge_batches(&self.input_buffer, batch, batch.schema())?; + // Resize the hashes buffer to the number of rows in the incoming batch: self.hashes_buffer.resize(batch.num_rows(), 0); - - // Update the hashmap with the join key values and hashes of the incoming batch + // Update the hashmap with the join key values and hashes of the incoming batch: update_hash( &self.on, batch, @@ -996,14 +997,12 @@ impl OneSideHashJoiner { random_state, &mut self.hashes_buffer, )?; - - // Add the hashes buffer to the row_hash_values deque + // Add the hashes buffer to the hash value deque: self.row_hash_values.extend(self.hashes_buffer.iter()); - Ok(()) } - /// This method performs a join between the build side input buffer and probe side RecordBatch. + /// This method performs a join between the build side input buffer and the probe side batch. /// /// # Arguments /// @@ -1012,11 +1011,11 @@ impl OneSideHashJoiner { /// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join. /// * `filter` - An optional filter on the join condition. /// * `probe_batch` - The second record batch to be joined. - /// * `probe_visited` - A hashset to store the visited indices from the probe batch. + /// * `probe_visited` - A hash set to store the visited indices from the probe batch. /// * `probe_offset` - The offset of the probe side for visited indices calculations. /// * `column_indices` - An array of columns to be selected for the result of the join. /// * `random_state` - The random state for the join. - /// * `null_equals_null` - A boolean indicating whether null values should be treated as equal when joining. + /// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. /// /// # Returns /// @@ -1025,7 +1024,7 @@ impl OneSideHashJoiner { #[allow(clippy::too_many_arguments)] fn join_with_probe_batch( &mut self, - schema: SchemaRef, + schema: &SchemaRef, join_type: JoinType, on_probe: &[Column], filter: &JoinFilter, @@ -1037,9 +1036,9 @@ impl OneSideHashJoiner { null_equals_null: &bool, ) -> Result> { if self.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { - return Ok(Some(RecordBatch::new_empty(schema))); + return Ok(Some(RecordBatch::new_empty(schema.clone()))); } - let (build_side, probe_side) = build_join_indices( + let (build_indices, probe_indices) = build_join_indices( probe_batch, &self.hashmap, &self.input_buffer, @@ -1056,11 +1055,11 @@ impl OneSideHashJoiner { record_visited_indices( &mut self.visited_rows, self.deleted_offset, - &build_side, + &build_indices, ); } if need_to_produce_result_in_final(self.build_side.negate(), join_type) { - record_visited_indices(probe_visited, probe_offset, &probe_side); + record_visited_indices(probe_visited, probe_offset, &probe_indices); } if matches!( join_type, @@ -1071,30 +1070,30 @@ impl OneSideHashJoiner { ) { Ok(None) } else { - let res = build_batch_from_indices( - schema.as_ref(), + build_batch_from_indices( + schema, &self.input_buffer, probe_batch, - build_side.clone(), - probe_side.clone(), + build_indices, + probe_indices, column_indices, self.build_side, - )?; - Ok(Some(res)) + ) + .map(Some) } } - /// Function to produce the unmatched record results based on the build side, + /// This function produces unmatched record results based on the build side, /// join type and other parameters. /// - /// The method use first 'prune_length' rows from the build side input buffer to - /// produce results. + /// The method uses first `prune_length` rows from the build side input buffer + /// to produce results. /// /// # Arguments /// /// * `output_schema` - The schema of the final output record batch. /// * `prune_length` - The length of the determined prune length. - /// * `probe_schema` - The schema of the probe RecordBatch. + /// * `probe_schema` - The schema of the probe [RecordBatch]. /// * `join_type` - The type of join to be performed. /// * `column_indices` - Indices of columns that are being joined. /// @@ -1103,15 +1102,15 @@ impl OneSideHashJoiner { /// * `Option` - The final output record batch if required, otherwise [None]. fn build_side_determined_results( &self, - output_schema: SchemaRef, + output_schema: &SchemaRef, prune_length: usize, probe_schema: SchemaRef, join_type: JoinType, column_indices: &[ColumnIndex], ) -> Result> { - // Check if we need to produce a result in the final output - let result = if need_to_produce_result_in_final(self.build_side, join_type) { - // Calculate the indices for build and probe sides based on join type and build side + // Check if we need to produce a result in the final output: + if need_to_produce_result_in_final(self.build_side, join_type) { + // Calculate the indices for build and probe sides based on join type and build side: let (build_indices, probe_indices) = calculate_indices_by_join_type( self.build_side, prune_length, @@ -1120,11 +1119,10 @@ impl OneSideHashJoiner { join_type, )?; - // Create an empty probe record batch + // Create an empty probe record batch: let empty_probe_batch = RecordBatch::new_empty(probe_schema); - - // Build the final result from the indices of build and probe sides - Some(build_batch_from_indices( + // Build the final result from the indices of build and probe sides: + build_batch_from_indices( output_schema.as_ref(), &self.input_buffer, &empty_probe_batch, @@ -1132,19 +1130,20 @@ impl OneSideHashJoiner { probe_indices, column_indices, self.build_side, - )?) + ) + .map(Some) } else { // If we don't need to produce a result, return None - None - }; - Ok(result) + Ok(None) + } } /// Prunes the internal buffer. /// - /// The input record batch and `probe_batch`, is used to update the intervals of the sorted filter expressions. - /// The updated build interval is then used to determine the length of the build side rows to prune. - /// If there are rows to prune, they are pruned and the result is returned. + /// Argument `probe_batch` is used to update the intervals of the sorted + /// filter expressions. The updated build interval determines the new length + /// of the build side. If there are rows to prune, they are removed from the + /// internal buffer. /// /// # Arguments /// @@ -1158,25 +1157,23 @@ impl OneSideHashJoiner { /// /// # Returns /// - /// If there are no rows in the input buffer, returns `Ok(None)`. /// If there are rows to prune, returns the pruned build side record batch wrapped in an `Ok` variant. /// Otherwise, returns `Ok(None)`. fn prune_with_probe_batch( &mut self, - schema: SchemaRef, + schema: &SchemaRef, probe_batch: &RecordBatch, probe_side_sorted_filter_expr: &mut SortedFilterExpr, join_type: JoinType, column_indices: &[ColumnIndex], physical_expr_graph: &mut ExprIntervalGraph, ) -> Result> { - // Return None if there are no rows in the input buffer + // Check if the input buffer is empty: if self.input_buffer.num_rows() == 0 { return Ok(None); } - - // Convert the sorted filter expressions into a vector of (node_index, interval) tuples - // for use in updating the interval graph + // Convert the sorted filter expressions into a vector of (node_index, interval) + // tuples for use when updating the interval graph. let mut filter_intervals = vec![ ( self.sorted_filter_expr.node_index(), @@ -1187,72 +1184,66 @@ impl OneSideHashJoiner { probe_side_sorted_filter_expr.interval().clone(), ), ]; - // Use the join filter intervals to update the physical expression graph + // Use the join filter intervals to update the physical expression graph: physical_expr_graph.update_intervals(&mut filter_intervals)?; - - // Calculated join filter interval for build side. + // Get the new join filter interval for build side: let new_build_side_sorted_filter_expr = filter_intervals.remove(0).1; - - // Determine the length of the rows to prune if there was a change in the intervals - let prune_length = if !new_build_side_sorted_filter_expr - .eq(self.sorted_filter_expr.interval()) - { - self.sorted_filter_expr - .set_interval(new_build_side_sorted_filter_expr); - determine_prune_length(&self.input_buffer, &self.sorted_filter_expr)? - } else { - 0 - }; - - // If there are rows to prune, prune them and return the result - if prune_length > 0 { - let result = self.build_side_determined_results( - schema, - prune_length, - probe_batch.schema(), - join_type, - column_indices, - ); - prune_hash_values( - prune_length, - &mut self.hashmap, - &mut self.row_hash_values, - self.deleted_offset as u64, - )?; - for row in self.deleted_offset..(self.deleted_offset + prune_length) { - self.visited_rows.remove(&row); - } - self.input_buffer = self - .input_buffer - .slice(prune_length, self.input_buffer.num_rows() - prune_length); - self.deleted_offset += prune_length; - result - } else { - Ok(None) + // Check if the intervals changed, exit early if not: + if new_build_side_sorted_filter_expr.eq(self.sorted_filter_expr.interval()) { + return Ok(None); + } + // Determine the pruning length if there was a change in the intervals: + self.sorted_filter_expr + .set_interval(new_build_side_sorted_filter_expr); + let prune_length = + determine_prune_length(&self.input_buffer, &self.sorted_filter_expr)?; + // If we can not prune, exit early: + if prune_length == 0 { + return Ok(None); + } + // Compute the result, and perform pruning if there are rows to prune: + let result = self.build_side_determined_results( + schema, + prune_length, + probe_batch.schema(), + join_type, + column_indices, + ); + prune_hash_values( + prune_length, + &mut self.hashmap, + &mut self.row_hash_values, + self.deleted_offset as u64, + )?; + for row in self.deleted_offset..(self.deleted_offset + prune_length) { + self.visited_rows.remove(&row); } + self.input_buffer = self + .input_buffer + .slice(prune_length, self.input_buffer.num_rows() - prune_length); + self.deleted_offset += prune_length; + result } } -fn combine_two_batches_result( - output_schema: SchemaRef, +fn combine_two_batches( + output_schema: &SchemaRef, left_batch: Option, right_batch: Option, ) -> Result> { match (left_batch, right_batch) { (Some(batch), None) | (None, Some(batch)) => { - // If either left_batch or right_batch is present, return that batch + // If only one of the batches are present, return it: Ok(Some(batch)) } (Some(left_batch), Some(right_batch)) => { - // If both left_batch and right_batch are present, concatenate them - Some( - concat_batches(&output_schema, &[left_batch, right_batch]) - .map_err(DataFusionError::ArrowError), - ) - .transpose() + // If both batches are present, concatenate them: + concat_batches(output_schema, &[left_batch, right_batch]) + .map_err(DataFusionError::ArrowError) + .map(Some) } (None, None) => { - // If both left_batch and right_batch are not present, return an empty batch + // If neither is present, return an empty batch: Ok(None) } } @@ -1261,43 +1252,42 @@ fn combine_two_batches_result( impl SymmetricHashJoinStream { /// Polls the next result of the join operation. /// - /// If the result of the join is ready, it returns the next record batch. If the join has completed and there are no more results, - /// it returns `Poll::Ready(None)`. If the join operation is not complete but the current stream is not ready yet, it returns `Poll::Pending`. + /// If the result of the join is ready, it returns the next record batch. + /// If the join has completed and there are no more results, it returns + /// `Poll::Ready(None)`. If the join operation is not complete, but the + /// current stream is not ready yet, it returns `Poll::Pending`. fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { loop { - // If the final result has already been obtained, return `Poll::Ready(None)` + // If the final result has already been obtained, return `Poll::Ready(None)`: if self.final_result { return Poll::Ready(None); } - // If both streams have been exhausted, return the final result + // If both streams have been exhausted, return the final result: if self.right.exhausted && self.left.exhausted { - // Get left side results + // Get left side results: let left_result = self.left.build_side_determined_results( - self.schema.clone(), + &self.schema, self.left.input_buffer.num_rows(), self.right.input_buffer.schema(), self.join_type, &self.column_indices, )?; - // Get right side results + // Get right side results: let right_result = self.right.build_side_determined_results( - self.schema.clone(), + &self.schema, self.right.input_buffer.num_rows(), self.left.input_buffer.schema(), self.join_type, &self.column_indices, )?; self.final_result = true; - // Combine results - let result = combine_two_batches_result( - self.schema.clone(), - left_result, - right_result, - )?; - // Update the metrics if batch is some, if not, continue loop. + // Combine results: + let result = + combine_two_batches(&self.schema, left_result, right_result)?; + // Update the metrics if we have a batch; otherwise, continue the loop. if let Some(batch) = &result { self.metrics.output_batches.add(1); self.metrics.output_rows.add(batch.num_rows()); @@ -1307,7 +1297,8 @@ impl SymmetricHashJoinStream { } } - // Determine which stream should be polled next. The side the RecordBatch comes from becomes probe side. + // Determine which stream should be polled next. The side the + // RecordBatch comes from becomes the probe side. let ( input_stream, probe_hash_joiner, @@ -1331,26 +1322,26 @@ impl SymmetricHashJoinStream { &mut self.metrics.right, ) }; - // Poll the next batch from `input_stream` + // Poll the next batch from `input_stream`: match input_stream.poll_next_unpin(cx) { // Batch is available Poll::Ready(Some(Ok(probe_batch))) => { - // Update the metrics for the stream that was polled + // Update the metrics for the stream that was polled: probe_side_metrics.input_batches.add(1); probe_side_metrics.input_rows.add(probe_batch.num_rows()); - // Update the internal state of the hash joiner for build side + // Update the internal state of the hash joiner for the build side: probe_hash_joiner .update_internal_state(&probe_batch, &self.random_state)?; - // Calculate filter intervals + // Calculate filter intervals: calculate_filter_expr_intervals( &build_hash_joiner.input_buffer, &mut build_hash_joiner.sorted_filter_expr, &probe_batch, &mut probe_hash_joiner.sorted_filter_expr, )?; - // Join the two sides, with + // Join the two sides: let equal_result = build_hash_joiner.join_with_probe_batch( - self.schema.clone(), + &self.schema, self.join_type, &probe_hash_joiner.on, &self.filter, @@ -1361,29 +1352,27 @@ impl SymmetricHashJoinStream { &self.random_state, &self.null_equals_null, )?; - // Increment the offset for the probe hash joiner + // Increment the offset for the probe hash joiner: probe_hash_joiner.offset += probe_batch.num_rows(); - // Prune build input buffer using the physical expression graph and filter intervals + // Prune the build side input buffer using the expression + // DAG and filter intervals: let anti_result = build_hash_joiner.prune_with_probe_batch( - self.schema.clone(), + &self.schema, &probe_batch, &mut probe_hash_joiner.sorted_filter_expr, self.join_type, &self.column_indices, &mut self.physical_expr_graph, )?; - // Combine results - let result = combine_two_batches_result( - self.schema.clone(), - equal_result, - anti_result, - )?; - // Choose next pool side. If other side is not exhausted, revert the probe side - // before returning the result. + // Combine results: + let result = + combine_two_batches(&self.schema, equal_result, anti_result)?; + // Choose next poll side. If the other side is not exhausted, + // switch the probe side before returning the result. if !build_hash_joiner.exhausted { self.probe_side = build_join_side; } - // Update the metrics if batch is some, if not, continue loop. + // Update the metrics if we have a batch; otherwise, continue the loop. if let Some(batch) = &result { self.metrics.output_batches.add(1); self.metrics.output_rows.add(batch.num_rows()); @@ -1392,9 +1381,9 @@ impl SymmetricHashJoinStream { } Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Ready(None) => { - // Update probe side is exhausted. + // Mark the probe side exhausted: probe_hash_joiner.exhausted = true; - // Change the probe side + // Change the probe side: self.probe_side = build_join_side; } Poll::Pending => { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index aad92875b353..095e822748f3 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -341,8 +341,7 @@ impl ExprIntervalGraph { .collect::>(); while let Some(node) = bfs.next(graph) { // Get the plan corresponding to this node: - let input = graph.index(node); - let expr = input.expr.clone(); + let expr = &graph.index(node).expr; // If the current expression is among `exprs`, slate its children // for removal: if let Some(value) = exprs.iter().position(|e| expr.eq(e)) { @@ -450,15 +449,14 @@ impl ExprIntervalGraph { ) -> Result { let mut bfs = Bfs::new(&self.graph, self.root); while let Some(node) = bfs.next(&self.graph) { - // Get plan - let input = self.graph.index(node); - // Get calculated interval. BinaryExpr will propagate the interval according to - // this. - let node_interval = input.interval().clone(); let index = node.index(); + // Get the expression DAG node: + let input = self.graph.index(node); + // Get the associated interval, which will used for propagation: + let node_interval = input.interval(); if let Some((_, interval)) = expr_stats.iter_mut().find(|(e, _)| *e == index) { - *interval = node_interval; + *interval = node_interval.clone(); continue; } @@ -477,10 +475,10 @@ impl ExprIntervalGraph { if let (Some(new_left_interval), Some(new_right_interval)) = if op.is_logic_operator() { // TODO: Currently, this implementation only supports the AND operator - // and does not require any further propagation. - // In the future, upon adding support for additional logical operators, - // this method will require modification to support propagating - // the changes accordingly. + // and does not require any further propagation. In the future, + // upon adding support for additional logical operators, this + // method will require modification to support propagating the + // changes accordingly. continue; } else if op.is_comparison_operator() { if let Interval { @@ -489,8 +487,8 @@ impl ExprIntervalGraph { } = node_interval { // TODO: The optimization of handling strictly false clauses through - // conversion to equivalent comparison operators (e.g. GT to LE, LT to GE) - // can be implemented once support for open/closed intervals is added. + // conversion to equivalent comparison operators (e.g. GT to LE, LT to GE) + // can be implemented once open/closed intervals are supported. continue; } // Propagate the comparison operator. @@ -499,7 +497,7 @@ impl ExprIntervalGraph { // Propagate the arithmetic operator. propagate_arithmetic( op, - &node_interval, + node_interval, left_interval, right_interval, )? From 0743eb61a603b3be41d1e51012251e2f1d4a37a6 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Sat, 11 Feb 2023 21:43:25 -0600 Subject: [PATCH 29/39] Remove leaf searches from CP iterations, improve code organization/comments --- .../joins/symmetric_hash_join.rs | 4 +- .../physical-expr/src/intervals/cp_solver.rs | 137 ++++++++++-------- datafusion/physical-expr/src/utils.rs | 104 ++++++------- 3 files changed, 127 insertions(+), 118 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index bbfea952fcff..d70924279936 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -436,7 +436,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { ) } - // TODO Output ordering might be kept for some cases. + // TODO: Output ordering might be kept for some cases. fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { None } @@ -1185,7 +1185,7 @@ impl OneSideHashJoiner { ), ]; // Use the join filter intervals to update the physical expression graph: - physical_expr_graph.update_intervals(&mut filter_intervals)?; + physical_expr_graph.update_ranges(&mut filter_intervals)?; // Get the new join filter interval for build side: let new_build_side_sorted_filter_expr = filter_intervals.remove(0).1; // Check if the intervals changed, exit early if not: diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 095e822748f3..c62a2d5c0594 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -25,7 +25,7 @@ use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; -use petgraph::stable_graph::StableGraph; +use petgraph::stable_graph::{DefaultIx, StableGraph}; use petgraph::visit::{Bfs, DfsPostOrder}; use petgraph::Outgoing; @@ -120,7 +120,7 @@ pub struct ExprIntervalGraph { root: NodeIndex, } -/// Result corresponding the 'update_interval' call. +/// This object encapsulates all possible constraint propagation results. #[derive(PartialEq, Debug)] pub enum PropagationResult { CannotPropagate, @@ -341,14 +341,13 @@ impl ExprIntervalGraph { .collect::>(); while let Some(node) = bfs.next(graph) { // Get the plan corresponding to this node: - let expr = &graph.index(node).expr; + let expr = &graph[node].expr; // If the current expression is among `exprs`, slate its children // for removal: if let Some(value) = exprs.iter().position(|e| expr.eq(e)) { // Update the node index of the associated `PhysicalExpr`: expr_node_indices[value].1 = node.index(); - let mut edges = graph.neighbors_directed(node, Outgoing).detach(); - while let Some(n_index) = edges.next_node(graph) { + for n_index in graph.neighbors_directed(node, Outgoing) { // Slate the child for removal, do not remove immediately. removals.push(n_index); } @@ -360,11 +359,31 @@ impl ExprIntervalGraph { expr_node_indices } + /// This function assigns given ranges to expressions in the DAEG. + /// The argument `assignments` associates indices of sought expressions + /// with their corresponding new ranges. + pub fn assign_intervals(&mut self, assignments: &[(usize, Interval)]) { + for (index, interval) in assignments { + let node_index = NodeIndex::from(*index as DefaultIx); + self.graph[node_index].interval = interval.clone(); + } + } + + /// This function fetches ranges of expressions from the DAEG. The argument + /// `assignments` associates indices of sought expressions with their ranges, + /// which this function modifies to reflect the intervals in the DAEG. + pub fn update_intervals(&self, assignments: &mut [(usize, Interval)]) { + for (index, interval) in assignments.iter_mut() { + let node_index = NodeIndex::from(*index as DefaultIx); + *interval = self.graph[node_index].interval.clone(); + } + } + /// Computes bounds for an expression using interval arithmetic via a /// bottom-up traversal. /// /// # Arguments - /// * `expr_stats` - &[(usize, Interval)]. Provide NodeIndex, Interval tuples for leaf variables. + /// * `leaf_bounds` - &[(usize, Interval)]. Provide NodeIndex, Interval tuples for leaf variables. /// /// # Examples /// @@ -394,82 +413,68 @@ impl ExprIntervalGraph { /// }, /// )]; /// // Evaluate bounds for the composite expression: + /// graph.assign_intervals(&intervals); /// assert_eq!( - /// graph.evaluate_bounds(&intervals).unwrap(), - /// Interval { + /// graph.evaluate_bounds().unwrap(), + /// &Interval { /// lower: ScalarValue::Int32(Some(20)), /// upper: ScalarValue::Int32(Some(30)) /// } /// ) /// /// ``` - pub fn evaluate_bounds( - &mut self, - expr_stats: &[(usize, Interval)], - ) -> Result { + pub fn evaluate_bounds(&mut self) -> Result<&Interval> { let mut dfs = DfsPostOrder::new(&self.graph, self.root); while let Some(node) = dfs.next(&self.graph) { - // Get outgoing edges (i.e. children): - let mut edges = self.graph.neighbors_directed(node, Outgoing).detach(); - // If the current expression is a leaf expression, it should have - // externally-given intervals. Use this information if available: - let index = node.index(); - if let Some((_, interval)) = expr_stats.iter().find(|(e, _)| *e == index) { - self.graph[node].interval = interval.clone(); + let neighbors = self.graph.neighbors_directed(node, Outgoing); + let children = neighbors.collect::>(); + // If the current expression is a leaf, its interval should already + // be set externally, just continue with the evaluation procedure: + if children.is_empty() { continue; } let expr_any = self.graph[node].expr.as_any(); if let Some(binary) = expr_any.downcast_ref::() { // Get children intervals: - let left_child_idx = edges.next_node(&self.graph).unwrap(); - let right_child_idx = edges.next_node(&self.graph).unwrap(); - let left_interval = self.graph.index(right_child_idx).interval(); - let right_interval = self.graph.index(left_child_idx).interval(); + let left_child_idx = children[1]; + let right_child_idx = children[0]; + let left_interval = self.graph[left_child_idx].interval(); + let right_interval = self.graph[right_child_idx].interval(); // Calculate and replace current node's interval: self.graph[node].interval = apply_operator(binary.op(), left_interval, right_interval)?; } else if let Some(cast_expr) = expr_any.downcast_ref::() { let cast_type = cast_expr.cast_type(); let cast_options = cast_expr.cast_options(); - let child_index = edges.next_node(&self.graph).unwrap(); - let child = self.graph.index(child_index); + let child = self.graph.index(children[0]); // Cast current node's interval to the right type: self.graph[node].interval = child.interval.cast_to(cast_type, cast_options)?; } } - Ok(self.graph.index(self.root).interval.clone()) + Ok(&self.graph[self.root].interval) } /// Updates/shrinks bounds for leaf expressions using interval arithmetic /// via a top-down traversal. - fn propagate_constraints( - &mut self, - expr_stats: &mut [(usize, Interval)], - ) -> Result { + fn propagate_constraints(&mut self) -> Result { let mut bfs = Bfs::new(&self.graph, self.root); while let Some(node) = bfs.next(&self.graph) { - let index = node.index(); - // Get the expression DAG node: - let input = self.graph.index(node); - // Get the associated interval, which will used for propagation: - let node_interval = input.interval(); - if let Some((_, interval)) = expr_stats.iter_mut().find(|(e, _)| *e == index) - { - *interval = node_interval.clone(); + let neighbors = self.graph.neighbors_directed(node, Outgoing); + let children = neighbors.collect::>(); + // If the current expression is a leaf, its range is now final. + // So, just continue with the propagation procedure: + if children.is_empty() { continue; } - - // Get outgoing edges (i.e. children): - let mut edges = self.graph.neighbors_directed(node, Outgoing).detach(); - - let expr_any = input.expr.as_any(); + let node_interval = self.graph[node].interval(); + let expr_any = self.graph[node].expr.as_any(); if let Some(binary) = expr_any.downcast_ref::() { - // Get children intervals (note that we traverse in reverse): - let right_child_idx = edges.next_node(&self.graph).unwrap(); - let right_interval = self.graph.index(right_child_idx).interval(); - let left_child_idx = edges.next_node(&self.graph).unwrap(); - let left_interval = self.graph.index(left_child_idx).interval(); + // Get children intervals: + let left_child_idx = children[1]; + let right_child_idx = children[0]; + let left_interval = self.graph[left_child_idx].interval(); + let right_interval = self.graph[right_child_idx].interval(); let op = binary.op(); if let (Some(new_left_interval), Some(new_right_interval)) = @@ -509,7 +514,7 @@ impl ExprIntervalGraph { return Ok(PropagationResult::Infeasible); }; } else if let Some(cast) = expr_any.downcast_ref::() { - let child_index = edges.next_node(&self.graph).unwrap(); + let child_index = children[0]; let child = self.graph.index(child_index); // Get child's data type: let cast_type = child.interval().get_datatype(); @@ -522,22 +527,26 @@ impl ExprIntervalGraph { /// Updates intervals for all expressions in the DAEG by successive /// bottom-up and top-down traversals. - pub fn update_intervals( + pub fn update_ranges( &mut self, - expr_stats: &mut [(usize, Interval)], + leaf_bounds: &mut [(usize, Interval)], ) -> Result { - self.evaluate_bounds(expr_stats) - .and_then(|interval| match interval { - Interval { - upper: ScalarValue::Boolean(Some(true)), - .. - } => self.propagate_constraints(expr_stats), - Interval { - lower: ScalarValue::Boolean(Some(false)), - .. - } => Ok(PropagationResult::Infeasible), - _ => Ok(PropagationResult::CannotPropagate), - }) + self.assign_intervals(leaf_bounds); + match self.evaluate_bounds()? { + Interval { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + } => Ok(PropagationResult::Infeasible), + Interval { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(true)), + } => { + let result = self.propagate_constraints(); + self.update_intervals(leaf_bounds); + result + } + _ => Ok(PropagationResult::CannotPropagate), + } } } @@ -609,7 +618,7 @@ mod tests { .map(|((_, interval), (_, index))| (*index, interval.clone())) .collect_vec(); - let exp_result = graph.update_intervals(&mut col_stat_nodes[..])?; + let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?; assert_eq!(exp_result, result); col_stat_nodes .iter() diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 120ef0f09a83..0c79bb4d8e84 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -184,6 +184,56 @@ pub fn normalize_sort_expr_with_equivalence_properties( } } +/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. +pub fn ordering_satisfy EquivalenceProperties>( + provided: Option<&[PhysicalSortExpr]>, + required: Option<&[PhysicalSortExpr]>, + equal_properties: F, +) -> bool { + match (provided, required) { + (_, None) => true, + (None, Some(_)) => false, + (Some(provided), Some(required)) => { + ordering_satisfy_concrete(provided, required, equal_properties) + } + } +} + +pub fn ordering_satisfy_concrete EquivalenceProperties>( + provided: &[PhysicalSortExpr], + required: &[PhysicalSortExpr], + equal_properties: F, +) -> bool { + if required.len() > provided.len() { + false + } else if required + .iter() + .zip(provided.iter()) + .all(|(order1, order2)| order1.eq(order2)) + { + true + } else if let eq_classes @ [_, ..] = equal_properties().classes() { + let normalized_required_exprs = required + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) + }) + .collect::>(); + let normalized_provided_exprs = provided + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) + }) + .collect::>(); + normalized_required_exprs + .iter() + .zip(normalized_provided_exprs.iter()) + .all(|(order1, order2)| order1.eq(order2)) + } else { + false + } +} + #[derive(Clone, Debug)] pub struct ExprTreeNode { expr: Arc, @@ -242,7 +292,7 @@ where let mut graph = StableGraph::::new(); let mut visited_plans = Vec::<(Arc, NodeIndex)>::new(); let root = init.mutable_transform_up(&mut |mut input| { - let expr = input.expr.clone(); + let expr = &input.expr; let node_idx = match visited_plans.iter().find(|(e, _)| expr.eq(e)) { Some((_, idx)) => *idx, None => { @@ -250,7 +300,7 @@ where for expr_node in input.child_nodes.iter() { graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } - visited_plans.push((expr, node_idx)); + visited_plans.push((expr.clone(), node_idx)); node_idx } }; @@ -260,56 +310,6 @@ where Ok((root.data.unwrap(), graph)) } -/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. -pub fn ordering_satisfy EquivalenceProperties>( - provided: Option<&[PhysicalSortExpr]>, - required: Option<&[PhysicalSortExpr]>, - equal_properties: F, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => { - ordering_satisfy_concrete(provided, required, equal_properties) - } - } -} - -pub fn ordering_satisfy_concrete EquivalenceProperties>( - provided: &[PhysicalSortExpr], - required: &[PhysicalSortExpr], - equal_properties: F, -) -> bool { - if required.len() > provided.len() { - false - } else if required - .iter() - .zip(provided.iter()) - .all(|(order1, order2)| order1.eq(order2)) - { - true - } else if let eq_classes @ [_, ..] = equal_properties().classes() { - let normalized_required_exprs = required - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) - }) - .collect::>(); - let normalized_provided_exprs = provided - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) - }) - .collect::>(); - normalized_required_exprs - .iter() - .zip(normalized_provided_exprs.iter()) - .all(|(order1, order2)| order1.eq(order2)) - } else { - false - } -} - #[cfg(test)] mod tests { From a9a49121faf2b12cde206b8bde8902cf988b23fa Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 13 Feb 2023 16:34:43 +0300 Subject: [PATCH 30/39] Bug fix in cp_solver, revamp some comments --- .../core/src/physical_plan/joins/hash_join.rs | 1 + .../physical_plan/joins/hash_join_utils.rs | 115 +++++++++++- .../joins/symmetric_hash_join.rs | 62 ++++--- datafusion/core/tests/fifo.rs | 7 +- .../physical-expr/src/intervals/cp_solver.rs | 170 +++++++++++++++++- 5 files changed, 321 insertions(+), 34 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index 0ad7615539d0..d013995cdadb 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -777,6 +777,7 @@ pub fn build_equal_condition_join_indices( .into_array(build_input_buffer.num_rows())) }) .collect::>>()?; + hashes_buffer.clear(); hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; // Using a buffer builder to avoid slower normal builder diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 7824503afda2..ff3a88636b8d 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -189,12 +189,123 @@ fn convert_filter_columns( /// contains the following information: The origin expression, the filter /// expression, an interval encapsulating expression bounds, and a stable /// index identifying the expression in the expression DAG. +/// +/// Physical schema of intermediate batch of a JoinFilter combines two sides with new column names. +/// For sort information, we do the column exchange, to traverse the main filter expression. +/// +/// For evaluate the inner buffer, we use origin_sorted_expr. +/// For interval traversing, we use filter_expr, adjusted with intermediate schema. +/// +/// ``` +/// use std::sync::Arc; +/// use arrow::compute::SortOptions; +/// use arrow::datatypes::{Field, Schema, DataType}; +/// use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter, JoinSide}; +/// use datafusion_expr::Operator; +/// use datafusion_physical_expr::expressions::{Column, BinaryExpr}; +/// use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +/// use super::*; +/// +/// let left_child_schema = Arc::new(Schema::new(vec![Field::new( +/// "left_1", +/// DataType::Int32, +/// true, +/// )])); +/// // Will be origin expr +/// let left_child_sort_expr = PhysicalSortExpr { +/// expr: Arc::new(Column::new("left_1", 0)), +/// options: SortOptions::default(), +/// }; +/// +/// let right_child_schema = Arc::new(Schema::new(vec![ +/// Field::new("right_1", DataType::Int32, true), +/// Field::new("right_2", DataType::Int32, true), +/// ])); +/// // Will be origin expr +/// let right_child_sort_expr = PhysicalSortExpr { +/// expr: Arc::new(BinaryExpr::new( +/// Arc::new(Column::new("right_1", 0)), +/// Operator::Plus, +/// Arc::new(Column::new("right_2", 1)), +/// )), +/// options: SortOptions::default(), +/// }; +/// +/// let filter_col_1 = Arc::new(Column::new("filter_1", 0)); +/// let filter_col_2 = Arc::new(Column::new("filter_2", 1)); +/// let filter_col_3 = Arc::new(Column::new("filter_3", 2)); +/// +/// let column_indices = vec![ +/// ColumnIndex { +/// index: 0, +/// side: JoinSide::Left, +/// }, +/// ColumnIndex { +/// index: 0, +/// side: JoinSide::Right, +/// }, +/// ColumnIndex { +/// index: 1, +/// side: JoinSide::Right, +/// }, +/// ]; +/// let intermediate_schema = Schema::new(vec![ +/// Field::new(filter_col_1.name(), DataType::Int32, true), +/// Field::new(filter_col_2.name(), DataType::Int32, true), +/// Field::new(filter_col_3.name(), DataType::Int32, true), +/// ]); +/// // left_1 > right_1 + right_2 +/// let filter_expr = Arc::new(BinaryExpr::new( +/// filter_col_1, +/// Operator::Gt, +/// Arc::new(BinaryExpr::new(filter_col_2, Operator::Plus, filter_col_3)), +/// )); +/// let filter = JoinFilter::new( +/// filter_expr, +/// column_indices.clone(), +/// intermediate_schema.clone(), +/// ); +/// +/// let left_sort_filter_expr = build_filter_input_order( +/// JoinSide::Left, +/// &filter, +/// left_child_schema, +/// &left_child_sort_expr, +/// )?; +/// let right_sort_filter_expr = build_filter_input_order( +/// JoinSide::Right, +/// &filter, +/// right_child_schema, +/// &right_child_sort_expr, +/// )?; +/// +/// assert_eq!( +/// left_sort_filter_expr.origin_sorted_expr(), +/// left_child_sort_expr +/// ); +/// // Matches with left_child_sort_expr expression. +/// let expected_filter_expr: Arc = +/// Arc::new(Column::new("filter_1", 0)); +/// assert!(expected_filter_expr.eq(left_sort_filter_expr.filter_expr())); +/// assert_eq!( +/// right_sort_filter_expr.origin_sorted_expr(), +/// right_child_sort_expr +/// ); +/// // Matches with right_child_sort_expr expression. +/// let expected_filter_expr: Arc = Arc::new(BinaryExpr::new( +/// Arc::new(Column::new("filter_2", 1)), +/// Operator::Plus, +/// Arc::new(Column::new("filter_3", 2)), +/// )); +/// assert!(expected_filter_expr.eq(right_sort_filter_expr.filter_expr())); +/// +/// ``` +/// #[derive(Debug, Clone)] pub struct SortedFilterExpr { /// Sorted expression from a join side (i.e. a child of the join) origin_sorted_expr: PhysicalSortExpr, - /// For interval calculations, one to one mapping of the columns - /// according to filter expression and column indices. + /// Expression adjusted for filter schema. filter_expr: Arc, /// Interval containing expression bounds interval: Interval, diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index d70924279936..c187f8dfa0f4 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -276,6 +276,9 @@ impl SymmetricHashJoinExec { let mut physical_expr_graph = ExprIntervalGraph::try_new(filter.expression().clone())?; + // To produce global sorted expressions for interval operations, we use the first + // PhysicalSortExpr in child output ordering, taking the leftmost expression + // in lexicographical order. let (left_ordering, right_ordering) = match ( left.output_ordering(), right.output_ordering(), @@ -620,7 +623,7 @@ fn prune_hash_values( /// Calculate the filter expression intervals /// /// This function updates the `interval` field of each `SortedFilterExpr` based on -/// the first or last value of the expression in `build_input_buffer` or `probe_batch`. +/// the first or last value of the expression in `build_input_buffer` and `probe_batch`. /// /// # Arguments /// @@ -632,20 +635,6 @@ fn prune_hash_values( /// ### Note /// ```text /// -/// Build Probe -/// +-------+ +-------+ -/// |+-----+| | | | -/// FV ||a |b || | | | -/// |+--|--+| | | | -/// | | | | | | -/// | | | | | | -/// | | | | | | -/// | | | | | | -/// | | | |+-----+| -/// | | | ||c |d || LV -/// | | | |+--|--+| -/// +-------+ +-------+ -/// /// Interval arithmetic is used to calculate viable join ranges for build-side /// pruning. This is done by first creating an interval for join filter values in /// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the @@ -655,11 +644,36 @@ fn prune_hash_values( /// (ascending/descending) of the probe side. Here, LV denotes the last value on /// the probe side. /// -/// In our case: -/// Expr1: [a(FV), inf] -/// Expr2: [b(FV), inf] -/// Expr3: [c(LV), inf] -/// Expr4: [d(LV), inf] +/// Let’s check the query: +/// SELECT * FROM left_table, right_table +/// ON +/// left_key = right_key AND +/// a > b - 3 AND a < b + 10 +/// +/// When a new RecordBatch comes to the right side (”a” belongs to the left side, +/// “b” belongs to the right side), a > b - 3 will possibly indicate a prunable range for the left side. +/// When a new RecordBatch comes to the left side, a < b + 10 will possibly indicate +/// prunability for the right side. This is how we manage to prune both sides when we get the +/// data to two sides. Let’s select the build side (new [RecordBatch] comes for the right side) +/// as left: +/// +/// Build Probe +/// +-------+ +-------+ +/// | a | z | | b | y | +/// |+--|--+| |+--|--+| +/// | 1 | 2 | | 4 | 3 | +/// |+--|--+| |+--|--+| +/// | 3 | 1 | | 4 | 3 | +/// |+--|--+| |+--|--+| +/// | 5 | 7 | | 6 | 1 | +/// |+--|--+| |+--|--+| +/// | 7 | 1 | | 6 | 3 | +/// +-------+ +-------+ +/// +/// - The interval for the column “a” is [1, ∞], +/// - The interval for the column “b” is [6, ∞]. +/// +/// Then we calculate intervals and propagate the constraint by traversing the expression graph. /// ``` fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, @@ -1187,14 +1201,14 @@ impl OneSideHashJoiner { // Use the join filter intervals to update the physical expression graph: physical_expr_graph.update_ranges(&mut filter_intervals)?; // Get the new join filter interval for build side: - let new_build_side_sorted_filter_expr = filter_intervals.remove(0).1; + let calculated_build_side_interval = filter_intervals.remove(0).1; // Check if the intervals changed, exit early if not: - if new_build_side_sorted_filter_expr.eq(self.sorted_filter_expr.interval()) { + if calculated_build_side_interval.eq(self.sorted_filter_expr.interval()) { return Ok(None); } // Determine the pruning length if there was a change in the intervals: self.sorted_filter_expr - .set_interval(new_build_side_sorted_filter_expr); + .set_interval(calculated_build_side_interval); let prune_length = determine_prune_length(&self.input_buffer, &self.sorted_filter_expr)?; // If we can not prune, exit early: @@ -2167,7 +2181,7 @@ mod tests { Ok(()) } - #[tokio::test(flavor = "multi_thread", worker_threads = 20)] + #[tokio::test(flavor = "multi_thread")] async fn join_change_in_planner() -> Result<()> { let config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::with_config(config); diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 7910cec1caab..503516db8c5c 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -312,9 +312,10 @@ mod unix_test { }; operations.push(op); } - // It check if the right unmatched or left unmatched results generated at the end of the file and only once - // in hash join executor. SymmetricHashJoin exec produce it without waiting the end of the file - // and produce it more than once. + // This check ensures that the right unmatched or left unmatched results are not generated + // only once by the hash join executor at the end of the file. This is a valid check because + // the SymmetricHashJoin executor produces these results without waiting for the end of + // the file, and will produce them more than once. assert!( operations .iter() diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index c62a2d5c0594..80911cfdf221 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -17,6 +17,7 @@ //! Constraint propagator/solver for custom PhysicalExpr graphs. +use std::collections::{HashSet, VecDeque}; use std::fmt::{Display, Formatter}; use std::ops::Index; use std::sync::Arc; @@ -26,7 +27,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; use petgraph::stable_graph::{DefaultIx, StableGraph}; -use petgraph::visit::{Bfs, DfsPostOrder}; +use petgraph::visit::{Bfs, DfsPostOrder, EdgeRef, VisitMap, Visitable}; use petgraph::Outgoing; use crate::expressions::{BinaryExpr, CastExpr, Literal}; @@ -276,6 +277,10 @@ impl ExprIntervalGraph { Ok(Self { graph, root }) } + pub fn node_count(&self) -> usize { + self.graph.node_count() + } + // Sometimes, we do not want to calculate and/or propagate intervals all // way down to leaf expressions. For example, assume that we have a // `SymmetricHashJoin` which has a child with an output ordering like: @@ -347,18 +352,59 @@ impl ExprIntervalGraph { if let Some(value) = exprs.iter().position(|e| expr.eq(e)) { // Update the node index of the associated `PhysicalExpr`: expr_node_indices[value].1 = node.index(); - for n_index in graph.neighbors_directed(node, Outgoing) { + for edge_reference in graph.edges_directed(node, Outgoing) { // Slate the child for removal, do not remove immediately. - removals.push(n_index); + removals.push(edge_reference.id()); } } } - for node in removals { - self.graph.remove_node(node); + for edge_reference in removals { + self.graph.remove_edge(edge_reference); } + // Get the set of node indices connected to the root node + let connected_nodes = self.connected_nodes(); + // Remove nodes not connected to the root node + self.graph + .retain_nodes(|_graph, index| connected_nodes.contains(&index)); expr_node_indices } + /// Returns a set of node indices connected to the root node + /// via any directed edges in the graph. + fn connected_nodes(&self) -> HashSet { + // Create a visit map to keep track of visited nodes. + let mut visited = self.graph.visit_map(); + // Create a new set to store the connected nodes. + let mut nodes = HashSet::new(); + // Create a queue to store nodes to be visited. + let mut visit_next = VecDeque::new(); + // Start the traversal from the root node. + visit_next.push_back(self.root); + // Add the root node to the set of connected nodes. + nodes.insert(self.root); + // Continue the traversal while there are nodes to be visited. + while let Some(node) = visit_next.pop_back() { + // If the node has already been visited, skip it. + if visited.is_visited(&node) { + continue; + } + // For each outgoing edge from the current node, add the target node to the set of connected nodes + // and add it to the queue of nodes to be visited if it hasn't been visited already. + for edge in self.graph.edges(node) { + let next = edge.target(); + if visited.is_visited(&next) { + continue; + } + nodes.insert(next); + visit_next.push_back(next); + } + // Mark the current node as visited. + visited.visit(node); + } + // Return the set of connected nodes. + nodes + } + /// This function assigns given ranges to expressions in the DAEG. /// The argument `assignments` associates indices of sought expressions /// with their corresponding new ranges. @@ -906,4 +952,118 @@ mod tests { generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; Ok(()) } + + #[test] + fn test_gather_node_indices_not_remove() -> Result<()> { + // Expression: a@0 + b@1 + 1 > a@0 - b@1 + // Provide a@0 + b@1, not remove a@0 or b@1, only remove edges since a@0 - b@1 + // has also leaf nodes a@0 and b@1. + let left_expr = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + + let right_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Minus, + Arc::new(Column::new("b", 1)), + )); + let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); + let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + // Define a test leaf node. + let leaf_node = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + // Store the current node count. + let prev_node_count = graph.node_count(); + // Gather the index of node in the expression graph that match the test leaf node. + graph.gather_node_indices(&[leaf_node]); + // Store the final node count. + let final_node_count = graph.node_count(); + // Assert that the final node count is equal the previous node count (i.e., no node was pruned). + assert_eq!(prev_node_count, final_node_count); + Ok(()) + } + #[test] + fn test_gather_node_indices_remove() -> Result<()> { + // Expression: a@0 + b@1 + 1 > y@0 - z@1 + // We expect that 2 nodes will be pruned since we do not need a@ and b@ + let left_expr = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + + let right_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 0)), + Operator::Minus, + Arc::new(Column::new("z", 1)), + )); + let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); + let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + // Define a test leaf node. + let leaf_node = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + // Store the current node count. + let prev_node_count = graph.node_count(); + // Gather the index of node in the expression graph that match the test leaf node. + graph.gather_node_indices(&[leaf_node]); + // Store the final node count. + let final_node_count = graph.node_count(); + // Assert that the final node count is two less than the previous node count (i.e., two node was pruned). + assert_eq!(prev_node_count, final_node_count + 2); + Ok(()) + } + + #[test] + fn test_gather_node_indices_remove_one() -> Result<()> { + // Expression: a@0 + b@1 + 1 > a@0 - z@1 + // We expect that 1 nodes will be pruned since we do not need b@ + let left_expr = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + + let right_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Minus, + Arc::new(Column::new("z", 1)), + )); + let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); + let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + // Define a test leaf node. + let leaf_node = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + // Store the current node count. + let prev_node_count = graph.node_count(); + // Gather the index of node in the expression graph that match the test leaf node. + graph.gather_node_indices(&[leaf_node]); + // Store the final node count. + let final_node_count = graph.node_count(); + // Assert that the final node count is one less than the previous node count (i.e., one node was pruned). + assert_eq!(prev_node_count, final_node_count + 1); + Ok(()) + } } From 52e143d30595fe3850134076317c82ccb6d91540 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 13 Feb 2023 18:22:42 +0300 Subject: [PATCH 31/39] Update with correct testing --- .../physical_plan/joins/hash_join_utils.rs | 195 ++++++++---------- datafusion/core/tests/fifo.rs | 9 +- 2 files changed, 94 insertions(+), 110 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index ff3a88636b8d..2cf2505da3d4 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -195,112 +195,6 @@ fn convert_filter_columns( /// /// For evaluate the inner buffer, we use origin_sorted_expr. /// For interval traversing, we use filter_expr, adjusted with intermediate schema. -/// -/// ``` -/// use std::sync::Arc; -/// use arrow::compute::SortOptions; -/// use arrow::datatypes::{Field, Schema, DataType}; -/// use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter, JoinSide}; -/// use datafusion_expr::Operator; -/// use datafusion_physical_expr::expressions::{Column, BinaryExpr}; -/// use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -/// use super::*; -/// -/// let left_child_schema = Arc::new(Schema::new(vec![Field::new( -/// "left_1", -/// DataType::Int32, -/// true, -/// )])); -/// // Will be origin expr -/// let left_child_sort_expr = PhysicalSortExpr { -/// expr: Arc::new(Column::new("left_1", 0)), -/// options: SortOptions::default(), -/// }; -/// -/// let right_child_schema = Arc::new(Schema::new(vec![ -/// Field::new("right_1", DataType::Int32, true), -/// Field::new("right_2", DataType::Int32, true), -/// ])); -/// // Will be origin expr -/// let right_child_sort_expr = PhysicalSortExpr { -/// expr: Arc::new(BinaryExpr::new( -/// Arc::new(Column::new("right_1", 0)), -/// Operator::Plus, -/// Arc::new(Column::new("right_2", 1)), -/// )), -/// options: SortOptions::default(), -/// }; -/// -/// let filter_col_1 = Arc::new(Column::new("filter_1", 0)); -/// let filter_col_2 = Arc::new(Column::new("filter_2", 1)); -/// let filter_col_3 = Arc::new(Column::new("filter_3", 2)); -/// -/// let column_indices = vec![ -/// ColumnIndex { -/// index: 0, -/// side: JoinSide::Left, -/// }, -/// ColumnIndex { -/// index: 0, -/// side: JoinSide::Right, -/// }, -/// ColumnIndex { -/// index: 1, -/// side: JoinSide::Right, -/// }, -/// ]; -/// let intermediate_schema = Schema::new(vec![ -/// Field::new(filter_col_1.name(), DataType::Int32, true), -/// Field::new(filter_col_2.name(), DataType::Int32, true), -/// Field::new(filter_col_3.name(), DataType::Int32, true), -/// ]); -/// // left_1 > right_1 + right_2 -/// let filter_expr = Arc::new(BinaryExpr::new( -/// filter_col_1, -/// Operator::Gt, -/// Arc::new(BinaryExpr::new(filter_col_2, Operator::Plus, filter_col_3)), -/// )); -/// let filter = JoinFilter::new( -/// filter_expr, -/// column_indices.clone(), -/// intermediate_schema.clone(), -/// ); -/// -/// let left_sort_filter_expr = build_filter_input_order( -/// JoinSide::Left, -/// &filter, -/// left_child_schema, -/// &left_child_sort_expr, -/// )?; -/// let right_sort_filter_expr = build_filter_input_order( -/// JoinSide::Right, -/// &filter, -/// right_child_schema, -/// &right_child_sort_expr, -/// )?; -/// -/// assert_eq!( -/// left_sort_filter_expr.origin_sorted_expr(), -/// left_child_sort_expr -/// ); -/// // Matches with left_child_sort_expr expression. -/// let expected_filter_expr: Arc = -/// Arc::new(Column::new("filter_1", 0)); -/// assert!(expected_filter_expr.eq(left_sort_filter_expr.filter_expr())); -/// assert_eq!( -/// right_sort_filter_expr.origin_sorted_expr(), -/// right_child_sort_expr -/// ); -/// // Matches with right_child_sort_expr expression. -/// let expected_filter_expr: Arc = Arc::new(BinaryExpr::new( -/// Arc::new(Column::new("filter_2", 1)), -/// Operator::Plus, -/// Arc::new(Column::new("filter_3", 2)), -/// )); -/// assert!(expected_filter_expr.eq(right_sort_filter_expr.filter_expr())); -/// -/// ``` -/// #[derive(Debug, Clone)] pub struct SortedFilterExpr { /// Sorted expression from a join side (i.e. a child of the join) @@ -420,6 +314,95 @@ pub mod tests { )) } + #[test] + fn test_column_exchange() -> Result<()> { + let left_child_schema = Arc::new(Schema::new(vec![Field::new( + "left_1", + DataType::Int32, + true, + )])); + // Will be origin expr + let left_child_sort_expr = PhysicalSortExpr { + expr: Arc::new(Column::new("left_1", 0)), + options: SortOptions::default(), + }; + + let right_child_schema = Arc::new(Schema::new(vec![ + Field::new("right_1", DataType::Int32, true), + Field::new("right_2", DataType::Int32, true), + ])); + // Will be origin expr + let right_child_sort_expr = PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("right_1", 0)), + Operator::Plus, + Arc::new(Column::new("right_2", 1)), + )), + options: SortOptions::default(), + }; + + let filter_col_1 = Arc::new(Column::new("filter_1", 0)); + let filter_col_2 = Arc::new(Column::new("filter_2", 1)); + let filter_col_3 = Arc::new(Column::new("filter_3", 2)); + + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(filter_col_1.name(), DataType::Int32, true), + Field::new(filter_col_2.name(), DataType::Int32, true), + Field::new(filter_col_3.name(), DataType::Int32, true), + ]); + // left_1 > right_1 + right_2 + let filter_expr = Arc::new(BinaryExpr::new( + filter_col_1, + Operator::Gt, + Arc::new(BinaryExpr::new(filter_col_2, Operator::Plus, filter_col_3)), + )); + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + let left_sort_filter_expr = build_filter_input_order( + JoinSide::Left, + &filter, + &left_child_schema, + &left_child_sort_expr, + )?; + let right_sort_filter_expr = build_filter_input_order( + JoinSide::Right, + &filter, + &right_child_schema, + &right_child_sort_expr, + )?; + + assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr()),); + // Expression that adjusted with filter schema + let expected_filter_expr: Arc = + Arc::new(Column::new("filter_1", 0)); + // Matches with left_child_sort_expr expression. + assert!(expected_filter_expr.eq(left_sort_filter_expr.filter_expr())); + assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr()),); + // Expression that adjusted with filter schema + let expected_filter_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("filter_2", 1)), + Operator::Plus, + Arc::new(Column::new("filter_3", 2)), + )); + // Matches with right_child_sort_expr expression. + assert!(expected_filter_expr.eq(right_sort_filter_expr.filter_expr())); + Ok(()) + } + #[test] fn test_column_collector() { let filter_expr = complicated_filter(); diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 503516db8c5c..f586d020c3bc 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -312,10 +312,11 @@ mod unix_test { }; operations.push(op); } - // This check ensures that the right unmatched or left unmatched results are not generated - // only once by the hash join executor at the end of the file. This is a valid check because - // the SymmetricHashJoin executor produces these results without waiting for the end of - // the file, and will produce them more than once. + + // The SymmetricHashJoin executor produces FULL join results at the each prune, + // which happens earlier than the end of the file and more than once. + // We provide partially joinable data on both sides to ensure that the right unmatched and left unmatched + // results are generated more than once. assert!( operations .iter() From 5ec2eae1bb7a7c9518263ec16a237c8f4a8cdf1d Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Tue, 14 Feb 2023 10:36:14 +0300 Subject: [PATCH 32/39] Test for future support on fuzzy matches between exprs --- .../physical-expr/src/intervals/cp_solver.rs | 49 +++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 80911cfdf221..415531110293 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -955,8 +955,8 @@ mod tests { #[test] fn test_gather_node_indices_not_remove() -> Result<()> { - // Expression: a@0 + b@1 + 1 > a@0 - b@1 - // Provide a@0 + b@1, not remove a@0 or b@1, only remove edges since a@0 - b@1 + // Expression: a@0 + b@1 + 1 > a@0 - b@1 -> provide a@0 + b@1 + // Not remove a@0 or b@1, only remove edges since a@0 - b@1 // has also leaf nodes a@0 and b@1. let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( @@ -993,7 +993,7 @@ mod tests { } #[test] fn test_gather_node_indices_remove() -> Result<()> { - // Expression: a@0 + b@1 + 1 > y@0 - z@1 + // Expression: a@0 + b@1 + 1 > y@0 - z@1 -> provide a@0 + b@1 // We expect that 2 nodes will be pruned since we do not need a@ and b@ let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( @@ -1031,7 +1031,7 @@ mod tests { #[test] fn test_gather_node_indices_remove_one() -> Result<()> { - // Expression: a@0 + b@1 + 1 > a@0 - z@1 + // Expression: a@0 + b@1 + 1 > a@0 - z@1 -> provide a@0 + b@1 // We expect that 1 nodes will be pruned since we do not need b@ let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( @@ -1066,4 +1066,45 @@ mod tests { assert_eq!(prev_node_count, final_node_count + 1); Ok(()) } + + #[test] + fn test_gather_node_indices_cannot_provide() -> Result<()> { + // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1 + // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node. + // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. + // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. + // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. + let left_expr = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + + let right_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 0)), + Operator::Minus, + Arc::new(Column::new("z", 1)), + )); + let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); + let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + // Define a test leaf node. + let leaf_node = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + // Store the current node count. + let prev_node_count = graph.node_count(); + // Gather the index of node in the expression graph that match the test leaf node. + graph.gather_node_indices(&[leaf_node]); + // Store the final node count. + let final_node_count = graph.node_count(); + // Assert that the final node count is equal the previous node count (i.e., no node was pruned). + assert_eq!(prev_node_count, final_node_count); + Ok(()) + } } From 175133ba71108dbe861f59e8c50e2a7dc80feb4f Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Thu, 16 Feb 2023 20:56:13 -0600 Subject: [PATCH 33/39] Compute connected nodes in CP solver via a DFS, improve comments --- .../physical_plan/joins/hash_join_utils.rs | 29 +++---- .../joins/symmetric_hash_join.rs | 62 ++++++++------- datafusion/core/tests/fifo.rs | 9 ++- .../physical-expr/src/intervals/cp_solver.rs | 78 +++++++------------ 4 files changed, 83 insertions(+), 95 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 2cf2505da3d4..08b07a79fef1 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -190,11 +190,11 @@ fn convert_filter_columns( /// expression, an interval encapsulating expression bounds, and a stable /// index identifying the expression in the expression DAG. /// -/// Physical schema of intermediate batch of a JoinFilter combines two sides with new column names. -/// For sort information, we do the column exchange, to traverse the main filter expression. -/// -/// For evaluate the inner buffer, we use origin_sorted_expr. -/// For interval traversing, we use filter_expr, adjusted with intermediate schema. +/// Physical schema of a [JoinFilter]'s intermediate batch combines two sides +/// and uses new column names. In this process, a column exchange is done so +/// we can utilize sorting information while traversing the filter expression +/// DAG for interval calculations. When evaluating the inner buffer, we use +/// `origin_sorted_expr`. #[derive(Debug, Clone)] pub struct SortedFilterExpr { /// Sorted expression from a join side (i.e. a child of the join) @@ -321,7 +321,7 @@ pub mod tests { DataType::Int32, true, )])); - // Will be origin expr + // Sorting information for the left side: let left_child_sort_expr = PhysicalSortExpr { expr: Arc::new(Column::new("left_1", 0)), options: SortOptions::default(), @@ -331,7 +331,7 @@ pub mod tests { Field::new("right_1", DataType::Int32, true), Field::new("right_2", DataType::Int32, true), ])); - // Will be origin expr + // Sorting information for the right side: let right_child_sort_expr = PhysicalSortExpr { expr: Arc::new(BinaryExpr::new( Arc::new(Column::new("right_1", 0)), @@ -364,7 +364,7 @@ pub mod tests { Field::new(filter_col_2.name(), DataType::Int32, true), Field::new(filter_col_3.name(), DataType::Int32, true), ]); - // left_1 > right_1 + right_2 + // Our filter expression is: left_1 > right_1 + right_2. let filter_expr = Arc::new(BinaryExpr::new( filter_col_1, Operator::Gt, @@ -378,27 +378,28 @@ pub mod tests { &left_child_schema, &left_child_sort_expr, )?; + assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr())); + let right_sort_filter_expr = build_filter_input_order( JoinSide::Right, &filter, &right_child_schema, &right_child_sort_expr, )?; + assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr())); - assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr()),); - // Expression that adjusted with filter schema + // Filter expression (left) adjusted for filter schema: let expected_filter_expr: Arc = Arc::new(Column::new("filter_1", 0)); - // Matches with left_child_sort_expr expression. + // Assert that it matches with the `left_child_sort_expr`: assert!(expected_filter_expr.eq(left_sort_filter_expr.filter_expr())); - assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr()),); - // Expression that adjusted with filter schema + // Filter expression (right) adjusted for filter schema: let expected_filter_expr: Arc = Arc::new(BinaryExpr::new( Arc::new(Column::new("filter_2", 1)), Operator::Plus, Arc::new(Column::new("filter_3", 2)), )); - // Matches with right_child_sort_expr expression. + // Assert that it matches with the `right_child_sort_expr`: assert!(expected_filter_expr.eq(right_sort_filter_expr.filter_expr())); Ok(()) } diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index c187f8dfa0f4..797f568d6aef 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -255,30 +255,33 @@ impl SymmetricHashJoinExec { let left_schema = left.schema(); let right_schema = right.schema(); - // Return error if the on constraints in the join are empty + // Error out if no "on" contraints are given: if on.is_empty() { return Err(DataFusionError::Plan( - "On constraints in HashJoinExec should be non-empty".to_string(), + "On constraints in SymmetricHashJoinExec should be non-empty".to_string(), )); } - // Check if the join is valid with the given on constraints + // Check if the join is valid with the given on constraints: check_join_is_valid(&left_schema, &right_schema, &on)?; - // Build the join schema from the left and right schemas + // Build the join schema from the left and right schemas: let (schema, column_indices) = build_join_schema(&left_schema, &right_schema, join_type); - // Set a random state for the join + // Set a random state for the join: let random_state = RandomState::with_seeds(0, 0, 0, 0); - // Create expression graph for PhysicalExpr + // Create an expression DAG for the join filter: let mut physical_expr_graph = ExprIntervalGraph::try_new(filter.expression().clone())?; - // To produce global sorted expressions for interval operations, we use the first - // PhysicalSortExpr in child output ordering, taking the leftmost expression - // in lexicographical order. + // Interval calculations require each column to exhibit monotonicity + // independently. However, a `PhysicalSortExpr` object defines a + // lexicographical ordering, so we can only use their first elements. + // when deducing column monotonicities. + // TODO: Extend the `PhysicalSortExpr` mechanism to express independent + // (i.e. simultaneous) ordering properties of columns. let (left_ordering, right_ordering) = match ( left.output_ordering(), right.output_ordering(), @@ -620,10 +623,11 @@ fn prune_hash_values( Ok(()) } -/// Calculate the filter expression intervals +/// Calculate the filter expression intervals. /// -/// This function updates the `interval` field of each `SortedFilterExpr` based on -/// the first or last value of the expression in `build_input_buffer` and `probe_batch`. +/// This function updates the `interval` field of each `SortedFilterExpr` based +/// on the first or the last value of the expression in `build_input_buffer` +/// and `probe_batch`. /// /// # Arguments /// @@ -644,18 +648,21 @@ fn prune_hash_values( /// (ascending/descending) of the probe side. Here, LV denotes the last value on /// the probe side. /// -/// Let’s check the query: -/// SELECT * FROM left_table, right_table -/// ON -/// left_key = right_key AND -/// a > b - 3 AND a < b + 10 +/// As a concrete example, consider the following query: /// -/// When a new RecordBatch comes to the right side (”a” belongs to the left side, -/// “b” belongs to the right side), a > b - 3 will possibly indicate a prunable range for the left side. -/// When a new RecordBatch comes to the left side, a < b + 10 will possibly indicate -/// prunability for the right side. This is how we manage to prune both sides when we get the -/// data to two sides. Let’s select the build side (new [RecordBatch] comes for the right side) -/// as left: +/// SELECT * FROM left_table, right_table +/// WHERE +/// left_key = right_key AND +/// a > b - 3 AND +/// a < b + 10 +/// +/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// respectively. When a new `RecordBatch` arrives at the right side, the +/// condition a > b - 3 will possibly indicate a prunable range for the left +/// side. Conversely, when a new `RecordBatch` arrives at the left side, the +/// condition a < b + 10 will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// side (i.e. when the left side is the build side): /// /// Build Probe /// +-------+ +-------+ @@ -670,10 +677,11 @@ fn prune_hash_values( /// | 7 | 1 | | 6 | 3 | /// +-------+ +-------+ /// -/// - The interval for the column “a” is [1, ∞], -/// - The interval for the column “b” is [6, ∞]. -/// -/// Then we calculate intervals and propagate the constraint by traversing the expression graph. +/// In this case, the interval representing viable (i.e. joinable) values for +/// column "a" is [1, ∞], and the interval representing possible future values +/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// intervals for the whole filter expression and propagate join constraint by +/// traversing the expression graph. /// ``` fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index f586d020c3bc..d8b66b0a29cd 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -313,10 +313,11 @@ mod unix_test { operations.push(op); } - // The SymmetricHashJoin executor produces FULL join results at the each prune, - // which happens earlier than the end of the file and more than once. - // We provide partially joinable data on both sides to ensure that the right unmatched and left unmatched - // results are generated more than once. + // The SymmetricHashJoin executor produces FULL join results at every + // pruning, which happens before it reaches the end of input and more + // than once. In this test, we feed partially joinable data to both + // sides in order to ensure that both left/right unmatched results are + // generated more than once during the test. assert!( operations .iter() diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 415531110293..8eb43393b46e 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -17,7 +17,7 @@ //! Constraint propagator/solver for custom PhysicalExpr graphs. -use std::collections::{HashSet, VecDeque}; +use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::ops::Index; use std::sync::Arc; @@ -27,7 +27,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; use petgraph::stable_graph::{DefaultIx, StableGraph}; -use petgraph::visit::{Bfs, DfsPostOrder, EdgeRef, VisitMap, Visitable}; +use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef}; use petgraph::Outgoing; use crate::expressions::{BinaryExpr, CastExpr, Literal}; @@ -352,56 +352,31 @@ impl ExprIntervalGraph { if let Some(value) = exprs.iter().position(|e| expr.eq(e)) { // Update the node index of the associated `PhysicalExpr`: expr_node_indices[value].1 = node.index(); - for edge_reference in graph.edges_directed(node, Outgoing) { + for edge in graph.edges_directed(node, Outgoing) { // Slate the child for removal, do not remove immediately. - removals.push(edge_reference.id()); + removals.push(edge.id()); } } } - for edge_reference in removals { - self.graph.remove_edge(edge_reference); + for edge_idx in removals { + self.graph.remove_edge(edge_idx); } - // Get the set of node indices connected to the root node + // Get the set of node indices reachable from the root node: let connected_nodes = self.connected_nodes(); - // Remove nodes not connected to the root node + // Remove nodes not connected to the root node: self.graph - .retain_nodes(|_graph, index| connected_nodes.contains(&index)); + .retain_nodes(|_, index| connected_nodes.contains(&index)); expr_node_indices } - /// Returns a set of node indices connected to the root node - /// via any directed edges in the graph. + /// Returns the set of node indices reachable from the root node via a + /// simple depth-first search. fn connected_nodes(&self) -> HashSet { - // Create a visit map to keep track of visited nodes. - let mut visited = self.graph.visit_map(); - // Create a new set to store the connected nodes. let mut nodes = HashSet::new(); - // Create a queue to store nodes to be visited. - let mut visit_next = VecDeque::new(); - // Start the traversal from the root node. - visit_next.push_back(self.root); - // Add the root node to the set of connected nodes. - nodes.insert(self.root); - // Continue the traversal while there are nodes to be visited. - while let Some(node) = visit_next.pop_back() { - // If the node has already been visited, skip it. - if visited.is_visited(&node) { - continue; - } - // For each outgoing edge from the current node, add the target node to the set of connected nodes - // and add it to the queue of nodes to be visited if it hasn't been visited already. - for edge in self.graph.edges(node) { - let next = edge.target(); - if visited.is_visited(&next) { - continue; - } - nodes.insert(next); - visit_next.push_back(next); - } - // Mark the current node as visited. - visited.visit(node); + let mut dfs = Dfs::new(&self.graph, self.root); + while let Some(node) = dfs.next(&self.graph) { + nodes.insert(node); } - // Return the set of connected nodes. nodes } @@ -954,10 +929,10 @@ mod tests { } #[test] - fn test_gather_node_indices_not_remove() -> Result<()> { - // Expression: a@0 + b@1 + 1 > a@0 - b@1 -> provide a@0 + b@1 - // Not remove a@0 or b@1, only remove edges since a@0 - b@1 - // has also leaf nodes a@0 and b@1. + fn test_gather_node_indices_dont_remove() -> Result<()> { + // Expression: a@0 + b@1 + 1 > a@0 - b@1, given a@0 + b@1. + // Do not remove a@0 or b@1, only remove edges since a@0 - b@1 also + // depends on leaf nodes a@0 and b@1. let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -987,14 +962,15 @@ mod tests { graph.gather_node_indices(&[leaf_node]); // Store the final node count. let final_node_count = graph.node_count(); - // Assert that the final node count is equal the previous node count (i.e., no node was pruned). + // Assert that the final node count is equal the previous node count. + // This means we did not remove any node. assert_eq!(prev_node_count, final_node_count); Ok(()) } #[test] fn test_gather_node_indices_remove() -> Result<()> { - // Expression: a@0 + b@1 + 1 > y@0 - z@1 -> provide a@0 + b@1 - // We expect that 2 nodes will be pruned since we do not need a@ and b@ + // Expression: a@0 + b@1 + 1 > y@0 - z@1, given a@0 + b@1. + // We expect to remove two nodes since we do not need a@ and b@. let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1024,15 +1000,16 @@ mod tests { graph.gather_node_indices(&[leaf_node]); // Store the final node count. let final_node_count = graph.node_count(); - // Assert that the final node count is two less than the previous node count (i.e., two node was pruned). + // Assert that the final node count is two less than the previous node + // count; i.e. that we did remove two nodes. assert_eq!(prev_node_count, final_node_count + 2); Ok(()) } #[test] fn test_gather_node_indices_remove_one() -> Result<()> { - // Expression: a@0 + b@1 + 1 > a@0 - z@1 -> provide a@0 + b@1 - // We expect that 1 nodes will be pruned since we do not need b@ + // Expression: a@0 + b@1 + 1 > a@0 - z@1, given a@0 + b@1. + // We expect to remove one nodesince we still need a@ but not b@. let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1062,7 +1039,8 @@ mod tests { graph.gather_node_indices(&[leaf_node]); // Store the final node count. let final_node_count = graph.node_count(); - // Assert that the final node count is one less than the previous node count (i.e., one node was pruned). + // Assert that the final node count is one less than the previous node + // count; i.e. that we did remove two nodes. assert_eq!(prev_node_count, final_node_count + 1); Ok(()) } From fc833869b9dfafdeadc563acdf32fc18e697b24e Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Fri, 17 Feb 2023 13:36:04 +0300 Subject: [PATCH 34/39] Revamp OneSideHashJoin constructor and new unit test --- .../joins/symmetric_hash_join.rs | 221 +++++++++++++++--- 1 file changed, 184 insertions(+), 37 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 797f568d6aef..2c1a84669eb3 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -510,13 +510,13 @@ impl ExecutionPlan for SymmetricHashJoinExec { JoinSide::Left, self.sorted_filter_exprs[0].clone(), on_left, - self.left.clone(), + self.left.schema(), ); let right_side_joiner = OneSideHashJoiner::new( JoinSide::Right, self.sorted_filter_exprs[1].clone(), on_right, - self.right.clone(), + self.right.schema(), ); let left_stream = self.left.execute(partition, context.clone())?; let right_stream = self.right.execute(partition, context)?; @@ -974,11 +974,11 @@ impl OneSideHashJoiner { build_side: JoinSide, sorted_filter_expr: SortedFilterExpr, on: Vec, - plan: Arc, + schema: SchemaRef, ) -> Self { Self { build_side, - input_buffer: RecordBatch::new_empty(plan.schema()), + input_buffer: RecordBatch::new_empty(schema), on, hashmap: JoinHashMap(RawTable::with_capacity(10_000)), row_hash_values: VecDeque::new(), @@ -1532,7 +1532,7 @@ mod tests { null_equals_null: bool, context: Arc, ) -> Result> { - let partition_count = 1; + let partition_count = 4; let (left_expr, right_expr) = on .iter() @@ -1572,15 +1572,15 @@ mod tests { pub fn split_record_batches( batch: &RecordBatch, - num_split: usize, + batch_size: usize, ) -> Result> { let row_num = batch.num_rows(); - let number_of_batch = row_num / num_split; - let mut sizes = vec![num_split; number_of_batch]; - sizes.push(row_num - (num_split * number_of_batch)); + let number_of_batch = row_num / batch_size; + let mut sizes = vec![batch_size; number_of_batch]; + sizes.push(row_num - (batch_size * number_of_batch)); let mut result = vec![]; for (i, size) in sizes.iter().enumerate() { - result.push(batch.slice(i * num_split, *size)); + result.push(batch.slice(i * batch_size, *size)); } Ok(result) } @@ -1604,14 +1604,14 @@ mod tests { fn join_expr_tests_fixture( expr_id: usize, - left_watermark: Arc, - right_watermark: Arc, + left_col: Arc, + right_col: Arc, ) -> Arc { match expr_id { - // left_watermark + 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 + // left_col + 1 > right_col + 5 AND left_col + 3 < right_col + 10 0 => gen_conjunctive_numeric_expr( - left_watermark, - right_watermark, + left_col, + right_col, Operator::Plus, Operator::Plus, Operator::Plus, @@ -1621,10 +1621,10 @@ mod tests { 3, 10, ), - // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 + // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 1 => gen_conjunctive_numeric_expr( - left_watermark, - right_watermark, + left_col, + right_col, Operator::Minus, Operator::Plus, Operator::Plus, @@ -1634,10 +1634,10 @@ mod tests { 3, 10, ), - // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10 + // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 2 => gen_conjunctive_numeric_expr( - left_watermark, - right_watermark, + left_col, + right_col, Operator::Minus, Operator::Plus, Operator::Minus, @@ -1647,10 +1647,10 @@ mod tests { 3, 10, ), - // left_watermark - 10 > right_watermark - 5 AND left_watermark - 3 < right_watermark + 10 + // left_col - 10 > right_col - 5 AND left_col - 3 < right_col + 10 3 => gen_conjunctive_numeric_expr( - left_watermark, - right_watermark, + left_col, + right_col, Operator::Minus, Operator::Minus, Operator::Minus, @@ -1660,10 +1660,10 @@ mod tests { 3, 10, ), - // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 + // left_col - 10 > right_col - 5 AND left_col - 30 < right_col - 3 4 => gen_conjunctive_numeric_expr( - left_watermark, - right_watermark, + left_col, + right_col, Operator::Minus, Operator::Minus, Operator::Minus, @@ -1836,16 +1836,17 @@ mod tests { right_batch: RecordBatch, left_sorted: Vec, right_sorted: Vec, + batch_size: usize, ) -> Result<(Arc, Arc)> { Ok(( Arc::new(MemoryExec::try_new_with_sort_information( - &[split_record_batches(&left_batch, 13).unwrap()], + &[split_record_batches(&left_batch, batch_size).unwrap()], left_batch.schema(), None, Some(left_sorted), )?), Arc::new(MemoryExec::try_new_with_sort_information( - &[split_record_batches(&right_batch, 13).unwrap()], + &[split_record_batches(&right_batch, batch_size).unwrap()], right_batch.schema(), None, Some(right_sorted), @@ -1927,7 +1928,7 @@ mod tests { options: SortOptions::default(), }]; let (left, right) = - create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -2007,7 +2008,7 @@ mod tests { options: SortOptions::default(), }]; let (left, right) = - create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -2070,7 +2071,7 @@ mod tests { }, }]; let (left, right) = - create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -2151,7 +2152,7 @@ mod tests { }, }]; let (left, right) = - create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -2253,7 +2254,7 @@ mod tests { }, }]; let (left, right) = - create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -2321,7 +2322,7 @@ mod tests { }, }]; let (left, right) = - create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -2390,7 +2391,7 @@ mod tests { }, }]; let (left, right) = - create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -2449,7 +2450,7 @@ mod tests { options: SortOptions::default(), }]; let (left, right) = - create_memory_table(left_batch, right_batch, left_sorted, right_sorted)?; + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( Column::new_with_schema("lc1", &left.schema())?, @@ -2491,4 +2492,150 @@ mod tests { experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) } + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn test_one_side_hash_joiner_visited_rows( + #[values( + (JoinType::Inner, true), + (JoinType::Left,false), + (JoinType::Right, true), + (JoinType::RightSemi, true), + (JoinType::LeftSemi, false), + (JoinType::LeftAnti, false), + (JoinType::RightAnti, true), + (JoinType::Full, false), + )] + case: (JoinType, bool), + ) -> Result<()> { + // Set a random state for the join + let join_type = case.0; + let should_be_empty = case.1; + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + // Ensure there will be matching rows + let (left_batch, right_batch) = build_sides_record_batches(20, (1, 1))?; + let (left_schema, right_schema) = (left_batch.schema(), right_batch.schema()); + + // Build the join schema from the left and right schemas + let (schema, join_column_indices) = + build_join_schema(&left_schema, &right_schema, &join_type); + let join_schema = Arc::new(schema); + + // Sort information for MemoryExec + let left_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), + options: SortOptions::default(), + }]; + // Sort information for MemoryExec + let right_sorted = vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("ra1", &right_batch.schema())?), + options: SortOptions::default(), + }]; + // Construct MemoryExec + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 10)?; + + // Filter columns + let filter_col_0 = Arc::new(Column::new("0", 0)); + let filter_col_1 = Arc::new(Column::new("1", 1)); + + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new(filter_col_0.name(), DataType::Int32, true), + Field::new(filter_col_1.name(), DataType::Int32, true), + ]); + + // Ensure first batches will have matching rows. + let filter_expr = gen_conjunctive_numeric_expr( + filter_col_0.clone(), + filter_col_1.clone(), + Operator::Plus, + Operator::Minus, + Operator::Plus, + Operator::Plus, + 0, + 3, + 0, + 3, + ); + + let filter = JoinFilter::new( + filter_expr, + column_indices.clone(), + intermediate_schema.clone(), + ); + + let left_sorted_filter_expr = SortedFilterExpr::new( + PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("la1", &left_schema)?), + options: SortOptions::default(), + }, + Arc::new(Column::new(filter_col_0.name(), 0)), + ); + + let mut left_side_joiner = OneSideHashJoiner::new( + JoinSide::Left, + left_sorted_filter_expr, + vec![Column::new_with_schema("lc1", &left_schema)?], + left_schema, + ); + + let right_sorted_filter_expr = SortedFilterExpr::new( + PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("ra1", &right_schema)?), + options: SortOptions::default(), + }, + Arc::new(Column::new(filter_col_1.name(), 0)), + ); + + let mut right_side_joiner = OneSideHashJoiner::new( + JoinSide::Right, + right_sorted_filter_expr, + vec![Column::new_with_schema("rc1", &right_schema)?], + right_schema, + ); + + let mut left_stream = left.execute(0, task_ctx.clone())?; + let mut right_stream = right.execute(0, task_ctx)?; + + let initial_left_batch = left_stream.next().await.unwrap()?; + left_side_joiner.update_internal_state(&initial_left_batch, &random_state)?; + assert_eq!( + left_side_joiner.input_buffer.num_rows(), + initial_left_batch.num_rows() + ); + + let initial_right_batch = right_stream.next().await.unwrap()?; + right_side_joiner.update_internal_state(&initial_right_batch, &random_state)?; + assert_eq!( + right_side_joiner.input_buffer.num_rows(), + initial_right_batch.num_rows() + ); + + left_side_joiner.join_with_probe_batch( + &join_schema, + join_type, + &right_side_joiner.on, + &filter, + &initial_right_batch, + &mut right_side_joiner.visited_rows, + right_side_joiner.offset, + &join_column_indices, + &random_state, + &false, + )?; + assert_eq!(left_side_joiner.visited_rows.is_empty(), should_be_empty); + Ok(()) + } } From 3bef26d8b0e8e2fc4df0c878a0961cb28c4d3dad Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Fri, 17 Feb 2023 13:49:13 +0300 Subject: [PATCH 35/39] Update on concat_batches usage --- datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 2c1a84669eb3..5aa03f78a01c 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -51,7 +51,6 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; use crate::logical_expr::JoinType; use crate::physical_plan::{ - common::merge_batches, expressions::Column, expressions::PhysicalSortExpr, joins::{ @@ -1007,7 +1006,7 @@ impl OneSideHashJoiner { random_state: &RandomState, ) -> Result<()> { // Merge the incoming batch with the existing input buffer: - self.input_buffer = merge_batches(&self.input_buffer, batch, batch.schema())?; + self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?; // Resize the hashes buffer to the number of rows in the incoming batch: self.hashes_buffer.resize(batch.num_rows(), 0); // Update the hashmap with the join key values and hashes of the incoming batch: From 537bbe54d38e56a5e5ab1a86440af3d39cb3e0f3 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 27 Feb 2023 10:30:12 +0300 Subject: [PATCH 36/39] Revamping according to comments. --- .../physical_plan/joins/hash_join_utils.rs | 163 +++++------ .../joins/symmetric_hash_join.rs | 255 ++++++------------ datafusion/core/src/physical_plan/memory.rs | 25 +- .../physical-expr/src/expressions/binary.rs | 52 ++++ .../physical-expr/src/expressions/cast.rs | 21 ++ .../physical-expr/src/intervals/cp_solver.rs | 105 +++----- datafusion/physical-expr/src/intervals/mod.rs | 4 +- datafusion/physical-expr/src/physical_expr.rs | 15 ++ datafusion/physical-expr/src/utils.rs | 198 ++++++++++++-- 9 files changed, 467 insertions(+), 371 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 08b07a79fef1..df71475c6bb4 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -254,64 +254,59 @@ pub mod tests { expressions::PhysicalSortExpr, joins::utils::{ColumnIndex, JoinFilter, JoinSide}, }; - use arrow::compute::{CastOptions, SortOptions}; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Literal}; + use datafusion_physical_expr::expressions::{binary, cast, col, lit, BinaryExpr}; use std::sync::Arc; /// Filter expr for a + b > c + 10 AND a + b < c + 100 - pub(crate) fn complicated_filter() -> Arc { - let left_expr = BinaryExpr::new( - Arc::new(CastExpr::new( - Arc::new(BinaryExpr::new( - Arc::new(Column::new("0", 0)), + pub(crate) fn complicated_filter( + filter_schema: &Schema, + ) -> Result> { + let left_expr = binary( + cast( + binary( + col("0", filter_schema)?, Operator::Plus, - Arc::new(Column::new("1", 1)), - )), + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, DataType::Int64, - CastOptions { safe: false }, - )), + )?, Operator::Gt, - Arc::new(BinaryExpr::new( - Arc::new(CastExpr::new( - Arc::new(Column::new("2", 2)), - DataType::Int64, - CastOptions { safe: false }, - )), + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int64(Some(10)))), - )), - ); - - let right_expr = BinaryExpr::new( - Arc::new(CastExpr::new( - Arc::new(BinaryExpr::new( - Arc::new(Column::new("0", 0)), + lit(ScalarValue::Int64(Some(10))), + filter_schema, + )?, + filter_schema, + )?; + + let right_expr = binary( + cast( + binary( + col("0", filter_schema)?, Operator::Plus, - Arc::new(Column::new("1", 1)), - )), + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, DataType::Int64, - CastOptions { safe: false }, - )), + )?, Operator::Lt, - Arc::new(BinaryExpr::new( - Arc::new(CastExpr::new( - Arc::new(Column::new("2", 2)), - DataType::Int64, - CastOptions { safe: false }, - )), + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int64(Some(100)))), - )), - ); - - Arc::new(BinaryExpr::new( - Arc::new(left_expr), - Operator::And, - Arc::new(right_expr), - )) + lit(ScalarValue::Int64(Some(100))), + filter_schema, + )?, + filter_schema, + )?; + binary(left_expr, Operator::And, right_expr, filter_schema) } #[test] @@ -323,7 +318,7 @@ pub mod tests { )])); // Sorting information for the left side: let left_child_sort_expr = PhysicalSortExpr { - expr: Arc::new(Column::new("left_1", 0)), + expr: col("left_1", left_child_schema.as_ref())?, options: SortOptions::default(), }; @@ -333,11 +328,12 @@ pub mod tests { ])); // Sorting information for the right side: let right_child_sort_expr = PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("right_1", 0)), + expr: binary( + col("right_1", right_child_schema.as_ref())?, Operator::Plus, - Arc::new(Column::new("right_2", 1)), - )), + col("right_2", right_child_schema.as_ref())?, + right_child_schema.as_ref(), + )?, options: SortOptions::default(), }; @@ -405,15 +401,26 @@ pub mod tests { } #[test] - fn test_column_collector() { - let filter_expr = complicated_filter(); + fn test_column_collector() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ])); + let filter_expr = complicated_filter(schema.as_ref())?; let columns = collect_columns(&filter_expr); - assert_eq!(columns.len(), 3) + assert_eq!(columns.len(), 3); + Ok(()) } #[test] fn find_expr_inside_expr() -> Result<()> { - let filter_expr = complicated_filter(); + let schema = Arc::new(Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ])); + let filter_expr = complicated_filter(schema.as_ref())?; let expr_1: Arc = Arc::new(Column::new("gnz", 0)); assert!(!check_filter_expr_contains_sort_information( @@ -421,22 +428,23 @@ pub mod tests { &expr_1 )); - let expr_2: Arc = Arc::new(Column::new("1", 1)); + let expr_2: Arc = col("1", schema.as_ref())?; assert!(check_filter_expr_contains_sort_information( &filter_expr, &expr_2 )); - let expr_3: Arc = Arc::new(CastExpr::new( - Arc::new(BinaryExpr::new( - Arc::new(Column::new("0", 0)), + let expr_3: Arc = cast( + binary( + col("0", schema.as_ref())?, Operator::Plus, - Arc::new(Column::new("1", 1)), - )), + col("1", schema.as_ref())?, + schema.as_ref(), + )?, + schema.as_ref(), DataType::Int64, - CastOptions { safe: false }, - )); + )?; assert!(check_filter_expr_contains_sort_information( &filter_expr, @@ -496,16 +504,17 @@ pub mod tests { Field::new(filter_col_2.name(), DataType::Int32, true), ]); - let filter_expr = complicated_filter(); + let filter_expr = complicated_filter(&intermediate_schema)?; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + let filter = + JoinFilter::new(filter_expr, column_indices, intermediate_schema.clone()); assert!(build_filter_input_order( JoinSide::Left, &filter, &left_schema, &PhysicalSortExpr { - expr: Arc::new(Column::new("la1", 0)), + expr: col("la1", left_schema.as_ref())?, options: SortOptions::default(), } ) @@ -515,7 +524,7 @@ pub mod tests { &filter, &left_schema, &PhysicalSortExpr { - expr: Arc::new(Column::new("lt1", 3)), + expr: col("lt1", left_schema.as_ref())?, options: SortOptions::default(), } ) @@ -525,7 +534,7 @@ pub mod tests { &filter, &right_schema, &PhysicalSortExpr { - expr: Arc::new(Column::new("ra1", 0)), + expr: col("ra1", right_schema.as_ref())?, options: SortOptions::default(), } ) @@ -535,7 +544,7 @@ pub mod tests { &filter, &right_schema, &PhysicalSortExpr { - expr: Arc::new(Column::new("rb1", 1)), + expr: col("rb1", right_schema.as_ref())?, options: SortOptions::default(), } ) @@ -564,25 +573,27 @@ pub mod tests { Field::new(filter_col_1.name(), DataType::Int32, true), ]); - let filter_expr = Arc::new(BinaryExpr::new( - Arc::new(Column::new("0", 0)), + let filter_expr = binary( + col("0", &intermediate_schema)?, Operator::Minus, - Arc::new(Column::new("1", 1)), - )); + col("1", &intermediate_schema)?, + &intermediate_schema, + )?; let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int64, false), + Field::new("b", DataType::Int32, false), ])); let sorted = PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), + expr: binary( + col("a", schema.as_ref())?, Operator::Plus, - Arc::new(Column::new("b", 1)), - )), + col("b", schema.as_ref())?, + schema.as_ref(), + )?, options: SortOptions::default(), }; diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 5aa03f78a01c..c49510b8c536 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -1423,7 +1423,7 @@ impl SymmetricHashJoinStream { mod tests { use std::fs::File; - use arrow::array::{Array, ArrayRef}; + use arrow::array::ArrayRef; use arrow::array::{Int32Array, TimestampNanosecondArray}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; @@ -1584,23 +1584,6 @@ mod tests { Ok(result) } - fn build_record_batch(columns: Vec<(&str, ArrayRef)>) -> Result { - let schema = Schema::new( - columns - .iter() - .map(|(name, array)| { - let null = array.null_count() > 0; - Field::new(*name, array.data_type().clone(), null) - }) - .collect(), - ); - let batch = RecordBatch::try_new( - Arc::new(schema), - columns.into_iter().map(|(_, array)| array).collect(), - )?; - Ok(batch) - } - fn join_expr_tests_fixture( expr_id: usize, left_col: Arc, @@ -1683,149 +1666,69 @@ mod tests { let initial_range = 0..table_size; let index = (table_size as f64 * null_ratio).round() as i32; let rest_of = index..table_size; - let left = build_record_batch(vec![ - ( - "la1", - Arc::new(Int32Array::from_iter( - initial_range.clone().collect::>(), - )), - ), - ( - "lb1", - Arc::new(Int32Array::from_iter( - initial_range.clone().map(|x| x % 4).collect::>(), - )), - ), - ( - "lc1", - Arc::new(Int32Array::from_iter( - initial_range - .clone() - .map(|x| x % key_cardinality.0) - .collect::>(), - )), - ), - ( - "lt1", - Arc::new(TimestampNanosecondArray::from( - initial_range - .clone() - .map(|x| 1664264591000000000 + (5000000000 * (x as i64))) - .collect::>(), - )), - ), - ( - "la2", - Arc::new(Int32Array::from_iter( - initial_range.clone().collect::>(), - )), - ), - ( - "la1_des", - Arc::new(Int32Array::from_iter( - initial_range.clone().rev().collect::>(), - )), - ), - ( - "l_asc_null_first", - Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) - .chain(rest_of.clone().map(Some)) - .collect::>>() - })), - ), - ( - "l_asc_null_last", - Arc::new(Int32Array::from_iter({ - rest_of - .clone() - .map(Some) - .chain(std::iter::repeat(None).take(index as usize)) - .collect::>>() - })), - ), - ( - "l_desc_null_first", - Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) - .chain(rest_of.clone().rev().map(Some)) - .collect::>>() - })), - ), + let ordered: ArrayRef = Arc::new(Int32Array::from_iter( + initial_range.clone().collect::>(), + )); + let ordered_des: ArrayRef = Arc::new(Int32Array::from_iter( + initial_range.clone().rev().collect::>(), + )); + let cardinality: ArrayRef = Arc::new(Int32Array::from_iter( + initial_range.clone().map(|x| x % 4).collect::>(), + )); + let cardinality_key: ArrayRef = Arc::new(Int32Array::from_iter( + initial_range + .clone() + .map(|x| x % key_cardinality.0) + .collect::>(), + )); + let ordered_asc_null_first: ArrayRef = Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.clone().map(Some)) + .collect::>>() + })); + let ordered_asc_null_last: ArrayRef = Arc::new(Int32Array::from_iter({ + rest_of + .clone() + .map(Some) + .chain(std::iter::repeat(None).take(index as usize)) + .collect::>>() + })); + + let ordered_desc_null_first: ArrayRef = Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.rev().map(Some)) + .collect::>>() + })); + + let time: ArrayRef = Arc::new(TimestampNanosecondArray::from( + initial_range + .map(|x| 1664264591000000000 + (5000000000 * (x as i64))) + .collect::>(), + )); + + let left = RecordBatch::try_from_iter(vec![ + ("la1", ordered.clone()), + ("lb1", cardinality.clone()), + ("lc1", cardinality_key.clone()), + ("lt1", time.clone()), + ("la2", ordered.clone()), + ("la1_des", ordered_des.clone()), + ("l_asc_null_first", ordered_asc_null_first.clone()), + ("l_asc_null_last", ordered_asc_null_last.clone()), + ("l_desc_null_first", ordered_desc_null_first.clone()), ])?; - let right = build_record_batch(vec![ - ( - "ra1", - Arc::new(Int32Array::from_iter( - initial_range.clone().collect::>(), - )), - ), - ( - "rb1", - Arc::new(Int32Array::from_iter( - initial_range.clone().map(|x| x % 7).collect::>(), - )), - ), - ( - "rc1", - Arc::new(Int32Array::from_iter( - initial_range - .clone() - .map(|x| x % key_cardinality.1) - .collect::>(), - )), - ), - ( - "rt1", - Arc::new(TimestampNanosecondArray::from( - initial_range - .clone() - .map(|x| 1664264591000000000 + (5000000000 * (x as i64))) - .collect::>(), - )), - ), - ( - "ra2", - Arc::new(Int32Array::from_iter( - initial_range.clone().collect::>(), - )), - ), - ( - "ra1_des", - Arc::new(Int32Array::from_iter( - initial_range.rev().collect::>(), - )), - ), - ( - "r_asc_null_first", - Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) - .chain(rest_of.clone().map(Some)) - .collect::>>() - })), - ), - ( - "r_asc_null_last", - Arc::new(Int32Array::from_iter({ - rest_of - .clone() - .map(Some) - .chain(std::iter::repeat(None).take(index as usize)) - .collect::>>() - })), - ), - ( - "r_desc_null_first", - Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) - .chain(rest_of.rev().map(Some)) - .collect::>>() - })), - ), + let right = RecordBatch::try_from_iter(vec![ + ("ra1", ordered.clone()), + ("rb1", cardinality), + ("rc1", cardinality_key), + ("rt1", time), + ("ra2", ordered), + ("ra1_des", ordered_des), + ("r_asc_null_first", ordered_asc_null_first), + ("r_asc_null_last", ordered_asc_null_last), + ("r_desc_null_first", ordered_desc_null_first), ])?; Ok((left, right)) } @@ -1838,18 +1741,22 @@ mod tests { batch_size: usize, ) -> Result<(Arc, Arc)> { Ok(( - Arc::new(MemoryExec::try_new_with_sort_information( - &[split_record_batches(&left_batch, batch_size).unwrap()], - left_batch.schema(), - None, - Some(left_sorted), - )?), - Arc::new(MemoryExec::try_new_with_sort_information( - &[split_record_batches(&right_batch, batch_size).unwrap()], - right_batch.schema(), - None, - Some(right_sorted), - )?), + Arc::new( + MemoryExec::try_new( + &[split_record_batches(&left_batch, batch_size).unwrap()], + left_batch.schema(), + None, + )? + .with_sort_information(left_sorted), + ), + Arc::new( + MemoryExec::try_new( + &[split_record_batches(&right_batch, batch_size).unwrap()], + right_batch.schema(), + None, + )? + .with_sort_information(right_sorted), + ), )) } @@ -1958,7 +1865,7 @@ mod tests { Field::new(filter_col_2.name(), DataType::Int32, true), ]); - let filter_expr = complicated_filter(); + let filter_expr = complicated_filter(&intermediate_schema)?; let filter = JoinFilter::new( filter_expr, @@ -2480,7 +2387,7 @@ mod tests { Field::new(filter_col_2.name(), DataType::Int32, true), ]); - let filter_expr = complicated_filter(); + let filter_expr = complicated_filter(&intermediate_schema)?; let filter = JoinFilter::new( filter_expr, diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/core/src/physical_plan/memory.rs index 6e2a0892508a..f0cd48fa4f9d 100644 --- a/datafusion/core/src/physical_plan/memory.rs +++ b/datafusion/core/src/physical_plan/memory.rs @@ -149,23 +149,14 @@ impl MemoryExec { sort_information: None, }) } - /// Create a new execution plan for reading in-memory record batches - /// The provided `schema` should not have the projection applied. Also, you can specify sort - /// information on PhysicalExprs. - pub fn try_new_with_sort_information( - partitions: &[Vec], - schema: SchemaRef, - projection: Option>, - sort_information: Option>, - ) -> Result { - let projected_schema = project_schema(&schema, projection.as_ref())?; - Ok(Self { - partitions: partitions.to_vec(), - schema, - projected_schema, - projection, - sort_information, - }) + + /// Set sort information + pub fn with_sort_information( + mut self, + sort_information: Vec, + ) -> Self { + self.sort_information = Some(sort_information); + self } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index ec9096ffc02a..ed3b7fbddd5b 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -74,6 +74,8 @@ use arrow::datatypes::{DataType, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; use super::column::Column; +use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; +use crate::intervals::{apply_operator, Interval}; use crate::physical_expr::down_cast_any_ref; use crate::{analysis_expect, AnalysisContext, ExprBoundaries, PhysicalExpr}; use datafusion_common::cast::{as_boolean_array, as_decimal128_array}; @@ -786,6 +788,56 @@ impl PhysicalExpr for BinaryExpr { _ => context.with_boundaries(None), } } + + fn evaluate_bound(&self, _child_intervals: &[&Interval]) -> Result { + // Get children intervals: + let left_interval = _child_intervals[1]; + let right_interval = _child_intervals[0]; + // Calculate current node's interval + apply_operator(&self.op, left_interval, right_interval) + } + + fn propagate_constraints( + &self, + _parent_interval: &Interval, + _child_intervals: &[&Interval], + ) -> Result>> { + // Get children intervals. Graph brings + let left_interval = _child_intervals[0]; + let right_interval = _child_intervals[1]; + if self.op.is_logic_operator() { + // TODO: Currently, this implementation only supports the AND operator + // and does not require any further propagation. In the future, + // upon adding support for additional logical operators, this + // method will require modification to support propagating the + // changes accordingly. + Ok(vec![]) + } else if self.op.is_comparison_operator() { + if let Interval { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + } = _parent_interval + { + // TODO: The optimization of handling strictly false clauses through + // conversion to equivalent comparison operators (e.g. GT to LE, LT to GE) + // can be implemented once open/closed intervals are supported. + return Ok(vec![]); + } + // Propagate the comparison operator. + let (left, right) = + propagate_comparison(&self.op, left_interval, right_interval)?; + Ok(vec![left, right]) + } else { + // Propagate the arithmetic operator. + let (left, right) = propagate_arithmetic( + &self.op, + _parent_interval, + left_interval, + right_interval, + )?; + Ok(vec![left, right]) + } + } } impl PartialEq for BinaryExpr { diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 975ad3c6cd1b..6242546c1cb4 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt; use std::sync::Arc; +use crate::intervals::Interval; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::compute; @@ -115,6 +116,26 @@ impl PhysicalExpr for CastExpr { }, ))) } + + fn evaluate_bound(&self, _child_intervals: &[&Interval]) -> Result { + let interval = _child_intervals[0]; + // Cast current node's interval to the right type + interval.cast_to(&self.cast_type, &self.cast_options) + } + + fn propagate_constraints( + &self, + _parent_interval: &Interval, + _child_intervals: &[&Interval], + ) -> Result>> { + let child_interval = _child_intervals[0]; + // Get child's datatype + let cast_type = child_interval.get_datatype(); + + Ok(vec![Some( + _parent_interval.cast_to(&cast_type, self.cast_options())?, + )]) + } } impl PartialEq for CastExpr { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 8eb43393b46e..31a3b7c24b14 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -19,7 +19,6 @@ use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::ops::Index; use std::sync::Arc; use arrow_schema::DataType; @@ -30,7 +29,7 @@ use petgraph::stable_graph::{DefaultIx, StableGraph}; use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef}; use petgraph::Outgoing; -use crate::expressions::{BinaryExpr, CastExpr, Literal}; +use crate::expressions::Literal; use crate::intervals::interval_aritmetic::{apply_operator, Interval}; use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; @@ -454,24 +453,12 @@ impl ExprIntervalGraph { if children.is_empty() { continue; } - let expr_any = self.graph[node].expr.as_any(); - if let Some(binary) = expr_any.downcast_ref::() { - // Get children intervals: - let left_child_idx = children[1]; - let right_child_idx = children[0]; - let left_interval = self.graph[left_child_idx].interval(); - let right_interval = self.graph[right_child_idx].interval(); - // Calculate and replace current node's interval: - self.graph[node].interval = - apply_operator(binary.op(), left_interval, right_interval)?; - } else if let Some(cast_expr) = expr_any.downcast_ref::() { - let cast_type = cast_expr.cast_type(); - let cast_options = cast_expr.cast_options(); - let child = self.graph.index(children[0]); - // Cast current node's interval to the right type: - self.graph[node].interval = - child.interval.cast_to(cast_type, cast_options)?; - } + let children_intervals = children + .into_iter() + .map(|child| self.graph[child].interval()) + .collect::>(); + self.graph[node].interval = + self.graph[node].expr.evaluate_bound(&children_intervals)?; } Ok(&self.graph[self.root].interval) } @@ -482,66 +469,36 @@ impl ExprIntervalGraph { let mut bfs = Bfs::new(&self.graph, self.root); while let Some(node) = bfs.next(&self.graph) { let neighbors = self.graph.neighbors_directed(node, Outgoing); - let children = neighbors.collect::>(); + let mut children = neighbors.collect::>(); + // Reverse the children index order to align it with the ExecutionPlan's children order. + children.reverse(); // If the current expression is a leaf, its range is now final. // So, just continue with the propagation procedure: if children.is_empty() { continue; } let node_interval = self.graph[node].interval(); - let expr_any = self.graph[node].expr.as_any(); - if let Some(binary) = expr_any.downcast_ref::() { - // Get children intervals: - let left_child_idx = children[1]; - let right_child_idx = children[0]; - let left_interval = self.graph[left_child_idx].interval(); - let right_interval = self.graph[right_child_idx].interval(); - - let op = binary.op(); - if let (Some(new_left_interval), Some(new_right_interval)) = - if op.is_logic_operator() { - // TODO: Currently, this implementation only supports the AND operator - // and does not require any further propagation. In the future, - // upon adding support for additional logical operators, this - // method will require modification to support propagating the - // changes accordingly. - continue; - } else if op.is_comparison_operator() { - if let Interval { - lower: ScalarValue::Boolean(Some(false)), - upper: ScalarValue::Boolean(Some(false)), - } = node_interval - { - // TODO: The optimization of handling strictly false clauses through - // conversion to equivalent comparison operators (e.g. GT to LE, LT to GE) - // can be implemented once open/closed intervals are supported. - continue; - } - // Propagate the comparison operator. - propagate_comparison(op, left_interval, right_interval)? - } else { - // Propagate the arithmetic operator. - propagate_arithmetic( - op, - node_interval, - left_interval, - right_interval, - )? - } - { - self.graph[left_child_idx].interval = new_left_interval; - self.graph[right_child_idx].interval = new_right_interval; - } else { - return Ok(PropagationResult::Infeasible); - }; - } else if let Some(cast) = expr_any.downcast_ref::() { - let child_index = children[0]; - let child = self.graph.index(child_index); - // Get child's data type: - let cast_type = child.interval().get_datatype(); - self.graph[child_index].interval = - node_interval.cast_to(&cast_type, cast.cast_options())?; + let children_intervals = children + .iter() + .map(|child| self.graph[*child].interval()) + .collect::>(); + let propagated_intervals = self.graph[node] + .expr + .propagate_constraints(node_interval, &children_intervals)?; + // No new interval is calculated for node leafs. + if propagated_intervals.is_empty() { + continue; + } + // There is a feasibility problem + if propagated_intervals.iter().any(|i| i.is_none()) { + return Ok(PropagationResult::Infeasible); } + children + .iter() + .zip(propagated_intervals.into_iter()) + .for_each(|(child_index, new_interval)| { + self.graph[*child_index].interval = new_interval.unwrap() + }); } Ok(PropagationResult::Success) } @@ -577,7 +534,7 @@ mod tests { use crate::intervals::test_utils::gen_conjunctive_numeric_expr; use itertools::Itertools; - use crate::expressions::Column; + use crate::expressions::{BinaryExpr, Column}; use datafusion_common::ScalarValue; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs index 3016db4dc797..45616534cb17 100644 --- a/datafusion/physical-expr/src/intervals/mod.rs +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -18,8 +18,8 @@ //! Interval calculations //! -mod cp_solver; -mod interval_aritmetic; +pub mod cp_solver; +pub mod interval_aritmetic; pub mod test_utils; pub use cp_solver::ExprIntervalGraph; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 459ce8cd7b15..a454fab053d6 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -30,6 +30,7 @@ use std::fmt::{Debug, Display}; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, filter_record_batch, is_not_null, SlicesIterator}; +use crate::intervals::Interval; use std::any::Any; use std::sync::Arc; @@ -81,6 +82,20 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { fn analyze(&self, context: AnalysisContext) -> AnalysisContext { context } + + /// Computes bounds for the expression using interval arithmetic + fn evaluate_bound(&self, _child_intervals: &[&Interval]) -> Result { + Err(DataFusionError::Internal("Not supported".to_owned())) + } + + /// Updates/shrinks bounds for the expression using interval arithmetic + fn propagate_constraints( + &self, + _parent_interval: &Interval, + _child_intervals: &[&Interval], + ) -> Result>> { + Err(DataFusionError::Internal("Not supported".to_owned())) + } } /// The shared context used during the analysis of an expression. Includes diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 0c79bb4d8e84..e28e4d7cc342 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -17,7 +17,7 @@ use crate::equivalence::EquivalentClass; use crate::expressions::{BinaryExpr, Column, UnKnownColumn}; -use crate::rewrite::TreeNodeRewritable; +use crate::rewrite::{TreeNodeRewritable, TreeNodeRewriter}; use crate::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::SchemaRef; use datafusion_common::Result; @@ -277,10 +277,55 @@ impl TreeNodeRewritable for ExprTreeNode { } } -/// This function converts the [PhysicalExpr] tree into a DAG by collecting identical -/// expressions in one node. Caller specifies the node type in this DAG via the -/// `constructor` argument, which constructs nodes in this DAG from the [ExprTreeNode] -/// ancillary object. +/// This struct converts the [PhysicalExpr] tree into a DAG by collecting identical +/// expressions in one node using a rewriter to transform an input expression tree. +/// Caller specifies the node type in this DAG via the `constructor` argument, +/// which constructs nodes in this DAG from the [ExprTreeNode] ancillary object. +struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&'_ ExprTreeNode) -> T> { + // A graph containing physical expression trees. + graph: StableGraph, + // A vector of visited expression nodes and their corresponding node indices. + visited_plans: Vec<(Arc, NodeIndex)>, + // A function to convert an input expression node to T. + constructor: &'a F, +} + +impl<'a, T, F: Fn(&'_ ExprTreeNode) -> T> + TreeNodeRewriter> for PhysicalExprDAEGBuilder<'a, T, F> +{ + // This method mutates an expression node by transforming it to a physical expression + // and adding it to the graph. The method returns the mutated expression node. + fn mutate( + &mut self, + mut node: ExprTreeNode, + ) -> Result> { + // Get the expression associated with the input expression node. + let expr = &node.expr; + + // Check if the expression has already been visited. + let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) { + // If the expression has been visited, return the corresponding node index. + Some((_, idx)) => *idx, + // If the expression has not been visited, add a new node to the graph and + // add edges to its child nodes. Add the visited expression to the vector + // of visited expressions and return the newly created node index. + None => { + let node_idx = self.graph.add_node((self.constructor)(&node)); + for expr_node in node.child_nodes.iter() { + self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); + } + self.visited_plans.push((expr.clone(), node_idx)); + node_idx + } + }; + // Set the data field of the input expression node to the corresponding node index. + node.data = Some(node_idx); + // Return the mutated expression node. + Ok(node) + } +} + +// A function that builds a directed acyclic graph of physical expression trees. pub fn build_dag( expr: Arc, constructor: &F, @@ -288,40 +333,137 @@ pub fn build_dag( where F: Fn(&ExprTreeNode) -> T, { + // Create a new expression tree node from the input expression. let init = ExprTreeNode::new(expr); - let mut graph = StableGraph::::new(); - let mut visited_plans = Vec::<(Arc, NodeIndex)>::new(); - let root = init.mutable_transform_up(&mut |mut input| { - let expr = &input.expr; - let node_idx = match visited_plans.iter().find(|(e, _)| expr.eq(e)) { - Some((_, idx)) => *idx, - None => { - let node_idx = graph.add_node(constructor(&input)); - for expr_node in input.child_nodes.iter() { - graph.add_edge(node_idx, expr_node.data.unwrap(), 0); - } - visited_plans.push((expr.clone(), node_idx)); - node_idx - } - }; - input.data = Some(node_idx); - Ok(Some(input)) - })?; - Ok((root.data.unwrap(), graph)) + // Create a new PhysicalExprDAEGBuilder instance. + let mut builder = PhysicalExprDAEGBuilder { + graph: StableGraph::::new(), + visited_plans: Vec::<(Arc, NodeIndex)>::new(), + constructor, + }; + // Use the builder to transform the expression tree node into a DAG. + let root = init.transform_using(&mut builder)?; + // Return a tuple containing the root node index and the DAG. + Ok((root.data.unwrap(), builder.graph)) } #[cfg(test)] mod tests { - use super::*; - use crate::expressions::Column; + use crate::expressions::{binary, cast, col, lit, Column, Literal}; use crate::PhysicalSortExpr; use arrow::compute::SortOptions; - use datafusion_common::Result; + use datafusion_common::{Result, ScalarValue}; + use std::fmt::{Display, Formatter}; - use arrow_schema::Schema; + use arrow_schema::{DataType, Field, Schema}; + use petgraph::visit::Bfs; use std::sync::Arc; + fn make_dummy_nodes(node: &ExprTreeNode) -> PhysicalExprDummyNode { + let expr = node.expression().clone(); + let dummy_property = if expr.as_any().is::() { + "Binary" + } else if expr.as_any().is::() { + "Column" + } else if expr.as_any().is::() { + "Literal" + } else { + "Other" + } + .to_owned(); + PhysicalExprDummyNode { + expr, + property: DummyProperty { + expr_type: dummy_property, + }, + } + } + #[derive(Clone)] + struct DummyProperty { + expr_type: String, + } + + /// This is a node in the DAEG; it encapsulates a reference to the actual + /// [PhysicalExpr] as well as a dummy property to use it in DAEG traversal. + #[derive(Clone)] + struct PhysicalExprDummyNode { + pub expr: Arc, + pub property: DummyProperty, + } + + impl Display for PhysicalExprDummyNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.expr) + } + } + + #[test] + fn testing_build_dag() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let expr = binary( + cast( + binary( + col("0", &schema)?, + Operator::Plus, + col("1", &schema)?, + &schema, + )?, + &schema, + DataType::Int64, + )?, + Operator::Gt, + binary( + cast(col("2", &schema)?, &schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(10))), + &schema, + )?, + &schema, + )?; + let mut vector_dummy_props = vec![]; + let (root, graph) = build_dag(expr, &make_dummy_nodes)?; + let mut bfs = Bfs::new(&graph, root); + while let Some(node_index) = bfs.next(&graph) { + let node = &graph[node_index]; + vector_dummy_props.push(node.property.clone()); + } + + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Binary") + .count(), + 3 + ); + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Column") + .count(), + 3 + ); + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Literal") + .count(), + 1 + ); + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Other") + .count(), + 2 + ); + Ok(()) + } + #[test] fn expr_list_eq_test() -> Result<()> { let list1: Vec> = vec![ From 70a9905e01feaaf645fd57692f6c3cf207393070 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Mon, 27 Feb 2023 20:59:47 -0600 Subject: [PATCH 37/39] Simplifications, refactoring --- .../src/physical_optimizer/pipeline_fixer.rs | 5 + .../physical_plan/joins/hash_join_utils.rs | 174 ++++---- .../joins/symmetric_hash_join.rs | 418 +++++++----------- .../physical-expr/src/expressions/binary.rs | 43 +- .../physical-expr/src/expressions/cast.rs | 18 +- .../physical-expr/src/intervals/cp_solver.rs | 43 +- datafusion/physical-expr/src/physical_expr.rs | 21 +- datafusion/physical-expr/src/rewrite.rs | 15 - datafusion/physical-expr/src/utils.rs | 60 +-- 9 files changed, 345 insertions(+), 452 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 29fce233917c..1936870b2570 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -56,8 +56,13 @@ impl PipelineFixer { Self {} } } +/// [PipelineFixer] subrules are functions of this type. Such functions take a +/// single [PipelineStatePropagator] argument, which stores state variables +/// indicating the unboundedness status of the current [ExecutionPlan] as +/// the [PipelineFixer] rule traverses the entire plan tree. type PipelineFixerSubrule = dyn Fn(PipelineStatePropagator) -> Option>; + impl PhysicalOptimizerRule for PipelineFixer { fn optimize( &self, diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index df71475c6bb4..b9c71f3c7469 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -258,7 +258,7 @@ pub mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{binary, cast, col, lit, BinaryExpr}; + use datafusion_physical_expr::expressions::{binary, cast, col, lit}; use std::sync::Arc; /// Filter expr for a + b > c + 10 AND a + b < c + 100 @@ -311,36 +311,48 @@ pub mod tests { #[test] fn test_column_exchange() -> Result<()> { - let left_child_schema = Arc::new(Schema::new(vec![Field::new( - "left_1", - DataType::Int32, - true, - )])); + let left_child_schema = + Schema::new(vec![Field::new("left_1", DataType::Int32, true)]); // Sorting information for the left side: let left_child_sort_expr = PhysicalSortExpr { - expr: col("left_1", left_child_schema.as_ref())?, + expr: col("left_1", &left_child_schema)?, options: SortOptions::default(), }; - let right_child_schema = Arc::new(Schema::new(vec![ + let right_child_schema = Schema::new(vec![ Field::new("right_1", DataType::Int32, true), Field::new("right_2", DataType::Int32, true), - ])); + ]); // Sorting information for the right side: let right_child_sort_expr = PhysicalSortExpr { expr: binary( - col("right_1", right_child_schema.as_ref())?, + col("right_1", &right_child_schema)?, Operator::Plus, - col("right_2", right_child_schema.as_ref())?, - right_child_schema.as_ref(), + col("right_2", &right_child_schema)?, + &right_child_schema, )?, options: SortOptions::default(), }; - let filter_col_1 = Arc::new(Column::new("filter_1", 0)); - let filter_col_2 = Arc::new(Column::new("filter_2", 1)); - let filter_col_3 = Arc::new(Column::new("filter_3", 2)); - + let intermediate_schema = Schema::new(vec![ + Field::new("filter_1", DataType::Int32, true), + Field::new("filter_2", DataType::Int32, true), + Field::new("filter_3", DataType::Int32, true), + ]); + // Our filter expression is: left_1 > right_1 + right_2. + let filter_left = col("filter_1", &intermediate_schema)?; + let filter_right = binary( + col("filter_2", &intermediate_schema)?, + Operator::Plus, + col("filter_3", &intermediate_schema)?, + &intermediate_schema, + )?; + let filter_expr = binary( + filter_left.clone(), + Operator::Gt, + filter_right.clone(), + &intermediate_schema, + )?; let column_indices = vec![ ColumnIndex { index: 0, @@ -355,23 +367,12 @@ pub mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(filter_col_1.name(), DataType::Int32, true), - Field::new(filter_col_2.name(), DataType::Int32, true), - Field::new(filter_col_3.name(), DataType::Int32, true), - ]); - // Our filter expression is: left_1 > right_1 + right_2. - let filter_expr = Arc::new(BinaryExpr::new( - filter_col_1, - Operator::Gt, - Arc::new(BinaryExpr::new(filter_col_2, Operator::Plus, filter_col_3)), - )); let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); let left_sort_filter_expr = build_filter_input_order( JoinSide::Left, &filter, - &left_child_schema, + &Arc::new(left_child_schema), &left_child_sort_expr, )?; assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr())); @@ -379,35 +380,26 @@ pub mod tests { let right_sort_filter_expr = build_filter_input_order( JoinSide::Right, &filter, - &right_child_schema, + &Arc::new(right_child_schema), &right_child_sort_expr, )?; assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr())); - // Filter expression (left) adjusted for filter schema: - let expected_filter_expr: Arc = - Arc::new(Column::new("filter_1", 0)); - // Assert that it matches with the `left_child_sort_expr`: - assert!(expected_filter_expr.eq(left_sort_filter_expr.filter_expr())); - // Filter expression (right) adjusted for filter schema: - let expected_filter_expr: Arc = Arc::new(BinaryExpr::new( - Arc::new(Column::new("filter_2", 1)), - Operator::Plus, - Arc::new(Column::new("filter_3", 2)), - )); - // Assert that it matches with the `right_child_sort_expr`: - assert!(expected_filter_expr.eq(right_sort_filter_expr.filter_expr())); + // Assert that adjusted (left) filter expression matches with `left_child_sort_expr`: + assert!(filter_left.eq(left_sort_filter_expr.filter_expr())); + // Assert that adjusted (right) filter expression matches with `right_child_sort_expr`: + assert!(filter_right.eq(right_sort_filter_expr.filter_expr())); Ok(()) } #[test] fn test_column_collector() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ + let schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), Field::new("1", DataType::Int32, true), Field::new("2", DataType::Int32, true), - ])); - let filter_expr = complicated_filter(schema.as_ref())?; + ]); + let filter_expr = complicated_filter(&schema)?; let columns = collect_columns(&filter_expr); assert_eq!(columns.len(), 3); Ok(()) @@ -415,34 +407,34 @@ pub mod tests { #[test] fn find_expr_inside_expr() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ + let schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), Field::new("1", DataType::Int32, true), Field::new("2", DataType::Int32, true), - ])); - let filter_expr = complicated_filter(schema.as_ref())?; + ]); + let filter_expr = complicated_filter(&schema)?; - let expr_1: Arc = Arc::new(Column::new("gnz", 0)); + let expr_1 = Arc::new(Column::new("gnz", 0)) as _; assert!(!check_filter_expr_contains_sort_information( &filter_expr, &expr_1 )); - let expr_2: Arc = col("1", schema.as_ref())?; + let expr_2 = col("1", &schema)? as _; assert!(check_filter_expr_contains_sort_information( &filter_expr, &expr_2 )); - let expr_3: Arc = cast( + let expr_3 = cast( binary( - col("0", schema.as_ref())?, + col("0", &schema)?, Operator::Plus, - col("1", schema.as_ref())?, - schema.as_ref(), + col("1", &schema)?, + &schema, )?, - schema.as_ref(), + &schema, DataType::Int64, )?; @@ -451,7 +443,7 @@ pub mod tests { &expr_3 )); - let expr_4: Arc = Arc::new(Column::new("1", 42)); + let expr_4 = Arc::new(Column::new("1", 42)) as _; assert!(!check_filter_expr_contains_sort_information( &filter_expr, @@ -462,28 +454,30 @@ pub mod tests { #[test] fn build_sorted_expr() -> Result<()> { - let left_schema = Arc::new(Schema::new(vec![ + let left_schema = Schema::new(vec![ Field::new("la1", DataType::Int32, false), Field::new("lb1", DataType::Int32, false), Field::new("lc1", DataType::Int32, false), Field::new("lt1", DataType::Int32, false), Field::new("la2", DataType::Int32, false), Field::new("la1_des", DataType::Int32, false), - ])); + ]); - let right_schema = Arc::new(Schema::new(vec![ + let right_schema = Schema::new(vec![ Field::new("ra1", DataType::Int32, false), Field::new("rb1", DataType::Int32, false), Field::new("rc1", DataType::Int32, false), Field::new("rt1", DataType::Int32, false), Field::new("ra2", DataType::Int32, false), Field::new("ra1_des", DataType::Int32, false), - ])); - - let filter_col_0 = Arc::new(Column::new("0", 0)); - let filter_col_1 = Arc::new(Column::new("1", 1)); - let filter_col_2 = Arc::new(Column::new("2", 2)); + ]); + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { index: 0, @@ -498,16 +492,10 @@ pub mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(filter_col_0.name(), DataType::Int32, true), - Field::new(filter_col_1.name(), DataType::Int32, true), - Field::new(filter_col_2.name(), DataType::Int32, true), - ]); - - let filter_expr = complicated_filter(&intermediate_schema)?; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - let filter = - JoinFilter::new(filter_expr, column_indices, intermediate_schema.clone()); + let left_schema = Arc::new(left_schema); + let right_schema = Arc::new(right_schema); assert!(build_filter_input_order( JoinSide::Left, @@ -552,12 +540,20 @@ pub mod tests { Ok(()) } - // if one side is sorted by ORDER BY (a+b), and join filter condition includes (a-b). + + // Test the case when we have an "ORDER BY a + b", and join filter condition includes "a - b". #[test] fn sorted_filter_expr_build() -> Result<()> { - let filter_col_0 = Arc::new(Column::new("0", 0)); - let filter_col_1 = Arc::new(Column::new("1", 1)); - + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + ]); + let filter_expr = binary( + col("0", &intermediate_schema)?, + Operator::Minus, + col("1", &intermediate_schema)?, + &intermediate_schema, + )?; let column_indices = vec![ ColumnIndex { index: 0, @@ -568,31 +564,19 @@ pub mod tests { side: JoinSide::Left, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(filter_col_0.name(), DataType::Int32, true), - Field::new(filter_col_1.name(), DataType::Int32, true), - ]); - - let filter_expr = binary( - col("0", &intermediate_schema)?, - Operator::Minus, - col("1", &intermediate_schema)?, - &intermediate_schema, - )?; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - let schema = Arc::new(Schema::new(vec![ + let schema = Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), - ])); + ]); let sorted = PhysicalSortExpr { expr: binary( - col("a", schema.as_ref())?, + col("a", &schema)?, Operator::Plus, - col("b", schema.as_ref())?, - schema.as_ref(), + col("b", &schema)?, + &schema, )?, options: SortOptions::default(), }; @@ -600,7 +584,7 @@ pub mod tests { let res = convert_sort_expr_with_filter_schema( &JoinSide::Left, &filter, - &schema, + &Arc::new(schema), &sorted, )?; assert!(res.is_none()); diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index c49510b8c536..c377be6d1ea0 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -1432,7 +1432,7 @@ mod tests { use tempfile::TempDir; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, Column}; + use datafusion_physical_expr::expressions::{binary, col, Column}; use datafusion_physical_expr::intervals::test_utils::gen_conjunctive_numeric_expr; use datafusion_physical_expr::PhysicalExpr; @@ -1669,25 +1669,25 @@ mod tests { let ordered: ArrayRef = Arc::new(Int32Array::from_iter( initial_range.clone().collect::>(), )); - let ordered_des: ArrayRef = Arc::new(Int32Array::from_iter( + let ordered_des = Arc::new(Int32Array::from_iter( initial_range.clone().rev().collect::>(), )); - let cardinality: ArrayRef = Arc::new(Int32Array::from_iter( + let cardinality = Arc::new(Int32Array::from_iter( initial_range.clone().map(|x| x % 4).collect::>(), )); - let cardinality_key: ArrayRef = Arc::new(Int32Array::from_iter( + let cardinality_key = Arc::new(Int32Array::from_iter( initial_range .clone() .map(|x| x % key_cardinality.0) .collect::>(), )); - let ordered_asc_null_first: ArrayRef = Arc::new(Int32Array::from_iter({ + let ordered_asc_null_first = Arc::new(Int32Array::from_iter({ std::iter::repeat(None) .take(index as usize) .chain(rest_of.clone().map(Some)) .collect::>>() })); - let ordered_asc_null_last: ArrayRef = Arc::new(Int32Array::from_iter({ + let ordered_asc_null_last = Arc::new(Int32Array::from_iter({ rest_of .clone() .map(Some) @@ -1695,14 +1695,14 @@ mod tests { .collect::>>() })); - let ordered_desc_null_first: ArrayRef = Arc::new(Int32Array::from_iter({ + let ordered_desc_null_first = Arc::new(Int32Array::from_iter({ std::iter::repeat(None) .take(index as usize) .chain(rest_of.rev().map(Some)) .collect::>>() })); - let time: ArrayRef = Arc::new(TimestampNanosecondArray::from( + let time = Arc::new(TimestampNanosecondArray::from( initial_range .map(|x| 1664264591000000000 + (5000000000 * (x as i64))) .collect::>(), @@ -1779,13 +1779,7 @@ mod tests { ) .await?; let second_batches = partitioned_hash_join_with_filter( - left.clone(), - right.clone(), - on.clone(), - filter.clone(), - &join_type, - false, - task_ctx.clone(), + left, right, on, filter, &join_type, false, task_ctx, ) .await?; compare_batches(&first_batches, &second_batches); @@ -1807,10 +1801,10 @@ mod tests { )] join_type: JoinType, #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), + (4, 5), + (11, 21), + (31, 71), + (99, 12), )] cardinality: (i32, i32), ) -> Result<()> { @@ -1820,31 +1814,35 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), + expr: binary( + col("la1", left_schema)?, Operator::Plus, - Arc::new(Column::new_with_schema("la2", &left_batch.schema())?), - )), + col("la2", left_schema)?, + left_schema, + )?, options: SortOptions::default(), }]; - let right_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("ra1", &right_batch.schema())?), + expr: col("ra1", right_schema)?, options: SortOptions::default(), }]; let (left, right) = create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( - Column::new_with_schema("lc1", &left.schema())?, - Column::new_with_schema("rc1", &right.schema())?, + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, )]; - let filter_col_0 = Arc::new(Column::new("0", 0)); - let filter_col_1 = Arc::new(Column::new("1", 1)); - let filter_col_2 = Arc::new(Column::new("2", 2)); - + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { index: 0, @@ -1859,19 +1857,7 @@ mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(filter_col_0.name(), DataType::Int32, true), - Field::new(filter_col_1.name(), DataType::Int32, true), - Field::new(filter_col_2.name(), DataType::Int32, true), - ]); - - let filter_expr = complicated_filter(&intermediate_schema)?; - - let filter = JoinFilter::new( - filter_expr, - column_indices.clone(), - intermediate_schema.clone(), - ); + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) @@ -1892,10 +1878,10 @@ mod tests { )] join_type: JoinType, #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), + (4, 5), + (11, 21), + (31, 71), + (99, 12), )] cardinality: (i32, i32), #[values(0, 1, 2, 3, 4)] case_expr: usize, @@ -1905,25 +1891,33 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), + expr: col("la1", left_schema)?, options: SortOptions::default(), }]; let right_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("ra1", &right_batch.schema())?), + expr: col("ra1", right_schema)?, options: SortOptions::default(), }]; let (left, right) = create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( - Column::new_with_schema("lc1", &left.schema())?, - Column::new_with_schema("rc1", &right.schema())?, + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, )]; - let left_col = Arc::new(Column::new("left", 0)); - let right_col = Arc::new(Column::new("right", 1)); - + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); let column_indices = vec![ ColumnIndex { index: 0, @@ -1934,19 +1928,7 @@ mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(left_col.name(), DataType::Int32, true), - Field::new(right_col.name(), DataType::Int32, true), - ]); - - let filter_expr = - join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); - - let filter = JoinFilter::new( - filter_expr, - column_indices.clone(), - intermediate_schema.clone(), - ); + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) @@ -1962,15 +1944,17 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("la1_des", &left_batch.schema())?), + expr: col("la1_des", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, }]; let right_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("ra1_des", &right_batch.schema())?), + expr: col("ra1_des", right_schema)?, options: SortOptions { descending: true, nulls_first: true, @@ -1980,13 +1964,19 @@ mod tests { create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( - Column::new_with_schema("lc1", &left.schema())?, - Column::new_with_schema("rc1", &right.schema())?, + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, )]; - let left_col = Arc::new(Column::new("left", 0)); - let right_col = Arc::new(Column::new("right", 1)); - + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); let column_indices = vec![ ColumnIndex { index: 5, @@ -1997,19 +1987,7 @@ mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(left_col.name(), DataType::Int32, true), - Field::new(right_col.name(), DataType::Int32, true), - ]); - - let filter_expr = - join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); - - let filter = JoinFilter::new( - filter_expr, - column_indices.clone(), - intermediate_schema.clone(), - ); + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) @@ -2030,10 +2008,10 @@ mod tests { )] join_type: JoinType, #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), + (4, 5), + (11, 21), + (31, 71), + (99, 12), )] cardinality: (i32, i32), #[values(0, 1, 2, 3, 4)] case_expr: usize, @@ -2043,15 +2021,17 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("la1_des", &left_batch.schema())?), + expr: col("la1_des", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, }]; let right_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("ra1_des", &right_batch.schema())?), + expr: col("ra1_des", right_schema)?, options: SortOptions { descending: true, nulls_first: true, @@ -2061,13 +2041,19 @@ mod tests { create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( - Column::new_with_schema("lc1", &left.schema())?, - Column::new_with_schema("rc1", &right.schema())?, + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, )]; - let left_col = Arc::new(Column::new("left", 0)); - let right_col = Arc::new(Column::new("right", 1)); - + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); let column_indices = vec![ ColumnIndex { index: 5, @@ -2078,19 +2064,7 @@ mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(left_col.name(), DataType::Int32, true), - Field::new(right_col.name(), DataType::Int32, true), - ]); - - let filter_expr = - join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); - - let filter = JoinFilter::new( - filter_expr, - column_indices.clone(), - intermediate_schema.clone(), - ); + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) @@ -2139,21 +2113,17 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema( - "l_asc_null_first", - &left_batch.schema(), - )?), + expr: col("l_asc_null_first", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, }]; let right_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema( - "r_asc_null_first", - &right_batch.schema(), - )?), + expr: col("r_asc_null_first", right_schema)?, options: SortOptions { descending: false, nulls_first: true, @@ -2163,13 +2133,19 @@ mod tests { create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( - Column::new_with_schema("lc1", &left.schema())?, - Column::new_with_schema("rc1", &right.schema())?, + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, )]; - let left_col = Arc::new(Column::new("left", 0)); - let right_col = Arc::new(Column::new("right", 1)); - + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); let column_indices = vec![ ColumnIndex { index: 6, @@ -2180,19 +2156,7 @@ mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(left_col.name(), DataType::Int32, true), - Field::new(right_col.name(), DataType::Int32, true), - ]); - - let filter_expr = - join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); - - let filter = JoinFilter::new( - filter_expr, - column_indices.clone(), - intermediate_schema.clone(), - ); + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) } @@ -2207,21 +2171,17 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema( - "l_asc_null_last", - &left_batch.schema(), - )?), + expr: col("l_asc_null_last", left_schema)?, options: SortOptions { descending: false, nulls_first: false, }, }]; let right_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema( - "r_asc_null_last", - &right_batch.schema(), - )?), + expr: col("r_asc_null_last", right_schema)?, options: SortOptions { descending: false, nulls_first: false, @@ -2231,13 +2191,19 @@ mod tests { create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( - Column::new_with_schema("lc1", &left.schema())?, - Column::new_with_schema("rc1", &right.schema())?, + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, )]; - let left_col = Arc::new(Column::new("left", 0)); - let right_col = Arc::new(Column::new("right", 1)); - + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); let column_indices = vec![ ColumnIndex { index: 7, @@ -2248,19 +2214,7 @@ mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(left_col.name(), DataType::Int32, true), - Field::new(right_col.name(), DataType::Int32, true), - ]); - - let filter_expr = - join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); - - let filter = JoinFilter::new( - filter_expr, - column_indices.clone(), - intermediate_schema.clone(), - ); + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) @@ -2276,21 +2230,17 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema( - "l_desc_null_first", - &left_batch.schema(), - )?), + expr: col("l_desc_null_first", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, }]; let right_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema( - "r_desc_null_first", - &right_batch.schema(), - )?), + expr: col("r_desc_null_first", right_schema)?, options: SortOptions { descending: true, nulls_first: true, @@ -2300,13 +2250,19 @@ mod tests { create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( - Column::new_with_schema("lc1", &left.schema())?, - Column::new_with_schema("rc1", &right.schema())?, + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, )]; - let left_col = Arc::new(Column::new("left", 0)); - let right_col = Arc::new(Column::new("right", 1)); - + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); let column_indices = vec![ ColumnIndex { index: 8, @@ -2317,19 +2273,7 @@ mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(left_col.name(), DataType::Int32, true), - Field::new(right_col.name(), DataType::Int32, true), - ]); - - let filter_expr = - join_expr_tests_fixture(case_expr, left_col.clone(), right_col.clone()); - - let filter = JoinFilter::new( - filter_expr, - column_indices.clone(), - intermediate_schema.clone(), - ); + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) @@ -2346,27 +2290,31 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), + expr: col("la1", left_schema)?, options: SortOptions::default(), }]; let right_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("ra1", &right_batch.schema())?), + expr: col("ra1", right_schema)?, options: SortOptions::default(), }]; let (left, right) = create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; let on = vec![( - Column::new_with_schema("lc1", &left.schema())?, - Column::new_with_schema("rc1", &right.schema())?, + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, )]; - let filter_col_0 = Arc::new(Column::new("0", 0)); - let filter_col_1 = Arc::new(Column::new("1", 1)); - let filter_col_2 = Arc::new(Column::new("2", 2)); - + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { index: 0, @@ -2381,23 +2329,12 @@ mod tests { side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - Field::new(filter_col_0.name(), DataType::Int32, true), - Field::new(filter_col_1.name(), DataType::Int32, true), - Field::new(filter_col_2.name(), DataType::Int32, true), - ]); - - let filter_expr = complicated_filter(&intermediate_schema)?; - - let filter = JoinFilter::new( - filter_expr, - column_indices.clone(), - intermediate_schema.clone(), - ); + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); experiment(left, right, filter, join_type, on, task_ctx).await?; Ok(()) } + #[rstest] #[tokio::test(flavor = "multi_thread")] async fn test_one_side_hash_joiner_visited_rows( @@ -2422,7 +2359,8 @@ mod tests { let task_ctx = session_ctx.task_ctx(); // Ensure there will be matching rows let (left_batch, right_batch) = build_sides_record_batches(20, (1, 1))?; - let (left_schema, right_schema) = (left_batch.schema(), right_batch.schema()); + let left_schema = left_batch.schema(); + let right_schema = right_batch.schema(); // Build the join schema from the left and right schemas let (schema, join_column_indices) = @@ -2431,41 +2369,26 @@ mod tests { // Sort information for MemoryExec let left_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("la1", &left_batch.schema())?), + expr: col("la1", &left_schema)?, options: SortOptions::default(), }]; // Sort information for MemoryExec let right_sorted = vec![PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("ra1", &right_batch.schema())?), + expr: col("ra1", &right_schema)?, options: SortOptions::default(), }]; // Construct MemoryExec let (left, right) = create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 10)?; - // Filter columns - let filter_col_0 = Arc::new(Column::new("0", 0)); - let filter_col_1 = Arc::new(Column::new("1", 1)); - - let column_indices = vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 0, - side: JoinSide::Right, - }, - ]; + // Filter columns, ensure first batches will have matching rows. let intermediate_schema = Schema::new(vec![ - Field::new(filter_col_0.name(), DataType::Int32, true), - Field::new(filter_col_1.name(), DataType::Int32, true), + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), ]); - - // Ensure first batches will have matching rows. let filter_expr = gen_conjunctive_numeric_expr( - filter_col_0.clone(), - filter_col_1.clone(), + col("0", &intermediate_schema)?, + col("1", &intermediate_schema)?, Operator::Plus, Operator::Minus, Operator::Plus, @@ -2475,21 +2398,25 @@ mod tests { 0, 3, ); - - let filter = JoinFilter::new( - filter_expr, - column_indices.clone(), - intermediate_schema.clone(), - ); + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); let left_sorted_filter_expr = SortedFilterExpr::new( PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("la1", &left_schema)?), + expr: col("la1", &left_schema)?, options: SortOptions::default(), }, - Arc::new(Column::new(filter_col_0.name(), 0)), + Arc::new(Column::new("0", 0)), ); - let mut left_side_joiner = OneSideHashJoiner::new( JoinSide::Left, left_sorted_filter_expr, @@ -2499,12 +2426,11 @@ mod tests { let right_sorted_filter_expr = SortedFilterExpr::new( PhysicalSortExpr { - expr: Arc::new(Column::new_with_schema("ra1", &right_schema)?), + expr: col("ra1", &right_schema)?, options: SortOptions::default(), }, - Arc::new(Column::new(filter_col_1.name(), 0)), + Arc::new(Column::new("1", 0)), ); - let mut right_side_joiner = OneSideHashJoiner::new( JoinSide::Right, right_sorted_filter_expr, diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index ed3b7fbddd5b..2c304cf7b355 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -789,54 +789,47 @@ impl PhysicalExpr for BinaryExpr { } } - fn evaluate_bound(&self, _child_intervals: &[&Interval]) -> Result { + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { // Get children intervals: - let left_interval = _child_intervals[1]; - let right_interval = _child_intervals[0]; - // Calculate current node's interval + let left_interval = children[0]; + let right_interval = children[1]; + // Calculate current node's interval: apply_operator(&self.op, left_interval, right_interval) } fn propagate_constraints( &self, - _parent_interval: &Interval, - _child_intervals: &[&Interval], + interval: &Interval, + children: &[&Interval], ) -> Result>> { // Get children intervals. Graph brings - let left_interval = _child_intervals[0]; - let right_interval = _child_intervals[1]; - if self.op.is_logic_operator() { + let left_interval = children[0]; + let right_interval = children[1]; + let (left, right) = if self.op.is_logic_operator() { // TODO: Currently, this implementation only supports the AND operator // and does not require any further propagation. In the future, // upon adding support for additional logical operators, this // method will require modification to support propagating the // changes accordingly. - Ok(vec![]) + return Ok(vec![]); } else if self.op.is_comparison_operator() { if let Interval { lower: ScalarValue::Boolean(Some(false)), upper: ScalarValue::Boolean(Some(false)), - } = _parent_interval + } = interval { - // TODO: The optimization of handling strictly false clauses through - // conversion to equivalent comparison operators (e.g. GT to LE, LT to GE) - // can be implemented once open/closed intervals are supported. + // TODO: We will handle strictly false clauses by negating + // the comparison operator (e.g. GT to LE, LT to GE) + // once open/closed intervals are supported. return Ok(vec![]); } // Propagate the comparison operator. - let (left, right) = - propagate_comparison(&self.op, left_interval, right_interval)?; - Ok(vec![left, right]) + propagate_comparison(&self.op, left_interval, right_interval)? } else { // Propagate the arithmetic operator. - let (left, right) = propagate_arithmetic( - &self.op, - _parent_interval, - left_interval, - right_interval, - )?; - Ok(vec![left, right]) - } + propagate_arithmetic(&self.op, interval, left_interval, right_interval)? + }; + Ok(vec![left, right]) } } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 6242546c1cb4..e3d1b1d67317 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -117,23 +117,21 @@ impl PhysicalExpr for CastExpr { ))) } - fn evaluate_bound(&self, _child_intervals: &[&Interval]) -> Result { - let interval = _child_intervals[0]; - // Cast current node's interval to the right type - interval.cast_to(&self.cast_type, &self.cast_options) + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + // Cast current node's interval to the right type: + children[0].cast_to(&self.cast_type, &self.cast_options) } fn propagate_constraints( &self, - _parent_interval: &Interval, - _child_intervals: &[&Interval], + interval: &Interval, + children: &[&Interval], ) -> Result>> { - let child_interval = _child_intervals[0]; - // Get child's datatype + let child_interval = children[0]; + // Get child's datatype: let cast_type = child_interval.get_datatype(); - Ok(vec![Some( - _parent_interval.cast_to(&cast_type, self.cast_options())?, + interval.cast_to(&cast_type, &self.cast_options)?, )]) } } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 31a3b7c24b14..302a86cdc927 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -447,18 +447,17 @@ impl ExprIntervalGraph { let mut dfs = DfsPostOrder::new(&self.graph, self.root); while let Some(node) = dfs.next(&self.graph) { let neighbors = self.graph.neighbors_directed(node, Outgoing); - let children = neighbors.collect::>(); + let mut children_intervals = neighbors + .map(|child| self.graph[child].interval()) + .collect::>(); // If the current expression is a leaf, its interval should already // be set externally, just continue with the evaluation procedure: - if children.is_empty() { - continue; + if !children_intervals.is_empty() { + // Reverse to align with [PhysicalExpr]'s children: + children_intervals.reverse(); + self.graph[node].interval = + self.graph[node].expr.evaluate_bounds(&children_intervals)?; } - let children_intervals = children - .into_iter() - .map(|child| self.graph[child].interval()) - .collect::>(); - self.graph[node].interval = - self.graph[node].expr.evaluate_bound(&children_intervals)?; } Ok(&self.graph[self.root].interval) } @@ -470,35 +469,29 @@ impl ExprIntervalGraph { while let Some(node) = bfs.next(&self.graph) { let neighbors = self.graph.neighbors_directed(node, Outgoing); let mut children = neighbors.collect::>(); - // Reverse the children index order to align it with the ExecutionPlan's children order. - children.reverse(); // If the current expression is a leaf, its range is now final. // So, just continue with the propagation procedure: if children.is_empty() { continue; } - let node_interval = self.graph[node].interval(); + // Reverse to align with [PhysicalExpr]'s children: + children.reverse(); let children_intervals = children .iter() .map(|child| self.graph[*child].interval()) .collect::>(); + let node_interval = self.graph[node].interval(); let propagated_intervals = self.graph[node] .expr .propagate_constraints(node_interval, &children_intervals)?; - // No new interval is calculated for node leafs. - if propagated_intervals.is_empty() { - continue; - } - // There is a feasibility problem - if propagated_intervals.iter().any(|i| i.is_none()) { - return Ok(PropagationResult::Infeasible); + for (child, interval) in children.into_iter().zip(propagated_intervals) { + if let Some(interval) = interval { + self.graph[child].interval = interval; + } else { + // The constraint is infeasible, report: + return Ok(PropagationResult::Infeasible); + } } - children - .iter() - .zip(propagated_intervals.into_iter()) - .for_each(|(child_index, new_interval)| { - self.graph[*child_index].interval = new_interval.unwrap() - }); } Ok(PropagationResult::Success) } diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index a454fab053d6..3bc85fccffc4 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -83,18 +83,25 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { context } - /// Computes bounds for the expression using interval arithmetic - fn evaluate_bound(&self, _child_intervals: &[&Interval]) -> Result { - Err(DataFusionError::Internal("Not supported".to_owned())) + /// Computes bounds for the expression using interval arithmetic. + fn evaluate_bounds(&self, _children: &[&Interval]) -> Result { + Err(DataFusionError::NotImplemented(format!( + "Not implemented for {self}" + ))) } - /// Updates/shrinks bounds for the expression using interval arithmetic + /// Updates/shrinks bounds for the expression using interval arithmetic. + /// If constraint propagation reveals an infeasibility, returns [None] for + /// the child causing infeasibility. If none of the children intervals + /// change, may return an empty vector instead of cloning `children`. fn propagate_constraints( &self, - _parent_interval: &Interval, - _child_intervals: &[&Interval], + _interval: &Interval, + _children: &[&Interval], ) -> Result>> { - Err(DataFusionError::Internal("Not supported".to_owned())) + Err(DataFusionError::NotImplemented(format!( + "Not implemented for {self}" + ))) } } diff --git a/datafusion/physical-expr/src/rewrite.rs b/datafusion/physical-expr/src/rewrite.rs index c729fafb1062..327eabd4b0d0 100644 --- a/datafusion/physical-expr/src/rewrite.rs +++ b/datafusion/physical-expr/src/rewrite.rs @@ -113,21 +113,6 @@ pub trait TreeNodeRewritable: Clone { Ok(new_node) } - fn mutable_transform_up(self, op: &mut F) -> Result - where - F: FnMut(Self) -> Result>, - { - let after_op_children = - self.map_children(|node| node.mutable_transform_up(op))?; - - let after_op_children_clone = after_op_children.clone(); - let new_node = match op(after_op_children)? { - Some(value) => value, - None => after_op_children_clone, - }; - Ok(new_node) - } - /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result where diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index e28e4d7cc342..0908250db487 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -277,12 +277,13 @@ impl TreeNodeRewritable for ExprTreeNode { } } -/// This struct converts the [PhysicalExpr] tree into a DAG by collecting identical -/// expressions in one node using a rewriter to transform an input expression tree. -/// Caller specifies the node type in this DAG via the `constructor` argument, -/// which constructs nodes in this DAG from the [ExprTreeNode] ancillary object. -struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&'_ ExprTreeNode) -> T> { - // A graph containing physical expression trees. +/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a +/// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting +/// identical expressions in one node. Caller specifies the node type in the +/// DAEG via the `constructor` argument, which constructs nodes in the DAEG +/// from the [ExprTreeNode] ancillary object. +struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { + // The resulting DAEG (expression DAG). graph: StableGraph, // A vector of visited expression nodes and their corresponding node indices. visited_plans: Vec<(Arc, NodeIndex)>, @@ -290,7 +291,7 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&'_ ExprTreeNode) -> T> { constructor: &'a F, } -impl<'a, T, F: Fn(&'_ ExprTreeNode) -> T> +impl<'a, T, F: Fn(&ExprTreeNode) -> T> TreeNodeRewriter> for PhysicalExprDAEGBuilder<'a, T, F> { // This method mutates an expression node by transforming it to a physical expression @@ -335,7 +336,7 @@ where { // Create a new expression tree node from the input expression. let init = ExprTreeNode::new(expr); - // Create a new PhysicalExprDAEGBuilder instance. + // Create a new `PhysicalExprDAEGBuilder` instance. let mut builder = PhysicalExprDAEGBuilder { graph: StableGraph::::new(), visited_plans: Vec::<(Arc, NodeIndex)>::new(), @@ -360,7 +361,26 @@ mod tests { use petgraph::visit::Bfs; use std::sync::Arc; - fn make_dummy_nodes(node: &ExprTreeNode) -> PhysicalExprDummyNode { + #[derive(Clone)] + struct DummyProperty { + expr_type: String, + } + + /// This is a dummy node in the DAEG; it stores a reference to the actual + /// [PhysicalExpr] as well as a dummy property. + #[derive(Clone)] + struct PhysicalExprDummyNode { + pub expr: Arc, + pub property: DummyProperty, + } + + impl Display for PhysicalExprDummyNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.expr) + } + } + + fn make_dummy_node(node: &ExprTreeNode) -> PhysicalExprDummyNode { let expr = node.expression().clone(); let dummy_property = if expr.as_any().is::() { "Binary" @@ -379,27 +399,9 @@ mod tests { }, } } - #[derive(Clone)] - struct DummyProperty { - expr_type: String, - } - - /// This is a node in the DAEG; it encapsulates a reference to the actual - /// [PhysicalExpr] as well as a dummy property to use it in DAEG traversal. - #[derive(Clone)] - struct PhysicalExprDummyNode { - pub expr: Arc, - pub property: DummyProperty, - } - - impl Display for PhysicalExprDummyNode { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.expr) - } - } #[test] - fn testing_build_dag() -> Result<()> { + fn test_build_dag() -> Result<()> { let schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), Field::new("1", DataType::Int32, true), @@ -426,7 +428,7 @@ mod tests { &schema, )?; let mut vector_dummy_props = vec![]; - let (root, graph) = build_dag(expr, &make_dummy_nodes)?; + let (root, graph) = build_dag(expr, &make_dummy_node)?; let mut bfs = Bfs::new(&graph, root); while let Some(node_index) = bfs.next(&graph) { let node = &graph[node_index]; From adbc128c8d272f26895e60d47c575f5f5825b504 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Tue, 28 Feb 2023 11:00:00 +0300 Subject: [PATCH 38/39] Minor fix --- datafusion/core/src/test_util.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index 34b5168c17c8..37f3a1667f37 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -27,6 +27,7 @@ use crate::datasource::datasource::TableProviderFactory; use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; +use crate::execution::options::ReadOptions; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, From f95e3a6b328d818e9356bf081e900556510b9c44 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Tue, 28 Feb 2023 10:47:30 -0600 Subject: [PATCH 39/39] Fix typo in the new_zero function --- datafusion/common/src/scalar.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 3e99ea63d7ab..40d1eb25e07a 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1032,8 +1032,8 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(0)), DataType::UInt32 => ScalarValue::UInt32(Some(0)), DataType::UInt64 => ScalarValue::UInt64(Some(0)), - DataType::Float32 => ScalarValue::UInt64(Some(0)), - DataType::Float64 => ScalarValue::UInt64(Some(0)), + DataType::Float32 => ScalarValue::Float32(Some(0.0)), + DataType::Float64 => ScalarValue::Float64(Some(0.0)), _ => { return Err(DataFusionError::NotImplemented(format!( "Can't create a zero scalar from data_type \"{datatype:?}\""