diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 1b23beeaa37c..7e61be3a16ae 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -85,5 +85,9 @@ name = "is_null" harness = false name = "binary_op" +[[bench]] +harness = false +name = "simplify" + [package.metadata.cargo-machete] ignored = ["half"] diff --git a/datafusion/physical-expr/benches/simplify.rs b/datafusion/physical-expr/benches/simplify.rs new file mode 100644 index 000000000000..cc00c710004e --- /dev/null +++ b/datafusion/physical-expr/benches/simplify.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This is an attempt at reproducing some predicates generated by TPC-DS query #76, +//! and trying to figure out how long it takes to simplify them. + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; +use std::hint::black_box; +use std::sync::Arc; + +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; + +use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, Column, IsNullExpr, Literal, +}; + +fn catalog_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("cs_sold_date_sk", DataType::Int64, true), // 0 + Field::new("cs_sold_time_sk", DataType::Int64, true), // 1 + Field::new("cs_ship_date_sk", DataType::Int64, true), // 2 + Field::new("cs_bill_customer_sk", DataType::Int64, true), // 3 + Field::new("cs_bill_cdemo_sk", DataType::Int64, true), // 4 + Field::new("cs_bill_hdemo_sk", DataType::Int64, true), // 5 + Field::new("cs_bill_addr_sk", DataType::Int64, true), // 6 + Field::new("cs_ship_customer_sk", DataType::Int64, true), // 7 + Field::new("cs_ship_cdemo_sk", DataType::Int64, true), // 8 + Field::new("cs_ship_hdemo_sk", DataType::Int64, true), // 9 + Field::new("cs_ship_addr_sk", DataType::Int64, true), // 10 + Field::new("cs_call_center_sk", DataType::Int64, true), // 11 + Field::new("cs_catalog_page_sk", DataType::Int64, true), // 12 + Field::new("cs_ship_mode_sk", DataType::Int64, true), // 13 + Field::new("cs_warehouse_sk", DataType::Int64, true), // 14 + Field::new("cs_item_sk", DataType::Int64, true), // 15 + Field::new("cs_promo_sk", DataType::Int64, true), // 16 + Field::new("cs_order_number", DataType::Int64, true), // 17 + Field::new("cs_quantity", DataType::Int64, true), // 18 + Field::new("cs_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +fn web_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("ws_sold_date_sk", DataType::Int64, true), + Field::new("ws_sold_time_sk", DataType::Int64, true), + Field::new("ws_ship_date_sk", DataType::Int64, true), + Field::new("ws_item_sk", DataType::Int64, true), + Field::new("ws_bill_customer_sk", DataType::Int64, true), + Field::new("ws_bill_cdemo_sk", DataType::Int64, true), + Field::new("ws_bill_hdemo_sk", DataType::Int64, true), + Field::new("ws_bill_addr_sk", DataType::Int64, true), + Field::new("ws_ship_customer_sk", DataType::Int64, true), + Field::new("ws_ship_cdemo_sk", DataType::Int64, true), + Field::new("ws_ship_hdemo_sk", DataType::Int64, true), + Field::new("ws_ship_addr_sk", DataType::Int64, true), + Field::new("ws_web_page_sk", DataType::Int64, true), + Field::new("ws_web_site_sk", DataType::Int64, true), + Field::new("ws_ship_mode_sk", DataType::Int64, true), + Field::new("ws_warehouse_sk", DataType::Int64, true), + Field::new("ws_promo_sk", DataType::Int64, true), + Field::new("ws_order_number", DataType::Int64, true), + Field::new("ws_quantity", DataType::Int64, true), + Field::new("ws_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +// Helper to create a literal +fn lit_i64(val: i64) -> Arc { + Arc::new(Literal::new(ScalarValue::Int64(Some(val)))) +} + +fn lit_i32(val: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(val)))) +} + +fn lit_bool(val: bool) -> Arc { + Arc::new(Literal::new(ScalarValue::Boolean(Some(val)))) +} + +// Helper to create binary expressions +fn and( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::And, right)) +} + +fn gte( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::GtEq, right)) +} + +fn lte( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::LtEq, right)) +} + +fn modulo( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::Modulo, right)) +} + +fn eq( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::Eq, right)) +} + +/// Build a predicate similar to TPC-DS q76 catalog_sales filter. +/// Uses placeholder columns instead of hash expressions. +pub fn catalog_sales_predicate(num_partitions: usize) -> Arc { + let cs_sold_date_sk: Arc = + Arc::new(Column::new("cs_sold_date_sk", 0)); + let cs_ship_addr_sk: Arc = + Arc::new(Column::new("cs_ship_addr_sk", 10)); + let cs_item_sk: Arc = Arc::new(Column::new("cs_item_sk", 15)); + + // Use a simple modulo expression as placeholder for hash + let item_hash_mod = modulo(cs_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(cs_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // cs_ship_addr_sk IS NULL + let is_null_expr: Arc = Arc::new(IsNullExpr::new(cs_ship_addr_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_item_sk.clone(), lit_i64(partition as i64)), + lte(cs_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(cs_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + // Final: is_null AND item_case AND date_case + and(and(is_null_expr, item_case_expr), date_case_expr) +} +/// Build a predicate similar to TPC-DS q76 web_sales filter. +/// Uses placeholder columns instead of hash expressions. +fn web_sales_predicate(num_partitions: usize) -> Arc { + let ws_sold_date_sk: Arc = + Arc::new(Column::new("ws_sold_date_sk", 0)); + let ws_item_sk: Arc = Arc::new(Column::new("ws_item_sk", 3)); + let ws_ship_customer_sk: Arc = + Arc::new(Column::new("ws_ship_customer_sk", 8)); + + // Use simple modulo expression as placeholder for hash + let item_hash_mod = modulo(ws_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(ws_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // ws_ship_customer_sk IS NULL + let is_null_expr: Arc = + Arc::new(IsNullExpr::new(ws_ship_customer_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_item_sk.clone(), lit_i64(partition as i64)), + lte(ws_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(ws_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + and(and(is_null_expr, item_case_expr), date_case_expr) +} + +/// Measures how long `PhysicalExprSimplifier::simplify` takes for a given expression. +fn bench_simplify( + c: &mut Criterion, + name: &str, + schema: &Schema, + expr: &Arc, +) { + let simplifier = PhysicalExprSimplifier::new(schema); + c.bench_function(name, |b| { + b.iter(|| black_box(simplifier.simplify(black_box(Arc::clone(expr))).unwrap())) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let cs_schema = catalog_sales_schema(); + let ws_schema = web_sales_schema(); + + for num_partitions in [16, 128] { + bench_simplify( + c, + &format!("tpc-ds/q76/cs/{num_partitions}"), + &cs_schema, + &catalog_sales_predicate(num_partitions), + ); + bench_simplify( + c, + &format!("tpc-ds/q76/ws/{num_partitions}"), + &ws_schema, + &web_sales_predicate(num_partitions), + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/simplifier/const_evaluator.rs b/datafusion/physical-expr/src/simplifier/const_evaluator.rs index 8a2368c4040a..1e62e47ce206 100644 --- a/datafusion/physical-expr/src/simplifier/const_evaluator.rs +++ b/datafusion/physical-expr/src/simplifier/const_evaluator.rs @@ -40,17 +40,22 @@ use crate::expressions::{Column, Literal}; /// - `(1 + 2) * 3` -> `9` (with bottom-up traversal) /// - `'hello' || ' world'` -> `'hello world'` pub fn simplify_const_expr( - expr: &Arc, + expr: Arc, ) -> Result>> { - if !can_evaluate_as_constant(expr) { - return Ok(Transformed::no(Arc::clone(expr))); - } + simplify_const_expr_with_dummy(expr, &create_dummy_batch()?) +} - // Create a 1-row dummy batch for evaluation - let batch = create_dummy_batch()?; +pub(crate) fn simplify_const_expr_with_dummy( + expr: Arc, + batch: &RecordBatch, +) -> Result>> { + // If expr is already a const literal or can't be evaluated into one. + if expr.as_any().is::() || (!can_evaluate_as_constant(&expr)) { + return Ok(Transformed::no(expr)); + } // Evaluate the expression - match expr.evaluate(&batch) { + match expr.evaluate(batch) { Ok(ColumnarValue::Scalar(scalar)) => { Ok(Transformed::yes(Arc::new(Literal::new(scalar)))) } @@ -61,13 +66,13 @@ pub fn simplify_const_expr( } Ok(_) => { // Unexpected result - keep original expression - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } Err(_) => { // On error, keep original expression // The expression might succeed at runtime due to short-circuit evaluation // or other runtime conditions - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } } } @@ -95,7 +100,7 @@ fn can_evaluate_as_constant(expr: &Arc) -> bool { /// that only contain literals, the batch content is irrelevant. /// /// This is the same approach used in the logical expression `ConstEvaluator`. -fn create_dummy_batch() -> Result { +pub(crate) fn create_dummy_batch() -> Result { // RecordBatch requires at least one column let dummy_schema = Arc::new(Schema::new(vec![Field::new("_", DataType::Null, true)])); let col = new_null_array(&DataType::Null, 1); diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 3bd4683c167c..45ead82a0a93 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -21,7 +21,14 @@ use arrow::datatypes::Schema; use datafusion_common::{Result, tree_node::TreeNode}; use std::sync::Arc; -use crate::{PhysicalExpr, simplifier::not::simplify_not_expr}; +use crate::{ + PhysicalExpr, + simplifier::{ + const_evaluator::{create_dummy_batch, simplify_const_expr_with_dummy}, + not::simplify_not_expr, + unwrap_cast::unwrap_cast_in_comparison, + }, +}; pub mod const_evaluator; pub mod not; @@ -50,6 +57,8 @@ impl<'a> PhysicalExprSimplifier<'a> { let mut count = 0; let schema = self.schema; + let batch = create_dummy_batch()?; + while count < MAX_LOOP_COUNT { count += 1; let result = current_expr.transform(|node| { @@ -58,11 +67,11 @@ impl<'a> PhysicalExprSimplifier<'a> { // Apply NOT expression simplification first, then unwrap cast optimization, // then constant expression evaluation - let rewritten = simplify_not_expr(&node, schema)? + let rewritten = simplify_not_expr(node, schema)? + .transform_data(|node| unwrap_cast_in_comparison(node, schema))? .transform_data(|node| { - unwrap_cast::unwrap_cast_in_comparison(node, schema) - })? - .transform_data(|node| const_evaluator::simplify_const_expr(&node))?; + simplify_const_expr_with_dummy(node, &batch) + })?; #[cfg(debug_assertions)] assert_eq!( diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs index 9b65d5cba95a..ea5467d0a4b4 100644 --- a/datafusion/physical-expr/src/simplifier/not.rs +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -44,13 +44,13 @@ use crate::expressions::{BinaryExpr, InListExpr, Literal, NotExpr, in_list, lit} /// TreeNodeRewriter, multiple passes will automatically be applied until no more /// transformations are possible. pub fn simplify_not_expr( - expr: &Arc, + expr: Arc, schema: &Schema, ) -> Result>> { // Check if this is a NOT expression let not_expr = match expr.as_any().downcast_ref::() { Some(not_expr) => not_expr, - None => return Ok(Transformed::no(Arc::clone(expr))), + None => return Ok(Transformed::no(expr)), }; let inner_expr = not_expr.arg(); @@ -120,5 +120,5 @@ pub fn simplify_not_expr( } // If no simplification possible, return the original expression - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index ae6da9c5e0dc..0de517cd36c8 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,10 +34,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{ - Result, ScalarValue, - tree_node::{Transformed, TreeNode}, -}; +use datafusion_common::{Result, ScalarValue, tree_node::Transformed}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; @@ -49,14 +46,12 @@ pub(crate) fn unwrap_cast_in_comparison( expr: Arc, schema: &Schema, ) -> Result>> { - expr.transform_down(|e| { - if let Some(binary) = e.as_any().downcast_ref::() - && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? - { - return Ok(Transformed::yes(unwrapped)); - } - Ok(Transformed::no(e)) - }) + if let Some(binary) = expr.as_any().downcast_ref::() + && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? + { + return Ok(Transformed::yes(unwrapped)); + } + Ok(Transformed::no(expr)) } /// Try to unwrap casts in binary expressions @@ -144,7 +139,7 @@ mod tests { use super::*; use crate::expressions::{col, lit}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; + use datafusion_common::{ScalarValue, tree_node::TreeNode}; use datafusion_expr::Operator; /// Check if an expression is a cast expression @@ -484,8 +479,10 @@ mod tests { let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed); @@ -602,8 +599,10 @@ mod tests { // Create AND expression let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed);