-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(9678): short circuiting prevented population of visited stack, for common subexpr elimination optimization #9685
Changes from all commits
f1a01eb
ff46614
de09f72
97938ad
1931b3b
552e0f4
a55ced4
04ddb6e
104eeb7
6223239
8013ca6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,8 +29,7 @@ use datafusion_common::tree_node::{ | |
TreeNodeVisitor, | ||
}; | ||
use datafusion_common::{ | ||
internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef, | ||
DataFusionError, Result, | ||
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, | ||
}; | ||
use datafusion_expr::expr::Alias; | ||
use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; | ||
|
@@ -42,8 +41,36 @@ use datafusion_expr::{col, Expr, ExprSchemable}; | |
/// - DataType of this expression. | ||
type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>; | ||
|
||
/// Identifier type. Current implementation use describe of an expression (type String) as | ||
/// Identifier. | ||
/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an | ||
/// initial expression walk. | ||
/// | ||
/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove | ||
/// common subexpressions. | ||
/// | ||
/// Elements in this array are created on the walk down the expression tree | ||
/// during `f_down`. Thus element 0 is the root of the expression tree. The | ||
/// tuple contains: | ||
/// - series_number. | ||
/// - Incremented during `f_up`, start from 1. | ||
/// - Thus, items with higher idx have the lower series_number. | ||
/// - [`Identifier`] | ||
/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination. | ||
/// | ||
/// # Example | ||
/// An expression like `(a + b)` would have the following `IdArray`: | ||
/// ```text | ||
/// [ | ||
/// (3, "a + b"), | ||
/// (2, "a"), | ||
/// (1, "b") | ||
/// ] | ||
/// ``` | ||
type IdArray = Vec<(usize, Identifier)>; | ||
|
||
/// Identifier for each subexpression. | ||
/// | ||
/// Note that the current implementation uses the `Display` of an expression | ||
/// (a `String`) as `Identifier`. | ||
/// | ||
/// An identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no | ||
/// collision (as low as possible)" | ||
|
@@ -293,8 +320,9 @@ impl CommonSubexprEliminate { | |
agg_exprs.push(expr.alias(&name)); | ||
proj_exprs.push(Expr::Column(Column::from_name(name))); | ||
} else { | ||
let id = | ||
ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten); | ||
let id = ExprIdentifierVisitor::<'static>::expr_identifier( | ||
&expr_rewritten, | ||
); | ||
let out_name = | ||
expr_rewritten.to_field(&new_input_schema)?.qualified_name(); | ||
agg_exprs.push(expr_rewritten.alias(&id)); | ||
|
@@ -557,15 +585,15 @@ impl ExprMask { | |
/// This visitor implementation use a stack `visit_stack` to track traversal, which | ||
/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called | ||
/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. | ||
/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem` | ||
/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` | ||
/// before the first `EnterMark` is considered to be sub-tree of the leaving node. | ||
/// | ||
/// This visitor also records identifier in `id_array`. Makes the following traverse | ||
/// pass can get the identifier of a node without recalculate it. We assign each node | ||
/// in the expr tree a series number, start from 1, maintained by `series_number`. | ||
/// Series number represents the order we left (`post_visit`) a node. Has the property | ||
/// Series number represents the order we left (`f_up()`) a node. Has the property | ||
/// that child node's series number always smaller than parent's. While `id_array` is | ||
/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to | ||
/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to | ||
/// get the index of `id_array` for each node. | ||
/// | ||
/// `Expr` without sub-expr (column, literal etc.) will not have identifier | ||
|
@@ -574,15 +602,15 @@ struct ExprIdentifierVisitor<'a> { | |
// param | ||
expr_set: &'a mut ExprSet, | ||
/// series number (usize) and identifier. | ||
id_array: &'a mut Vec<(usize, Identifier)>, | ||
id_array: &'a mut IdArray, | ||
/// input schema for the node that we're optimizing, so we can determine the correct datatype | ||
/// for each subexpression | ||
input_schema: DFSchemaRef, | ||
// inner states | ||
visit_stack: Vec<VisitRecord>, | ||
/// increased in pre_visit, start from 0. | ||
/// increased in fn_down, start from 0. | ||
node_count: usize, | ||
/// increased in post_visit, start from 1. | ||
/// increased in fn_up, start from 1. | ||
series_number: usize, | ||
/// which expression should be skipped? | ||
expr_mask: ExprMask, | ||
|
@@ -593,66 +621,73 @@ enum VisitRecord { | |
/// `usize` is the monotone increasing series number assigned in pre_visit(). | ||
/// Starts from 0. Is used to index the identifier array `id_array` in post_visit(). | ||
EnterMark(usize), | ||
/// the node's children were skipped => jump to f_up on same node | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the key fix, as I understand it -- the TreeNodeVisitor rewrite removed the notion of skipping sibling nodes during recursion, so this notion must be explicitly encoded in the subexpression rewrite pass. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my understanding (please correct here 🙏🏼 ), prior to refactoring the 2 different visitors ( Prior to the refactor:
Fix for the first visitor, for the short-circuited nodes:Prior to the refactor, the first tree walk did not add short-circuited nodes to the stack:
Whereas with the new TreeNode refactor the jump became a skip-children-but-call-f_up, which meant:
Edited: For the changes done in the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's how the API worked before and after #8891. I think the issue came in with this peter-toth#1 adjustment. |
||
JumpMark(usize), | ||
/// Accumulated identifier of sub expression. | ||
ExprItem(Identifier), | ||
} | ||
|
||
impl ExprIdentifierVisitor<'_> { | ||
fn desc_expr(expr: &Expr) -> String { | ||
fn expr_identifier(expr: &Expr) -> Identifier { | ||
format!("{expr}") | ||
} | ||
|
||
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` | ||
/// before it. | ||
fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { | ||
fn pop_enter_mark(&mut self) -> (usize, Identifier) { | ||
let mut desc = String::new(); | ||
|
||
while let Some(item) = self.visit_stack.pop() { | ||
match item { | ||
VisitRecord::EnterMark(idx) => { | ||
return Some((idx, desc)); | ||
VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => { | ||
return (idx, desc); | ||
} | ||
VisitRecord::ExprItem(s) => { | ||
desc.push_str(&s); | ||
VisitRecord::ExprItem(id) => { | ||
desc.push_str(&id); | ||
} | ||
} | ||
} | ||
None | ||
unreachable!("Enter mark should paired with node number"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bring back the pop_enter_mark() contract to the original (before the TreeNode refactor), while also adding in a new JumpMark. |
||
} | ||
} | ||
|
||
impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { | ||
type Node = Expr; | ||
|
||
fn f_down(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> { | ||
// put placeholder, sets the proper array length | ||
self.id_array.push((0, "".to_string())); | ||
|
||
// related to https://github.com/apache/arrow-datafusion/issues/8814 | ||
// If the expr contain volatile expression or is a short-circuit expression, skip it. | ||
if expr.short_circuits() || is_volatile_expression(expr)? { | ||
return Ok(TreeNodeRecursion::Jump); | ||
self.visit_stack | ||
.push(VisitRecord::JumpMark(self.node_count)); | ||
return Ok(TreeNodeRecursion::Jump); // go to f_up | ||
} | ||
|
||
self.visit_stack | ||
.push(VisitRecord::EnterMark(self.node_count)); | ||
self.node_count += 1; | ||
// put placeholder | ||
self.id_array.push((0, "".to_string())); | ||
|
||
Ok(TreeNodeRecursion::Continue) | ||
} | ||
|
||
fn f_up(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> { | ||
self.series_number += 1; | ||
|
||
let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else { | ||
return Ok(TreeNodeRecursion::Continue); | ||
}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was one of the bugs, which occurred for short-circuited expr. The behavior was changed with the TreeNode refactor. Previously, it was always returning a The changes I made here was to (1) bring it back closer to the original code, while (2) incorporating the jump contract from the consolidated visitor. The actual process to figure this out was done over several commits. The bug was fixed by build the proper visited_stack, then later I tweaked the fix as I figured out a bit more. |
||
let (idx, sub_expr_identifier) = self.pop_enter_mark(); | ||
|
||
// skip exprs should not be recognize. | ||
if self.expr_mask.ignores(expr) { | ||
self.id_array[idx].0 = self.series_number; | ||
let desc = Self::desc_expr(expr); | ||
self.visit_stack.push(VisitRecord::ExprItem(desc)); | ||
let curr_expr_identifier = Self::expr_identifier(expr); | ||
self.visit_stack | ||
.push(VisitRecord::ExprItem(curr_expr_identifier)); | ||
self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr | ||
return Ok(TreeNodeRecursion::Continue); | ||
} | ||
let mut desc = Self::desc_expr(expr); | ||
desc.push_str(&sub_expr_desc); | ||
let mut desc = Self::expr_identifier(expr); | ||
desc.push_str(&sub_expr_identifier); | ||
|
||
self.id_array[idx] = (self.series_number, desc.clone()); | ||
self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); | ||
|
@@ -693,7 +728,7 @@ fn expr_to_identifier( | |
/// evaluate result of replaced expression. | ||
struct CommonSubexprRewriter<'a> { | ||
expr_set: &'a ExprSet, | ||
id_array: &'a [(usize, Identifier)], | ||
id_array: &'a IdArray, | ||
/// Which identifier is replaced. | ||
affected_id: &'a mut BTreeSet<Identifier>, | ||
|
||
|
@@ -715,20 +750,26 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { | |
if expr.short_circuits() || is_volatile_expression(&expr)? { | ||
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); | ||
} | ||
|
||
let (series_number, curr_id) = &self.id_array[self.curr_index]; | ||
wiedld marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// halting conditions | ||
if self.curr_index >= self.id_array.len() | ||
|| self.max_series_number > self.id_array[self.curr_index].0 | ||
|| self.max_series_number > *series_number | ||
{ | ||
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); | ||
} | ||
|
||
let curr_id = &self.id_array[self.curr_index].1; | ||
// skip `Expr`s without identifier (empty identifier). | ||
if curr_id.is_empty() { | ||
self.curr_index += 1; | ||
self.curr_index += 1; // incr idx for id_array, when not jumping | ||
return Ok(Transformed::no(expr)); | ||
} | ||
|
||
// lookup previously visited expression | ||
match self.expr_set.get(curr_id) { | ||
Some((_, counter, _)) => { | ||
// if has a commonly used (a.k.a. 1+ use) expr | ||
if *counter > 1 { | ||
self.affected_id.insert(curr_id.clone()); | ||
|
||
|
@@ -741,23 +782,10 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { | |
)); | ||
} | ||
|
||
let (series_number, id) = &self.id_array[self.curr_index]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was one of the bugs (see the regression test). It uses a |
||
// incr idx for id_array, when not jumping | ||
self.curr_index += 1; | ||
// Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. | ||
let expr_set_item = self.expr_set.get(id).ok_or_else(|| { | ||
internal_datafusion_err!("expr_set invalid state") | ||
})?; | ||
if *series_number < self.max_series_number | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These checks are already done earlier. Remove. |
||
|| id.is_empty() | ||
|| expr_set_item.1 <= 1 | ||
{ | ||
return Ok(Transformed::new( | ||
expr, | ||
false, | ||
TreeNodeRecursion::Jump, | ||
)); | ||
} | ||
|
||
// series_number was the inverse number ordering (when doing f_up) | ||
self.max_series_number = *series_number; | ||
// step index to skip all sub-node (which has smaller series number). | ||
while self.curr_index < self.id_array.len() | ||
|
@@ -771,7 +799,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { | |
// `projection_push_down` optimizer use "expr name" to eliminate useless | ||
// projections. | ||
Ok(Transformed::new( | ||
col(id).alias(expr_name), | ||
col(curr_id).alias(expr_name), | ||
true, | ||
TreeNodeRecursion::Jump, | ||
)) | ||
|
@@ -787,7 +815,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { | |
|
||
fn replace_common_expr( | ||
expr: Expr, | ||
id_array: &[(usize, Identifier)], | ||
id_array: &IdArray, | ||
expr_set: &ExprSet, | ||
affected_id: &mut BTreeSet<Identifier>, | ||
) -> Result<Expr> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This abstraction really helps to understand how the 2 visitors work together.
The first visitor builds the
IdArray
andExprSet
. The first visitor sets the vector idx used on f_down, whereas the series_number is set on f_up. An empty string for the expr_id means that this expression is not considered for rewrite (is skipped) in the second visitor.e.g.
[(3, expr_id),(2, expr_id),(1, expr_id)]
The second visitor then performs the mutation, based upon the
IdArray
andExprSet
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THank you -- I agree introducing a new typedef helps to make the code eaiser to understand