Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(9678): short circuiting prevented population of visited stack, for common subexpr elimination optimization #9685

Merged
merged 11 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 79 additions & 51 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ use datafusion_common::tree_node::{
TreeNodeVisitor,
};
use datafusion_common::{
internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef,
DataFusionError, Result,
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window};
Expand All @@ -42,8 +41,36 @@ use datafusion_expr::{col, Expr, ExprSchemable};
/// - DataType of this expression.
type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>;

/// Identifier type. Current implementation use describe of an expression (type String) as
/// Identifier.
/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an
/// initial expression walk.
///
/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove
/// common subexpressions.
///
/// Elements in this array are created on the walk down the expression tree
/// during `f_down`. Thus element 0 is the root of the expression tree. The
/// tuple contains:
/// - series_number.
/// - Incremented during `f_up`, start from 1.
/// - Thus, items with higher idx have the lower series_number.
/// - [`Identifier`]
/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination.
///
/// # Example
/// An expression like `(a + b)` would have the following `IdArray`:
/// ```text
/// [
/// (3, "a + b"),
/// (2, "a"),
/// (1, "b")
/// ]
/// ```
type IdArray = Vec<(usize, Identifier)>;
Copy link
Contributor Author

@wiedld wiedld Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This abstraction really helps to understand how the 2 visitors work together.

The first visitor builds the IdArray and ExprSet. The first visitor sets the vector idx used on f_down, whereas the series_number is set on f_up. An empty string for the expr_id means that this expression is not considered for rewrite (is skipped) in the second visitor.
e.g. [(3, expr_id),(2, expr_id),(1, expr_id)]

The second visitor then performs the mutation, based upon the IdArray and ExprSet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THank you -- I agree introducing a new typedef helps to make the code eaiser to understand


/// Identifier for each subexpression.
///
/// Note that the current implementation uses the `Display` of an expression
/// (a `String`) as `Identifier`.
///
/// An identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no
/// collision (as low as possible)"
Expand Down Expand Up @@ -293,8 +320,9 @@ impl CommonSubexprEliminate {
agg_exprs.push(expr.alias(&name));
proj_exprs.push(Expr::Column(Column::from_name(name)));
} else {
let id =
ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten);
let id = ExprIdentifierVisitor::<'static>::expr_identifier(
&expr_rewritten,
);
let out_name =
expr_rewritten.to_field(&new_input_schema)?.qualified_name();
agg_exprs.push(expr_rewritten.alias(&id));
Expand Down Expand Up @@ -557,15 +585,15 @@ impl ExprMask {
/// This visitor implementation use a stack `visit_stack` to track traversal, which
/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called
/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack.
/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem`
/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem`
/// before the first `EnterMark` is considered to be sub-tree of the leaving node.
///
/// This visitor also records identifier in `id_array`. Makes the following traverse
/// pass can get the identifier of a node without recalculate it. We assign each node
/// in the expr tree a series number, start from 1, maintained by `series_number`.
/// Series number represents the order we left (`post_visit`) a node. Has the property
/// Series number represents the order we left (`f_up()`) a node. Has the property
/// that child node's series number always smaller than parent's. While `id_array` is
/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to
/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to
/// get the index of `id_array` for each node.
///
/// `Expr` without sub-expr (column, literal etc.) will not have identifier
Expand All @@ -574,15 +602,15 @@ struct ExprIdentifierVisitor<'a> {
// param
expr_set: &'a mut ExprSet,
/// series number (usize) and identifier.
id_array: &'a mut Vec<(usize, Identifier)>,
id_array: &'a mut IdArray,
/// input schema for the node that we're optimizing, so we can determine the correct datatype
/// for each subexpression
input_schema: DFSchemaRef,
// inner states
visit_stack: Vec<VisitRecord>,
/// increased in pre_visit, start from 0.
/// increased in fn_down, start from 0.
node_count: usize,
/// increased in post_visit, start from 1.
/// increased in fn_up, start from 1.
series_number: usize,
/// which expression should be skipped?
expr_mask: ExprMask,
Expand All @@ -593,66 +621,73 @@ enum VisitRecord {
/// `usize` is the monotone increasing series number assigned in pre_visit().
/// Starts from 0. Is used to index the identifier array `id_array` in post_visit().
EnterMark(usize),
/// the node's children were skipped => jump to f_up on same node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key fix, as I understand it -- the TreeNodeVisitor rewrite removed the notion of skipping sibling nodes during recursion, so this notion must be explicitly encoded in the subexpression rewrite pass.

Copy link
Contributor Author

@wiedld wiedld Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding (please correct here 🙏🏼 ), prior to refactoring the 2 different visitors (TreeNodeVisitor and TreeNodeRewriter) had different transversal rules, and the semantics of skip vs stop mapped in weird ways (as nicely described here in the TreeNode refactor's "Rationale for Change").

Prior to the refactor:

  • the first visitor used skips (VisitRecursion::Skip), which never calls f_up afterwards.
  • the second visitor used RewriteRecursion::Stop which is functionally the same as VisitRecursion::Skip

Fix for the first visitor, for the short-circuited nodes:

Prior to the refactor, the first tree walk did not add short-circuited nodes to the stack:

  1. short-circuited nodes were skipped, without adding to the stack
  2. the old semantics meant skipping the post-visit (up) call
  3. it never looked in the stack (since no post-visit on the old semantics)

Whereas with the new TreeNode refactor the jump became a skip-children-but-call-f_up, which meant:

  1. prior to the fix in this PR, it was using the old control flow. So it still didn't add to visited_stack.
  2. refactor means that it now does perform the f_up on jump
  3. f_up checks the stack (in self.pop_enter_mark()) and would not find it
  4. the original refactor got around this by removing the panic and adding an option

Edited: For the changes done in the TreeNodeVisitor for the short-circuited nodes, I basically made it work with a jump that does the post-visit f_up:

  • add a placeholder to the id_array so that it exists on f_up
    • keep the idx counts correct, for that IdArray abstraction
  • fixing the stack management:
    • push into the stack so it exists on the f_up
      • before this bugfix, it was either not finding it in the stack (None), or occasionally matching another/wrong entry in the stack (such as in our bug test case).
    • make a VisitRecord::JumpMark to be used in the stack
      • therefore the stack entry should always exist, including on short-circuited jumps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's how the API worked before and after #8891. I think the issue came in with this peter-toth#1 adjustment.
Thanks @wiedld for fixing it!

JumpMark(usize),
/// Accumulated identifier of sub expression.
ExprItem(Identifier),
}

impl ExprIdentifierVisitor<'_> {
fn desc_expr(expr: &Expr) -> String {
fn expr_identifier(expr: &Expr) -> Identifier {
format!("{expr}")
}

/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
/// before it.
fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> {
fn pop_enter_mark(&mut self) -> (usize, Identifier) {
let mut desc = String::new();

while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(idx) => {
return Some((idx, desc));
VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => {
return (idx, desc);
}
VisitRecord::ExprItem(s) => {
desc.push_str(&s);
VisitRecord::ExprItem(id) => {
desc.push_str(&id);
}
}
}
None
unreachable!("Enter mark should paired with node number");
Copy link
Contributor Author

@wiedld wiedld Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bring back the pop_enter_mark() contract to the original (before the TreeNode refactor), while also adding in a new JumpMark.

}
}

impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
type Node = Expr;

fn f_down(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
// put placeholder, sets the proper array length
self.id_array.push((0, "".to_string()));

// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a short-circuit expression, skip it.
if expr.short_circuits() || is_volatile_expression(expr)? {
return Ok(TreeNodeRecursion::Jump);
self.visit_stack
.push(VisitRecord::JumpMark(self.node_count));
return Ok(TreeNodeRecursion::Jump); // go to f_up
}

self.visit_stack
.push(VisitRecord::EnterMark(self.node_count));
self.node_count += 1;
// put placeholder
self.id_array.push((0, "".to_string()));

Ok(TreeNodeRecursion::Continue)
}

fn f_up(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
self.series_number += 1;

let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else {
return Ok(TreeNodeRecursion::Continue);
};
Copy link
Contributor Author

@wiedld wiedld Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was one of the bugs, which occurred for short-circuited expr. The behavior was changed with the TreeNode refactor. Previously, it was always returning a sub_expr_desc => then executing the remaining function block.

The changes I made here was to (1) bring it back closer to the original code, while (2) incorporating the jump contract from the consolidated visitor.

The actual process to figure this out was done over several commits. The bug was fixed by build the proper visited_stack, then later I tweaked the fix as I figured out a bit more.

let (idx, sub_expr_identifier) = self.pop_enter_mark();

// skip exprs should not be recognize.
if self.expr_mask.ignores(expr) {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
let curr_expr_identifier = Self::expr_identifier(expr);
self.visit_stack
.push(VisitRecord::ExprItem(curr_expr_identifier));
self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr
return Ok(TreeNodeRecursion::Continue);
}
let mut desc = Self::desc_expr(expr);
desc.push_str(&sub_expr_desc);
let mut desc = Self::expr_identifier(expr);
desc.push_str(&sub_expr_identifier);

self.id_array[idx] = (self.series_number, desc.clone());
self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
Expand Down Expand Up @@ -693,7 +728,7 @@ fn expr_to_identifier(
/// evaluate result of replaced expression.
struct CommonSubexprRewriter<'a> {
expr_set: &'a ExprSet,
id_array: &'a [(usize, Identifier)],
id_array: &'a IdArray,
/// Which identifier is replaced.
affected_id: &'a mut BTreeSet<Identifier>,

Expand All @@ -715,20 +750,26 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
if expr.short_circuits() || is_volatile_expression(&expr)? {
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}

let (series_number, curr_id) = &self.id_array[self.curr_index];
wiedld marked this conversation as resolved.
Show resolved Hide resolved

// halting conditions
if self.curr_index >= self.id_array.len()
|| self.max_series_number > self.id_array[self.curr_index].0
|| self.max_series_number > *series_number
{
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}

let curr_id = &self.id_array[self.curr_index].1;
// skip `Expr`s without identifier (empty identifier).
if curr_id.is_empty() {
self.curr_index += 1;
self.curr_index += 1; // incr idx for id_array, when not jumping
return Ok(Transformed::no(expr));
}

// lookup previously visited expression
match self.expr_set.get(curr_id) {
Some((_, counter, _)) => {
// if has a commonly used (a.k.a. 1+ use) expr
if *counter > 1 {
self.affected_id.insert(curr_id.clone());

Expand All @@ -741,23 +782,10 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
));
}

let (series_number, id) = &self.id_array[self.curr_index];
Copy link
Contributor Author

@wiedld wiedld Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was one of the bugs (see the regression test). It uses a self.curr_index which has already been incr, when it really needs the original self.curr_index for the current expr node being visited. The fix is to grab all of the IdArray[self.curr_index] values once (see new line 717), before all the self.curr_index incr shenanigans.

// incr idx for id_array, when not jumping
self.curr_index += 1;
// Skip sub-node of a replaced tree, or without identifier, or is not repeated expr.
let expr_set_item = self.expr_set.get(id).ok_or_else(|| {
internal_datafusion_err!("expr_set invalid state")
})?;
if *series_number < self.max_series_number
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks are already done earlier. Remove.

|| id.is_empty()
|| expr_set_item.1 <= 1
{
return Ok(Transformed::new(
expr,
false,
TreeNodeRecursion::Jump,
));
}

// series_number was the inverse number ordering (when doing f_up)
self.max_series_number = *series_number;
// step index to skip all sub-node (which has smaller series number).
while self.curr_index < self.id_array.len()
Expand All @@ -771,7 +799,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
Ok(Transformed::new(
col(id).alias(expr_name),
col(curr_id).alias(expr_name),
true,
TreeNodeRecursion::Jump,
))
Expand All @@ -787,7 +815,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {

fn replace_common_expr(
expr: Expr,
id_array: &[(usize, Identifier)],
id_array: &IdArray,
expr_set: &ExprSet,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Expr> {
Expand Down
36 changes: 36 additions & 0 deletions datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2205,3 +2205,39 @@ false true false true
NULL NULL NULL NULL
false false true true
false false true false


#############
## Common Subexpr Eliminate Tests
#############

statement ok
CREATE TABLE doubles (
f64 DOUBLE
) as VALUES
(10.1)
;

# common subexpr with alias
query RRR rowsort
select f64, round(1.0 / f64) as i64_1, acos(round(1.0 / f64)) from doubles;
----
10.1 0 1.570796326795

# common subexpr with coalesce (short-circuited)
query RRR rowsort
select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from doubles;
----
10.1 0.09900990099 1.471623942989

# common subexpr with coalesce (short-circuited) and alias
query RRR rowsort
select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0)) from doubles;
----
10.1 0.09900990099 1.471623942989

# common subexpr with case (short-circuited)
query RRR rowsort
select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles;
----
10.1 0.09900990099 1.471623942989
Loading