diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index 7b04694792e5..eaf3be2b86ed 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -922,6 +922,11 @@ async fn test_hashjoin_dynamic_filter_pushdown() { let plan = FilterPushdown::new_post_optimization() .optimize(plan, &config) .unwrap(); + + // Test for https://github.com/apache/datafusion/pull/17371: dynamic filter linking survives `with_new_children` + let children = plan.children().into_iter().map(Arc::clone).collect(); + let plan = plan.with_new_children(children).unwrap(); + let config = SessionConfig::new().with_batch_size(10); let session_ctx = SessionContext::new_with_config(config); session_ctx.register_object_store( diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 5710bfefb530..16e2166bb829 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -354,12 +354,18 @@ pub struct HashJoinExec { /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, /// Dynamic filter for pushing down to the probe side - /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result - dynamic_filter: Option>, - /// Shared bounds accumulator for coordinating dynamic filter updates across partitions - /// Only created when dynamic filter pushdown is enabled. - /// Lazily initialized at execution time to use actual runtime partition counts - bounds_accumulator: Option>>, + /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. + /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. + dynamic_filter: Option, +} + +#[derive(Clone)] +struct HashJoinExecDynamicFilter { + /// Dynamic filter that we'll update with the results of the build side once that is done. + filter: Arc, + /// Bounds accumulator to keep track of the min/max bounds on the join keys for each partition. + /// It is lazily initialized during execution to make sure we use the actual execution time partition counts. + bounds_accumulator: OnceLock>, } impl fmt::Debug for HashJoinExec { @@ -453,7 +459,6 @@ impl HashJoinExec { null_equality, cache, dynamic_filter: None, - bounds_accumulator: None, }) } @@ -837,7 +842,6 @@ impl ExecutionPlan for HashJoinExec { )?, // Keep the dynamic filter, bounds accumulator will be reset dynamic_filter: self.dynamic_filter.clone(), - bounds_accumulator: None, })) } @@ -860,7 +864,6 @@ impl ExecutionPlan for HashJoinExec { cache: self.cache.clone(), // Reset dynamic filter and bounds accumulator to initial state dynamic_filter: None, - bounds_accumulator: None, })) } @@ -942,32 +945,28 @@ impl ExecutionPlan for HashJoinExec { let batch_size = context.session_config().batch_size(); // Initialize bounds_accumulator lazily with runtime partition counts (only if enabled) - let bounds_accumulator = if enable_dynamic_filter_pushdown - && self.dynamic_filter.is_some() - { - if let Some(ref bounds_accumulator_oncelock) = self.bounds_accumulator { - let dynamic_filter = Arc::clone(self.dynamic_filter.as_ref().unwrap()); - let on_right = self - .on - .iter() - .map(|(_, right_expr)| Arc::clone(right_expr)) - .collect::>(); - - Some(Arc::clone(bounds_accumulator_oncelock.get_or_init(|| { - Arc::new(SharedBoundsAccumulator::new_from_partition_mode( - self.mode, - self.left.as_ref(), - self.right.as_ref(), - dynamic_filter, - on_right, - )) - }))) - } else { - None - } - } else { - None - }; + let bounds_accumulator = enable_dynamic_filter_pushdown + .then(|| { + self.dynamic_filter.as_ref().map(|df| { + let filter = Arc::clone(&df.filter); + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + Some(Arc::clone(df.bounds_accumulator.get_or_init(|| { + Arc::new(SharedBoundsAccumulator::new_from_partition_mode( + self.mode, + self.left.as_ref(), + self.right.as_ref(), + filter, + on_right, + )) + }))) + }) + }) + .flatten() + .flatten(); // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. @@ -1162,8 +1161,10 @@ impl ExecutionPlan for HashJoinExec { column_indices: self.column_indices.clone(), null_equality: self.null_equality, cache: self.cache.clone(), - dynamic_filter: Some(dynamic_filter), - bounds_accumulator: Some(OnceLock::new()), + dynamic_filter: Some(HashJoinExecDynamicFilter { + filter: dynamic_filter, + bounds_accumulator: OnceLock::new(), + }), }); result = result.with_updated_node(new_node as Arc); }