Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix bug where alias was ignored in COUNT(*) optimization #14738

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions crates/polars-plan/src/logical_plan/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub enum FunctionNode {
Count {
paths: Arc<[PathBuf]>,
scan_type: FileScan,
alias: Option<Arc<str>>,
},
#[cfg_attr(feature = "serde", serde(skip))]
Pipeline {
Expand Down Expand Up @@ -198,9 +199,11 @@ impl FunctionNode {
Ok(Cow::Owned(Arc::new(schema)))
},
DropNulls { .. } => Ok(Cow::Borrowed(input_schema)),
Count { .. } => {
Count { alias, .. } => {
let mut schema: Schema = Schema::with_capacity(1);
schema.insert_at_index(0, SmartString::from("len"), IDX_DTYPE)?;
let name =
SmartString::from(alias.as_ref().map(|alias| alias.as_ref()).unwrap_or("len"));
schema.insert_at_index(0, name, IDX_DTYPE)?;
Ok(Cow::Owned(Arc::new(schema)))
},
Rechunk => Ok(Cow::Borrowed(input_schema)),
Expand Down Expand Up @@ -323,7 +326,9 @@ impl FunctionNode {
}
},
DropNulls { subset } => df.drop_nulls(Some(subset.as_ref())),
Count { paths, scan_type } => count::count_rows(paths, scan_type),
Count {
paths, scan_type, ..
} => count::count_rows(paths, scan_type),
Rechunk => {
df.as_single_chunk_par();
Ok(df)
Expand Down
124 changes: 82 additions & 42 deletions crates/polars-plan/src/logical_plan/optimizer/count_star.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,70 @@ impl OptimizationRule for CountStar {
fn optimize_plan(
&mut self,
lp_arena: &mut Arena<ALogicalPlan>,
_expr_arena: &mut Arena<AExpr>,
expr_arena: &mut Arena<AExpr>,
node: Node,
) -> Option<ALogicalPlan> {
let mut paths = Vec::new();
visit_logical_plan_for_scan_paths(&mut paths, node, lp_arena).map(|scan_type| {
// MapFunction needs a leaf node, hence we create a dummy placeholder node
let placeholder = ALogicalPlan::DataFrameScan {
df: Arc::new(Default::default()),
schema: Arc::new(Default::default()),
output_schema: None,
projection: None,
selection: None,
};
let placeholder_node = lp_arena.add(placeholder);
visit_logical_plan_for_scan_paths(node, lp_arena, expr_arena, false).map(
|count_star_expr| {
// MapFunction needs a leaf node, hence we create a dummy placeholder node
let placeholder = ALogicalPlan::DataFrameScan {
df: Arc::new(Default::default()),
schema: Arc::new(Default::default()),
output_schema: None,
projection: None,
selection: None,
};
let placeholder_node = lp_arena.add(placeholder);

let sliced_paths: Arc<[PathBuf]> = paths.into();
let alp = ALogicalPlan::MapFunction {
input: placeholder_node,
function: FunctionNode::Count {
paths: count_star_expr.paths,
scan_type: count_star_expr.scan_type,
alias: count_star_expr.alias,
},
};

let alp = ALogicalPlan::MapFunction {
input: placeholder_node,
function: FunctionNode::Count {
paths: sliced_paths,
scan_type,
},
};
lp_arena.replace(node, alp.clone());
alp
})
lp_arena.replace(count_star_expr.node, alp.clone());
alp
},
)
}
}

// Visit the logical plan and return the file paths / scan type
struct CountStarExpr {
// Top node of the projection to replace
node: Node,
// Paths to the input files
paths: Arc<[PathBuf]>,
// File Type
scan_type: FileScan,
// Column Alias
alias: Option<Arc<str>>,
}

// Visit the logical plan and return CountStarExpr with the expr information gathered
// Return None if query is not a simple COUNT(*) FROM SOURCE
fn visit_logical_plan_for_scan_paths(
all_paths: &mut Vec<PathBuf>,
node: Node,
lp_arena: &Arena<ALogicalPlan>,
) -> Option<FileScan> {
expr_arena: &Arena<AExpr>,
inside_union: bool, // Inside union's we do not check for COUNT(*) expression
) -> Option<CountStarExpr> {
match lp_arena.get(node) {
ALogicalPlan::Union { inputs, .. } => {
// Preallocate right amount in case of globbing
if all_paths.is_empty() {
let _ = std::mem::replace(all_paths, Vec::with_capacity(inputs.len()));
}
let mut scan_type = None;
let mut scan_type: Option<FileScan> = None;
let mut paths = Vec::with_capacity(inputs.len());
for input in inputs {
match visit_logical_plan_for_scan_paths(all_paths, *input, lp_arena) {
Some(leaf_scan_type) => {
match visit_logical_plan_for_scan_paths(*input, lp_arena, expr_arena, true) {
Some(expr) => {
paths.extend(expr.paths.iter().cloned());
match &scan_type {
None => scan_type = Some(leaf_scan_type),
None => scan_type = Some(expr.scan_type),
Some(scan_type) => {
// All scans must be of the same type (e.g. csv / parquet)
if std::mem::discriminant(scan_type)
!= std::mem::discriminant(&leaf_scan_type)
!= std::mem::discriminant(&expr.scan_type)
{
return None;
}
Expand All @@ -77,17 +88,46 @@ fn visit_logical_plan_for_scan_paths(
None => return None,
}
}
scan_type
Some(CountStarExpr {
paths: paths.into(),
scan_type: scan_type.unwrap(),
node,
alias: None,
})
},
ALogicalPlan::Scan {
scan_type, paths, ..
} if !matches!(scan_type, FileScan::Anonymous { .. }) => {
all_paths.extend(paths.iter().cloned());
Some(scan_type.clone())
},
ALogicalPlan::Projection { input, .. } => {
visit_logical_plan_for_scan_paths(all_paths, *input, lp_arena)
} if !matches!(scan_type, FileScan::Anonymous { .. }) => Some(CountStarExpr {
paths: paths.clone(),
scan_type: scan_type.clone(),
node,
alias: None,
}),
ALogicalPlan::Projection { input, expr, .. } => {
if expr.len() == 1 {
let (valid, alias) = is_valid_count_expr(expr[0], expr_arena);
if valid || inside_union {
return visit_logical_plan_for_scan_paths(*input, lp_arena, expr_arena, false)
.map(|mut expr| {
expr.alias = alias;
expr.node = node;
expr
});
}
}
None
},
_ => None,
}
}

fn is_valid_count_expr(node: Node, expr_arena: &Arena<AExpr>) -> (bool, Option<Arc<str>>) {
match expr_arena.get(node) {
AExpr::Alias(node, alias) => {
let (valid, _) = is_valid_count_expr(*node, expr_arena);
(valid, Some(alias.clone()))
},
AExpr::Len => (true, None),
_ => (false, None),
}
}
Loading