diff --git a/crates/polars-plan/src/logical_plan/functions/mod.rs b/crates/polars-plan/src/logical_plan/functions/mod.rs index d1beeb3adaac..a692084a5539 100644 --- a/crates/polars-plan/src/logical_plan/functions/mod.rs +++ b/crates/polars-plan/src/logical_plan/functions/mod.rs @@ -50,6 +50,7 @@ pub enum FunctionNode { Count { paths: Arc<[PathBuf]>, scan_type: FileScan, + alias: Option>, }, #[cfg_attr(feature = "serde", serde(skip))] Pipeline { @@ -198,9 +199,11 @@ impl FunctionNode { Ok(Cow::Owned(Arc::new(schema))) }, DropNulls { .. } => Ok(Cow::Borrowed(input_schema)), - Count { .. } => { + Count { alias, .. } => { let mut schema: Schema = Schema::with_capacity(1); - schema.insert_at_index(0, SmartString::from("len"), IDX_DTYPE)?; + let name = + SmartString::from(alias.as_ref().map(|alias| alias.as_ref()).unwrap_or("len")); + schema.insert_at_index(0, name, IDX_DTYPE)?; Ok(Cow::Owned(Arc::new(schema))) }, Rechunk => Ok(Cow::Borrowed(input_schema)), @@ -323,7 +326,9 @@ impl FunctionNode { } }, DropNulls { subset } => df.drop_nulls(Some(subset.as_ref())), - Count { paths, scan_type } => count::count_rows(paths, scan_type), + Count { + paths, scan_type, .. + } => count::count_rows(paths, scan_type), Rechunk => { df.as_single_chunk_par(); Ok(df) diff --git a/crates/polars-plan/src/logical_plan/optimizer/count_star.rs b/crates/polars-plan/src/logical_plan/optimizer/count_star.rs index fcbe1e61d762..357e1bd60839 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/count_star.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/count_star.rs @@ -15,59 +15,70 @@ impl OptimizationRule for CountStar { fn optimize_plan( &mut self, lp_arena: &mut Arena, - _expr_arena: &mut Arena, + expr_arena: &mut Arena, node: Node, ) -> Option { - let mut paths = Vec::new(); - visit_logical_plan_for_scan_paths(&mut paths, node, lp_arena).map(|scan_type| { - // MapFunction needs a leaf node, hence we create a dummy placeholder node - let placeholder = ALogicalPlan::DataFrameScan { - df: Arc::new(Default::default()), - schema: Arc::new(Default::default()), - output_schema: None, - projection: None, - selection: None, - }; - let placeholder_node = lp_arena.add(placeholder); + visit_logical_plan_for_scan_paths(node, lp_arena, expr_arena, false).map( + |count_star_expr| { + // MapFunction needs a leaf node, hence we create a dummy placeholder node + let placeholder = ALogicalPlan::DataFrameScan { + df: Arc::new(Default::default()), + schema: Arc::new(Default::default()), + output_schema: None, + projection: None, + selection: None, + }; + let placeholder_node = lp_arena.add(placeholder); - let sliced_paths: Arc<[PathBuf]> = paths.into(); + let alp = ALogicalPlan::MapFunction { + input: placeholder_node, + function: FunctionNode::Count { + paths: count_star_expr.paths, + scan_type: count_star_expr.scan_type, + alias: count_star_expr.alias, + }, + }; - let alp = ALogicalPlan::MapFunction { - input: placeholder_node, - function: FunctionNode::Count { - paths: sliced_paths, - scan_type, - }, - }; - lp_arena.replace(node, alp.clone()); - alp - }) + lp_arena.replace(count_star_expr.node, alp.clone()); + alp + }, + ) } } -// Visit the logical plan and return the file paths / scan type +struct CountStarExpr { + // Top node of the projection to replace + node: Node, + // Paths to the input files + paths: Arc<[PathBuf]>, + // File Type + scan_type: FileScan, + // Column Alias + alias: Option>, +} + +// Visit the logical plan and return CountStarExpr with the expr information gathered // Return None if query is not a simple COUNT(*) FROM SOURCE fn visit_logical_plan_for_scan_paths( - all_paths: &mut Vec, node: Node, lp_arena: &Arena, -) -> Option { + expr_arena: &Arena, + inside_union: bool, // Inside union's we do not check for COUNT(*) expression +) -> Option { match lp_arena.get(node) { ALogicalPlan::Union { inputs, .. } => { - // Preallocate right amount in case of globbing - if all_paths.is_empty() { - let _ = std::mem::replace(all_paths, Vec::with_capacity(inputs.len())); - } - let mut scan_type = None; + let mut scan_type: Option = None; + let mut paths = Vec::with_capacity(inputs.len()); for input in inputs { - match visit_logical_plan_for_scan_paths(all_paths, *input, lp_arena) { - Some(leaf_scan_type) => { + match visit_logical_plan_for_scan_paths(*input, lp_arena, expr_arena, true) { + Some(expr) => { + paths.extend(expr.paths.iter().cloned()); match &scan_type { - None => scan_type = Some(leaf_scan_type), + None => scan_type = Some(expr.scan_type), Some(scan_type) => { // All scans must be of the same type (e.g. csv / parquet) if std::mem::discriminant(scan_type) - != std::mem::discriminant(&leaf_scan_type) + != std::mem::discriminant(&expr.scan_type) { return None; } @@ -77,17 +88,46 @@ fn visit_logical_plan_for_scan_paths( None => return None, } } - scan_type + Some(CountStarExpr { + paths: paths.into(), + scan_type: scan_type.unwrap(), + node, + alias: None, + }) }, ALogicalPlan::Scan { scan_type, paths, .. - } if !matches!(scan_type, FileScan::Anonymous { .. }) => { - all_paths.extend(paths.iter().cloned()); - Some(scan_type.clone()) - }, - ALogicalPlan::Projection { input, .. } => { - visit_logical_plan_for_scan_paths(all_paths, *input, lp_arena) + } if !matches!(scan_type, FileScan::Anonymous { .. }) => Some(CountStarExpr { + paths: paths.clone(), + scan_type: scan_type.clone(), + node, + alias: None, + }), + ALogicalPlan::Projection { input, expr, .. } => { + if expr.len() == 1 { + let (valid, alias) = is_valid_count_expr(expr[0], expr_arena); + if valid || inside_union { + return visit_logical_plan_for_scan_paths(*input, lp_arena, expr_arena, false) + .map(|mut expr| { + expr.alias = alias; + expr.node = node; + expr + }); + } + } + None }, _ => None, } } + +fn is_valid_count_expr(node: Node, expr_arena: &Arena) -> (bool, Option>) { + match expr_arena.get(node) { + AExpr::Alias(node, alias) => { + let (valid, _) = is_valid_count_expr(*node, expr_arena); + (valid, Some(alias.clone())) + }, + AExpr::Len => (true, None), + _ => (false, None), + } +}