From a016ed641d3fcafde34f2ceda6118d05f386d2ab Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 9 Sep 2025 12:58:22 -0500 Subject: [PATCH] Refactor HashJoinExec to progressively accumulate dynamic filter bounds instead of computing them after data is accumulated (#17444) (cherry picked from commit 5b833b9a63018fd91baa9bb9031b4ed19dbfd721) --- Cargo.lock | 1 - datafusion/physical-plan/Cargo.toml | 2 +- .../physical-plan/src/joins/hash_join.rs | 190 +++++++++++++++--- 3 files changed, 159 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 20b894864647..a088005a0f19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2506,7 +2506,6 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions-aggregate", - "datafusion-functions-aggregate-common", "datafusion-functions-window", "datafusion-functions-window-common", "datafusion-physical-expr", diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 9889b45cc5a5..6dc42472d68a 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -53,7 +53,7 @@ datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate-common = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-physical-expr-common = { workspace = true } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index ec228d4b40b8..bf1d713a96ec 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -72,6 +72,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; +use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ @@ -80,8 +81,9 @@ use datafusion_common::{ }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_expr::Accumulator; use datafusion_expr::Operator; -use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch}; +use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, }; @@ -1430,29 +1432,123 @@ impl ExecutionPlan for HashJoinExec { } } -/// Compute min/max bounds for each column in the given arrays -fn compute_bounds(arrays: &[ArrayRef]) -> Result> { - arrays - .iter() - .map(|array| { - if array.is_empty() { - // Return NULL values for empty arrays - return Ok(ColumnBounds::new( - ScalarValue::try_from(array.data_type())?, - ScalarValue::try_from(array.data_type())?, - )); +/// Accumulator for collecting min/max bounds from build-side data during hash join. +/// +/// This struct encapsulates the logic for progressively computing column bounds +/// (minimum and maximum values) for a specific join key expression as batches +/// are processed during the build phase of a hash join. +/// +/// The bounds are used for dynamic filter pushdown optimization, where filters +/// based on the actual data ranges can be pushed down to the probe side to +/// eliminate unnecessary data early. +struct CollectLeftAccumulator { + /// The physical expression to evaluate for each batch + expr: Arc, + /// Accumulator for tracking the minimum value across all batches + min: MinAccumulator, + /// Accumulator for tracking the maximum value across all batches + max: MaxAccumulator, +} + +impl CollectLeftAccumulator { + /// Creates a new accumulator for tracking bounds of a join key expression. + /// + /// # Arguments + /// * `expr` - The physical expression to track bounds for + /// * `schema` - The schema of the input data + /// + /// # Returns + /// A new `CollectLeftAccumulator` instance configured for the expression's data type + fn try_new(expr: Arc, schema: &SchemaRef) -> Result { + /// Recursively unwraps dictionary types to get the underlying value type. + fn dictionary_value_type(data_type: &DataType) -> DataType { + match data_type { + DataType::Dictionary(_, value_type) => { + dictionary_value_type(value_type.as_ref()) + } + _ => data_type.clone(), } + } + + let data_type = expr + .data_type(schema) + // Min/Max can operate on dictionary data but expect to be initialized with the underlying value type + .map(|dt| dictionary_value_type(&dt))?; + Ok(Self { + expr, + min: MinAccumulator::try_new(&data_type)?, + max: MaxAccumulator::try_new(&data_type)?, + }) + } - // Use Arrow kernels for efficient min/max computation - let min_val = min_batch(array)?; - let max_val = max_batch(array)?; + /// Updates the accumulators with values from a new batch. + /// + /// Evaluates the expression on the batch and updates both min and max + /// accumulators with the resulting values. + /// + /// # Arguments + /// * `batch` - The record batch to process + /// + /// # Returns + /// Ok(()) if the update succeeds, or an error if expression evaluation fails + fn update_batch(&mut self, batch: &RecordBatch) -> Result<()> { + let array = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; + self.min.update_batch(std::slice::from_ref(&array))?; + self.max.update_batch(std::slice::from_ref(&array))?; + Ok(()) + } - Ok(ColumnBounds::new(min_val, max_val)) + /// Finalizes the accumulation and returns the computed bounds. + /// + /// Consumes self to extract the final min and max values from the accumulators. + /// + /// # Returns + /// The `ColumnBounds` containing the minimum and maximum values observed + fn evaluate(mut self) -> Result { + Ok(ColumnBounds::new( + self.min.evaluate()?, + self.max.evaluate()?, + )) + } +} + +/// State for collecting the build-side data during hash join +struct BuildSideState { + batches: Vec, + num_rows: usize, + metrics: BuildProbeJoinMetrics, + reservation: MemoryReservation, + bounds_accumulators: Option>, +} + +impl BuildSideState { + /// Create a new BuildSideState with optional accumulators for bounds computation + fn try_new( + metrics: BuildProbeJoinMetrics, + reservation: MemoryReservation, + on_left: Vec>, + schema: &SchemaRef, + should_compute_bounds: bool, + ) -> Result { + Ok(Self { + batches: Vec::new(), + num_rows: 0, + metrics, + reservation, + bounds_accumulators: should_compute_bounds + .then(|| { + on_left + .iter() + .map(|expr| { + CollectLeftAccumulator::try_new(Arc::clone(expr), schema) + }) + .collect::>>() + }) + .transpose()?, }) - .collect() + } } -#[expect(clippy::too_many_arguments)] /// Collects all batches from the left (build) side stream and creates a hash map for joining. /// /// This function is responsible for: @@ -1481,6 +1577,7 @@ fn compute_bounds(arrays: &[ArrayRef]) -> Result> { /// # Returns /// `JoinLeftData` containing the hash map, consolidated batch, join key values, /// visited indices bitmap, and computed bounds (if requested). +#[allow(clippy::too_many_arguments)] async fn collect_left_input( random_state: RandomState, left_stream: SendableRecordBatchStream, @@ -1496,24 +1593,48 @@ async fn collect_left_input( // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream // 2. stores the batches in a vector. - let initial = (Vec::new(), 0, metrics, reservation); - let (batches, num_rows, metrics, mut reservation) = left_stream - .try_fold(initial, |mut acc, batch| async { + let initial = BuildSideState::try_new( + metrics, + reservation, + on_left.clone(), + &schema, + should_compute_bounds, + )?; + + let state = left_stream + .try_fold(initial, |mut state, batch| async move { + // Update accumulators if computing bounds + if let Some(ref mut accumulators) = state.bounds_accumulators { + for accumulator in accumulators { + accumulator.update_batch(&batch)?; + } + } + + // Decide if we spill or not let batch_size = get_record_batch_memory_size(&batch); // Reserve memory for incoming batch - acc.3.try_grow(batch_size)?; + state.reservation.try_grow(batch_size)?; // Update metrics - acc.2.build_mem_used.add(batch_size); - acc.2.build_input_batches.add(1); - acc.2.build_input_rows.add(batch.num_rows()); + state.metrics.build_mem_used.add(batch_size); + state.metrics.build_input_batches.add(1); + state.metrics.build_input_rows.add(batch.num_rows()); // Update row count - acc.1 += batch.num_rows(); + state.num_rows += batch.num_rows(); // Push batch to output - acc.0.push(batch); - Ok(acc) + state.batches.push(batch); + Ok(state) }) .await?; + // Extract fields from state + let BuildSideState { + batches, + num_rows, + metrics, + mut reservation, + bounds_accumulators, + } = state; + // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` let fixed_size_u32 = size_of::(); @@ -1580,10 +1701,15 @@ async fn collect_left_input( .collect::>>()?; // Compute bounds for dynamic filter if enabled - let bounds = if should_compute_bounds && num_rows > 0 { - Some(compute_bounds(&left_values)?) - } else { - None + let bounds = match bounds_accumulators { + Some(accumulators) if num_rows > 0 => { + let bounds = accumulators + .into_iter() + .map(CollectLeftAccumulator::evaluate) + .collect::>>()?; + Some(bounds) + } + _ => None, }; let data = JoinLeftData::new(