Skip to content

Commit 29d63c1

Browse files
authored
Optimize PhysicalExprSimplifier (#20111)
## Which issue does this PR close? - Related to #20078 ## Rationale for this change An attempt at reducing the cost of physical expression simplification ## What changes are included in this PR? 1. The most important change in this PR is that if an expression is already literal, don't transform it, which means we can stop transforming the tree much earlier. currently on main, even expressions like `lit(5)` end up running through the loop 5 times. This takes this PR to ~96% improvement on the benchmark. 2. Allocate a single dummy record batch for simplifying const expressions, instead of one per `simplify_const_expr` call. 3. Adds the benchmark I've been using to test the impact of changes 4. `simplify_not_expr` and `simplify_const_expr` now take an `Arc` instead of `&Arc` ## Are these changes tested? All existing tests pass with minor modifications. ## Are there any user-facing changes? Two of the individual recursive simplification functions (`simplify_not_expr` and `simplify_const_expr`) are public. This PR breaks their signature, but I think we should consider also making them private. --------- Signed-off-by: Adam Gutglick <adamgsal@gmail.com>
1 parent 4557033 commit 29d63c1

File tree

6 files changed

+351
-35
lines changed

6 files changed

+351
-35
lines changed

datafusion/physical-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,9 @@ name = "is_null"
8585
harness = false
8686
name = "binary_op"
8787

88+
[[bench]]
89+
harness = false
90+
name = "simplify"
91+
8892
[package.metadata.cargo-machete]
8993
ignored = ["half"]
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! This is an attempt at reproducing some predicates generated by TPC-DS query #76,
19+
//! and trying to figure out how long it takes to simplify them.
20+
21+
use arrow::datatypes::{DataType, Field, Schema};
22+
use criterion::{Criterion, criterion_group, criterion_main};
23+
use datafusion_physical_expr::PhysicalExpr;
24+
use datafusion_physical_expr::simplifier::PhysicalExprSimplifier;
25+
use std::hint::black_box;
26+
use std::sync::Arc;
27+
28+
use datafusion_common::ScalarValue;
29+
use datafusion_expr::Operator;
30+
31+
use datafusion_physical_expr::expressions::{
32+
BinaryExpr, CaseExpr, Column, IsNullExpr, Literal,
33+
};
34+
35+
fn catalog_sales_schema() -> Schema {
36+
Schema::new(vec![
37+
Field::new("cs_sold_date_sk", DataType::Int64, true), // 0
38+
Field::new("cs_sold_time_sk", DataType::Int64, true), // 1
39+
Field::new("cs_ship_date_sk", DataType::Int64, true), // 2
40+
Field::new("cs_bill_customer_sk", DataType::Int64, true), // 3
41+
Field::new("cs_bill_cdemo_sk", DataType::Int64, true), // 4
42+
Field::new("cs_bill_hdemo_sk", DataType::Int64, true), // 5
43+
Field::new("cs_bill_addr_sk", DataType::Int64, true), // 6
44+
Field::new("cs_ship_customer_sk", DataType::Int64, true), // 7
45+
Field::new("cs_ship_cdemo_sk", DataType::Int64, true), // 8
46+
Field::new("cs_ship_hdemo_sk", DataType::Int64, true), // 9
47+
Field::new("cs_ship_addr_sk", DataType::Int64, true), // 10
48+
Field::new("cs_call_center_sk", DataType::Int64, true), // 11
49+
Field::new("cs_catalog_page_sk", DataType::Int64, true), // 12
50+
Field::new("cs_ship_mode_sk", DataType::Int64, true), // 13
51+
Field::new("cs_warehouse_sk", DataType::Int64, true), // 14
52+
Field::new("cs_item_sk", DataType::Int64, true), // 15
53+
Field::new("cs_promo_sk", DataType::Int64, true), // 16
54+
Field::new("cs_order_number", DataType::Int64, true), // 17
55+
Field::new("cs_quantity", DataType::Int64, true), // 18
56+
Field::new("cs_wholesale_cost", DataType::Decimal128(7, 2), true),
57+
Field::new("cs_list_price", DataType::Decimal128(7, 2), true),
58+
Field::new("cs_sales_price", DataType::Decimal128(7, 2), true),
59+
Field::new("cs_ext_discount_amt", DataType::Decimal128(7, 2), true),
60+
Field::new("cs_ext_sales_price", DataType::Decimal128(7, 2), true),
61+
Field::new("cs_ext_wholesale_cost", DataType::Decimal128(7, 2), true),
62+
Field::new("cs_ext_list_price", DataType::Decimal128(7, 2), true),
63+
Field::new("cs_ext_tax", DataType::Decimal128(7, 2), true),
64+
Field::new("cs_coupon_amt", DataType::Decimal128(7, 2), true),
65+
Field::new("cs_ext_ship_cost", DataType::Decimal128(7, 2), true),
66+
Field::new("cs_net_paid", DataType::Decimal128(7, 2), true),
67+
Field::new("cs_net_paid_inc_tax", DataType::Decimal128(7, 2), true),
68+
Field::new("cs_net_paid_inc_ship", DataType::Decimal128(7, 2), true),
69+
Field::new("cs_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true),
70+
Field::new("cs_net_profit", DataType::Decimal128(7, 2), true),
71+
])
72+
}
73+
74+
fn web_sales_schema() -> Schema {
75+
Schema::new(vec![
76+
Field::new("ws_sold_date_sk", DataType::Int64, true),
77+
Field::new("ws_sold_time_sk", DataType::Int64, true),
78+
Field::new("ws_ship_date_sk", DataType::Int64, true),
79+
Field::new("ws_item_sk", DataType::Int64, true),
80+
Field::new("ws_bill_customer_sk", DataType::Int64, true),
81+
Field::new("ws_bill_cdemo_sk", DataType::Int64, true),
82+
Field::new("ws_bill_hdemo_sk", DataType::Int64, true),
83+
Field::new("ws_bill_addr_sk", DataType::Int64, true),
84+
Field::new("ws_ship_customer_sk", DataType::Int64, true),
85+
Field::new("ws_ship_cdemo_sk", DataType::Int64, true),
86+
Field::new("ws_ship_hdemo_sk", DataType::Int64, true),
87+
Field::new("ws_ship_addr_sk", DataType::Int64, true),
88+
Field::new("ws_web_page_sk", DataType::Int64, true),
89+
Field::new("ws_web_site_sk", DataType::Int64, true),
90+
Field::new("ws_ship_mode_sk", DataType::Int64, true),
91+
Field::new("ws_warehouse_sk", DataType::Int64, true),
92+
Field::new("ws_promo_sk", DataType::Int64, true),
93+
Field::new("ws_order_number", DataType::Int64, true),
94+
Field::new("ws_quantity", DataType::Int64, true),
95+
Field::new("ws_wholesale_cost", DataType::Decimal128(7, 2), true),
96+
Field::new("ws_list_price", DataType::Decimal128(7, 2), true),
97+
Field::new("ws_sales_price", DataType::Decimal128(7, 2), true),
98+
Field::new("ws_ext_discount_amt", DataType::Decimal128(7, 2), true),
99+
Field::new("ws_ext_sales_price", DataType::Decimal128(7, 2), true),
100+
Field::new("ws_ext_wholesale_cost", DataType::Decimal128(7, 2), true),
101+
Field::new("ws_ext_list_price", DataType::Decimal128(7, 2), true),
102+
Field::new("ws_ext_tax", DataType::Decimal128(7, 2), true),
103+
Field::new("ws_coupon_amt", DataType::Decimal128(7, 2), true),
104+
Field::new("ws_ext_ship_cost", DataType::Decimal128(7, 2), true),
105+
Field::new("ws_net_paid", DataType::Decimal128(7, 2), true),
106+
Field::new("ws_net_paid_inc_tax", DataType::Decimal128(7, 2), true),
107+
Field::new("ws_net_paid_inc_ship", DataType::Decimal128(7, 2), true),
108+
Field::new("ws_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true),
109+
Field::new("ws_net_profit", DataType::Decimal128(7, 2), true),
110+
])
111+
}
112+
113+
// Helper to create a literal
114+
fn lit_i64(val: i64) -> Arc<dyn PhysicalExpr> {
115+
Arc::new(Literal::new(ScalarValue::Int64(Some(val))))
116+
}
117+
118+
fn lit_i32(val: i32) -> Arc<dyn PhysicalExpr> {
119+
Arc::new(Literal::new(ScalarValue::Int32(Some(val))))
120+
}
121+
122+
fn lit_bool(val: bool) -> Arc<dyn PhysicalExpr> {
123+
Arc::new(Literal::new(ScalarValue::Boolean(Some(val))))
124+
}
125+
126+
// Helper to create binary expressions
127+
fn and(
128+
left: Arc<dyn PhysicalExpr>,
129+
right: Arc<dyn PhysicalExpr>,
130+
) -> Arc<dyn PhysicalExpr> {
131+
Arc::new(BinaryExpr::new(left, Operator::And, right))
132+
}
133+
134+
fn gte(
135+
left: Arc<dyn PhysicalExpr>,
136+
right: Arc<dyn PhysicalExpr>,
137+
) -> Arc<dyn PhysicalExpr> {
138+
Arc::new(BinaryExpr::new(left, Operator::GtEq, right))
139+
}
140+
141+
fn lte(
142+
left: Arc<dyn PhysicalExpr>,
143+
right: Arc<dyn PhysicalExpr>,
144+
) -> Arc<dyn PhysicalExpr> {
145+
Arc::new(BinaryExpr::new(left, Operator::LtEq, right))
146+
}
147+
148+
fn modulo(
149+
left: Arc<dyn PhysicalExpr>,
150+
right: Arc<dyn PhysicalExpr>,
151+
) -> Arc<dyn PhysicalExpr> {
152+
Arc::new(BinaryExpr::new(left, Operator::Modulo, right))
153+
}
154+
155+
fn eq(
156+
left: Arc<dyn PhysicalExpr>,
157+
right: Arc<dyn PhysicalExpr>,
158+
) -> Arc<dyn PhysicalExpr> {
159+
Arc::new(BinaryExpr::new(left, Operator::Eq, right))
160+
}
161+
162+
/// Build a predicate similar to TPC-DS q76 catalog_sales filter.
163+
/// Uses placeholder columns instead of hash expressions.
164+
pub fn catalog_sales_predicate(num_partitions: usize) -> Arc<dyn PhysicalExpr> {
165+
let cs_sold_date_sk: Arc<dyn PhysicalExpr> =
166+
Arc::new(Column::new("cs_sold_date_sk", 0));
167+
let cs_ship_addr_sk: Arc<dyn PhysicalExpr> =
168+
Arc::new(Column::new("cs_ship_addr_sk", 10));
169+
let cs_item_sk: Arc<dyn PhysicalExpr> = Arc::new(Column::new("cs_item_sk", 15));
170+
171+
// Use a simple modulo expression as placeholder for hash
172+
let item_hash_mod = modulo(cs_item_sk.clone(), lit_i64(num_partitions as i64));
173+
let date_hash_mod = modulo(cs_sold_date_sk.clone(), lit_i64(num_partitions as i64));
174+
175+
// cs_ship_addr_sk IS NULL
176+
let is_null_expr: Arc<dyn PhysicalExpr> = Arc::new(IsNullExpr::new(cs_ship_addr_sk));
177+
178+
// Build item_sk CASE expression with num_partitions branches
179+
let item_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
180+
..num_partitions)
181+
.map(|partition| {
182+
let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32));
183+
let then_expr = and(
184+
gte(cs_item_sk.clone(), lit_i64(partition as i64)),
185+
lte(cs_item_sk.clone(), lit_i64(18000)),
186+
);
187+
(when_expr, then_expr)
188+
})
189+
.collect();
190+
191+
let item_case_expr: Arc<dyn PhysicalExpr> =
192+
Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap());
193+
194+
// Build sold_date_sk CASE expression with num_partitions branches
195+
let date_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
196+
..num_partitions)
197+
.map(|partition| {
198+
let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32));
199+
let then_expr = and(
200+
gte(cs_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)),
201+
lte(cs_sold_date_sk.clone(), lit_i64(2488070)),
202+
);
203+
(when_expr, then_expr)
204+
})
205+
.collect();
206+
207+
let date_case_expr: Arc<dyn PhysicalExpr> =
208+
Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap());
209+
210+
// Final: is_null AND item_case AND date_case
211+
and(and(is_null_expr, item_case_expr), date_case_expr)
212+
}
213+
/// Build a predicate similar to TPC-DS q76 web_sales filter.
214+
/// Uses placeholder columns instead of hash expressions.
215+
fn web_sales_predicate(num_partitions: usize) -> Arc<dyn PhysicalExpr> {
216+
let ws_sold_date_sk: Arc<dyn PhysicalExpr> =
217+
Arc::new(Column::new("ws_sold_date_sk", 0));
218+
let ws_item_sk: Arc<dyn PhysicalExpr> = Arc::new(Column::new("ws_item_sk", 3));
219+
let ws_ship_customer_sk: Arc<dyn PhysicalExpr> =
220+
Arc::new(Column::new("ws_ship_customer_sk", 8));
221+
222+
// Use simple modulo expression as placeholder for hash
223+
let item_hash_mod = modulo(ws_item_sk.clone(), lit_i64(num_partitions as i64));
224+
let date_hash_mod = modulo(ws_sold_date_sk.clone(), lit_i64(num_partitions as i64));
225+
226+
// ws_ship_customer_sk IS NULL
227+
let is_null_expr: Arc<dyn PhysicalExpr> =
228+
Arc::new(IsNullExpr::new(ws_ship_customer_sk));
229+
230+
// Build item_sk CASE expression with num_partitions branches
231+
let item_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
232+
..num_partitions)
233+
.map(|partition| {
234+
let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32));
235+
let then_expr = and(
236+
gte(ws_item_sk.clone(), lit_i64(partition as i64)),
237+
lte(ws_item_sk.clone(), lit_i64(18000)),
238+
);
239+
(when_expr, then_expr)
240+
})
241+
.collect();
242+
243+
let item_case_expr: Arc<dyn PhysicalExpr> =
244+
Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap());
245+
246+
// Build sold_date_sk CASE expression with num_partitions branches
247+
let date_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
248+
..num_partitions)
249+
.map(|partition| {
250+
let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32));
251+
let then_expr = and(
252+
gte(ws_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)),
253+
lte(ws_sold_date_sk.clone(), lit_i64(2488070)),
254+
);
255+
(when_expr, then_expr)
256+
})
257+
.collect();
258+
259+
let date_case_expr: Arc<dyn PhysicalExpr> =
260+
Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap());
261+
262+
and(and(is_null_expr, item_case_expr), date_case_expr)
263+
}
264+
265+
/// Measures how long `PhysicalExprSimplifier::simplify` takes for a given expression.
266+
fn bench_simplify(
267+
c: &mut Criterion,
268+
name: &str,
269+
schema: &Schema,
270+
expr: &Arc<dyn PhysicalExpr>,
271+
) {
272+
let simplifier = PhysicalExprSimplifier::new(schema);
273+
c.bench_function(name, |b| {
274+
b.iter(|| black_box(simplifier.simplify(black_box(Arc::clone(expr))).unwrap()))
275+
});
276+
}
277+
278+
fn criterion_benchmark(c: &mut Criterion) {
279+
let cs_schema = catalog_sales_schema();
280+
let ws_schema = web_sales_schema();
281+
282+
for num_partitions in [16, 128] {
283+
bench_simplify(
284+
c,
285+
&format!("tpc-ds/q76/cs/{num_partitions}"),
286+
&cs_schema,
287+
&catalog_sales_predicate(num_partitions),
288+
);
289+
bench_simplify(
290+
c,
291+
&format!("tpc-ds/q76/ws/{num_partitions}"),
292+
&ws_schema,
293+
&web_sales_predicate(num_partitions),
294+
);
295+
}
296+
}
297+
298+
criterion_group!(benches, criterion_benchmark);
299+
criterion_main!(benches);

datafusion/physical-expr/src/simplifier/const_evaluator.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,22 @@ use crate::expressions::{Column, Literal};
4040
/// - `(1 + 2) * 3` -> `9` (with bottom-up traversal)
4141
/// - `'hello' || ' world'` -> `'hello world'`
4242
pub fn simplify_const_expr(
43-
expr: &Arc<dyn PhysicalExpr>,
43+
expr: Arc<dyn PhysicalExpr>,
4444
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
45-
if !can_evaluate_as_constant(expr) {
46-
return Ok(Transformed::no(Arc::clone(expr)));
47-
}
45+
simplify_const_expr_with_dummy(expr, &create_dummy_batch()?)
46+
}
4847

49-
// Create a 1-row dummy batch for evaluation
50-
let batch = create_dummy_batch()?;
48+
pub(crate) fn simplify_const_expr_with_dummy(
49+
expr: Arc<dyn PhysicalExpr>,
50+
batch: &RecordBatch,
51+
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
52+
// If expr is already a const literal or can't be evaluated into one.
53+
if expr.as_any().is::<Literal>() || (!can_evaluate_as_constant(&expr)) {
54+
return Ok(Transformed::no(expr));
55+
}
5156

5257
// Evaluate the expression
53-
match expr.evaluate(&batch) {
58+
match expr.evaluate(batch) {
5459
Ok(ColumnarValue::Scalar(scalar)) => {
5560
Ok(Transformed::yes(Arc::new(Literal::new(scalar))))
5661
}
@@ -61,13 +66,13 @@ pub fn simplify_const_expr(
6166
}
6267
Ok(_) => {
6368
// Unexpected result - keep original expression
64-
Ok(Transformed::no(Arc::clone(expr)))
69+
Ok(Transformed::no(expr))
6570
}
6671
Err(_) => {
6772
// On error, keep original expression
6873
// The expression might succeed at runtime due to short-circuit evaluation
6974
// or other runtime conditions
70-
Ok(Transformed::no(Arc::clone(expr)))
75+
Ok(Transformed::no(expr))
7176
}
7277
}
7378
}
@@ -95,7 +100,7 @@ fn can_evaluate_as_constant(expr: &Arc<dyn PhysicalExpr>) -> bool {
95100
/// that only contain literals, the batch content is irrelevant.
96101
///
97102
/// This is the same approach used in the logical expression `ConstEvaluator`.
98-
fn create_dummy_batch() -> Result<RecordBatch> {
103+
pub(crate) fn create_dummy_batch() -> Result<RecordBatch> {
99104
// RecordBatch requires at least one column
100105
let dummy_schema = Arc::new(Schema::new(vec![Field::new("_", DataType::Null, true)]));
101106
let col = new_null_array(&DataType::Null, 1);

0 commit comments

Comments
 (0)