Skip to content

Commit

Permalink
Make PruningPredicate's rewrite public
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Oct 10, 2024
1 parent 43d0bcf commit 9c49413
Showing 1 changed file with 158 additions and 23 deletions.
181 changes: 158 additions & 23 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,37 @@ pub struct PruningPredicate {
literal_guarantees: Vec<LiteralGuarantee>,
}

/// Hook to handle predicates that DataFusion can not handle, e.g. certain complex expressions
/// or predicates that reference columns that are not in the schema.
pub trait UnhandledPredicateHook {
/// Called when a predicate can not be handled by DataFusion's transformation rules
/// or is referencing a column that is not in the schema.
fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>;
}

#[derive(Debug, Clone)]
struct ConstantUnhandledPredicateHook {
default: Arc<dyn PhysicalExpr>,
}

impl ConstantUnhandledPredicateHook {
fn new(default: Arc<dyn PhysicalExpr>) -> Self {
Self { default }
}
}

impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
self.default.clone()
}
}

fn default_unhandled_hook() -> Arc<dyn UnhandledPredicateHook> {
Arc::new(ConstantUnhandledPredicateHook::new(Arc::new(
phys_expr::Literal::new(ScalarValue::Boolean(Some(true))),
)))
}

impl PruningPredicate {
/// Try to create a new instance of [`PruningPredicate`]
///
Expand All @@ -502,10 +533,16 @@ impl PruningPredicate {
/// See the struct level documentation on [`PruningPredicate`] for more
/// details.
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: SchemaRef) -> Result<Self> {
let unhandled_hook = default_unhandled_hook();

// build predicate expression once
let mut required_columns = RequiredColumns::new();
let predicate_expr =
build_predicate_expression(&expr, schema.as_ref(), &mut required_columns);
let predicate_expr = build_predicate_expression(
&expr,
schema.as_ref(),
&mut required_columns,
&unhandled_hook,
);

let literal_guarantees = LiteralGuarantee::analyze(&expr);

Expand Down Expand Up @@ -1316,23 +1353,43 @@ const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20;
/// expression that will evaluate to FALSE if it can be determined no
/// rows between the min/max values could pass the predicates.
///
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
///
/// Returns the pruning predicate as an [`PhysicalExpr`]
///
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
pub fn rewrite_predicate_to_statistics_predicate(
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
unhandled_hook: Option<Arc<dyn UnhandledPredicateHook>>,
) -> Arc<dyn PhysicalExpr> {
let unhandled_hook = unhandled_hook.unwrap_or(default_unhandled_hook());

let mut required_columns = RequiredColumns::new();

build_predicate_expression(expr, schema, &mut required_columns, &unhandled_hook)
}

/// Translate logical filter expression into pruning predicate
/// expression that will evaluate to FALSE if it can be determined no
/// rows between the min/max values could pass the predicates.
///
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
///
/// Returns the pruning predicate as an [`PhysicalExpr`]
///
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
fn build_predicate_expression(
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
required_columns: &mut RequiredColumns,
unhandled_hook: &Arc<dyn UnhandledPredicateHook>,
) -> Arc<dyn PhysicalExpr> {
// Returned for unsupported expressions. Such expressions are
// converted to TRUE.
let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));

// predicate expression can only be a binary expression
let expr_any = expr.as_any();
if let Some(is_null) = expr_any.downcast_ref::<phys_expr::IsNullExpr>() {
return build_is_null_column_expr(is_null.arg(), schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(is_not_null) = expr_any.downcast_ref::<phys_expr::IsNotNullExpr>() {
return build_is_null_column_expr(
Expand All @@ -1341,19 +1398,19 @@ fn build_predicate_expression(
required_columns,
true,
)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(col) = expr_any.downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
// match !col (don't do so recursively)
if let Some(col) = not.arg().as_any().downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, true)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
Expand Down Expand Up @@ -1382,9 +1439,14 @@ fn build_predicate_expression(
})
.reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _)
.unwrap();
return build_predicate_expression(&change_expr, schema, required_columns);
return build_predicate_expression(
&change_expr,
schema,
required_columns,
unhandled_hook,
);
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}

Expand All @@ -1396,21 +1458,23 @@ fn build_predicate_expression(
bin_expr.right().clone(),
)
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
};

if op == Operator::And || op == Operator::Or {
let left_expr = build_predicate_expression(&left, schema, required_columns);
let right_expr = build_predicate_expression(&right, schema, required_columns);
let left_expr =
build_predicate_expression(&left, schema, required_columns, unhandled_hook);
let right_expr =
build_predicate_expression(&right, schema, required_columns, unhandled_hook);
// simplify boolean expression if applicable
let expr = match (&left_expr, op, &right_expr) {
(left, Operator::And, _) if is_always_true(left) => right_expr,
(_, Operator::And, right) if is_always_true(right) => left_expr,
(left, Operator::Or, right)
if is_always_true(left) || is_always_true(right) =>
{
unhandled
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
}
_ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)),
};
Expand All @@ -1423,12 +1487,11 @@ fn build_predicate_expression(
Ok(builder) => builder,
// allow partial failure in predicate expression generation
// this can still produce a useful predicate when multiple conditions are joined using AND
Err(_) => {
return unhandled;
}
Err(_) => return unhandled_hook.handle(expr),
};

build_statistics_expr(&mut expr_builder).unwrap_or(unhandled)
build_statistics_expr(&mut expr_builder)
.unwrap_or_else(|_| unhandled_hook.handle(expr))
}

fn build_statistics_expr(
Expand Down Expand Up @@ -1582,6 +1645,8 @@ mod tests {
use arrow_array::UInt64Array;
use datafusion_expr::expr::InList;
use datafusion_expr::{cast, is_null, try_cast, Expr};
use datafusion_functions_nested::expr_fn::{array_has, make_array};
use datafusion_physical_expr::expressions as phys_expr;
use datafusion_physical_expr::planner::logical2physical;

#[derive(Debug, Default)]
Expand Down Expand Up @@ -3397,6 +3462,75 @@ mod tests {
// TODO: add test for other case and op
}

#[test]
fn test_rewrite_expr_to_prunable_custom_unhandled_hook() {
struct CustomUnhandledHook;

impl UnhandledPredicateHook for CustomUnhandledHook {
/// This handles an arbitrary case of a column that doesn't exist in the schema
/// by renaming it to yet another column that doesn't exist in the schema
/// (the transformation is arbitrary, the point is that it can do whatever it wants)
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42))))
}
}

let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let schema_with_b = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]);

let transform_expr = |expr| {
let expr = logical2physical(&expr, &schema_with_b);
rewrite_predicate_to_statistics_predicate(
&expr,
&schema,
Some(Arc::new(CustomUnhandledHook {})),
)
};

// transform an arbitrary valid expression that we know is handled
let known_expression = col("a").eq(lit(ScalarValue::Int32(Some(12))));
let known_expression_transformed = rewrite_predicate_to_statistics_predicate(
&logical2physical(&known_expression, &schema),
&schema,
None,
);

// an expression referencing an unknown column (that is not in the schema) gets passed to the hook
let input = col("b").eq(lit(ScalarValue::Int32(Some(12))));
let expected = logical2physical(&lit(42), &schema);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// more complex case with unknown column
let input = known_expression.clone().and(input.clone());
let expected = phys_expr::BinaryExpr::new(
known_expression_transformed.clone(),
Operator::And,
logical2physical(&lit(42), &schema),
);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// an unknown expression gets passed to the hook
let input = array_has(make_array(vec![lit(1)]), col("a"));
let expected = logical2physical(&lit(42), &schema);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// more complex case with unknown expression
let input = known_expression.and(input);
let expected = phys_expr::BinaryExpr::new(
known_expression_transformed.clone(),
Operator::And,
logical2physical(&lit(42), &schema),
);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());
}

#[test]
fn test_rewrite_expr_to_prunable_error() {
// cast string value to numeric value
Expand Down Expand Up @@ -3886,6 +4020,7 @@ mod tests {
required_columns: &mut RequiredColumns,
) -> Arc<dyn PhysicalExpr> {
let expr = logical2physical(expr, schema);
build_predicate_expression(&expr, schema, required_columns)
let unhandled_hook = default_unhandled_hook();
build_predicate_expression(&expr, schema, required_columns, &unhandled_hook)
}
}

0 comments on commit 9c49413

Please sign in to comment.