diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 1ad28a8fe0489..25f46f1b9d7c4 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -466,7 +466,7 @@ impl ColPrunable for LogicalAgg { }; let input_required_cols = { - let mut tmp: FixedBitSet = upstream_required_cols.clone(); + let mut tmp: FixedBitSet = upstream_required_cols; tmp.union_with(&group_key_required_cols); tmp.union_with(&agg_call_required_cols); tmp.ones().collect_vec() @@ -499,24 +499,25 @@ impl ColPrunable for LogicalAgg { self.input.prune_col(&input_required_cols), ) }; - - if group_key_required_cols.is_subset(&upstream_required_cols) { + let new_output_cols = { + let mapping = self.i2o_col_mapping(); + let mut tmp = input_required_cols + .iter() + .filter_map(|&idx| mapping.try_map(idx)) + .collect_vec(); + tmp.extend( + required_cols + .iter() + .filter(|&&index| index >= self.group_keys.len()), + ); + tmp + }; + if new_output_cols == required_cols { + // current schema perfectly fit the required columns agg.into() } else { - // Some group key columns are not needed - let new_output_cols = { - let mapping = self.i2o_col_mapping(); - let mut tmp = input_required_cols - .iter() - .filter_map(|&idx| mapping.try_map(idx)) - .collect_vec(); - tmp.extend( - required_cols - .iter() - .filter(|&&index| index >= self.group_keys.len()), - ); - tmp - }; + // some columns are not needed, or the order need to be adjusted. + // so we did a projection to remove/reorder the columns. let mapping = &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len()); let output_required_cols = required_cols @@ -722,6 +723,24 @@ mod tests { assert_eq!(group_keys, vec![0]); } } + /// Generate a agg call node with given [`DataType`] and fields. + /// For example, `generate_agg_call(Int32, [v1, v2, v3])` will result in: + /// ```text + /// Agg(min(input_ref(2))) group by (input_ref(1)) + /// TableScan(v1, v2, v3) + /// ``` + async fn generate_agg_call(ty: DataType, fields: Vec) -> LogicalAgg { + let ctx = OptimizerContext::mock().await; + + let values = LogicalValues::new(vec![], Schema { fields }, ctx); + let agg_call = PlanAggCall { + agg_kind: AggKind::Min, + return_type: ty.clone(), + inputs: vec![InputRef::new(2, ty.clone())], + distinct: false, + }; + LogicalAgg::new(vec![agg_call], vec![1], values.into()) + } #[tokio::test] /// Pruning @@ -733,29 +752,15 @@ mod tests { /// ```text /// Agg(min(input_ref(1))) group by (input_ref(0)) /// TableScan(v2, v3) + /// ``` async fn test_prune_all() { let ty = DataType::Int32; - let ctx = OptimizerContext::mock().await; let fields: Vec = vec![ Field::with_name(ty.clone(), "v1"), Field::with_name(ty.clone(), "v2"), Field::with_name(ty.clone(), "v3"), ]; - let values = LogicalValues::new( - vec![], - Schema { - fields: fields.clone(), - }, - ctx, - ); - let agg_call = PlanAggCall { - agg_kind: AggKind::Min, - return_type: ty.clone(), - inputs: vec![InputRef::new(2, ty.clone())], - distinct: false, - }; - let agg = LogicalAgg::new(vec![agg_call], vec![1], values.into()); - + let agg = generate_agg_call(ty.clone(), fields.clone()).await; // Perform the prune let required_cols = vec![0, 1]; let plan = agg.prune_col(&required_cols); @@ -775,6 +780,49 @@ mod tests { assert_eq!(values.schema().fields(), &fields[1..]); } + #[tokio::test] + /// Pruning + /// ```text + /// Agg(min(input_ref(2))) group by (input_ref(1)) + /// TableScan(v1, v2, v3) + /// ``` + /// with required columns [1,0] (all columns, with reversed order) will result in + /// ```text + /// Project [input_ref(1), input_ref(0)] + /// Agg(min(input_ref(1))) group by (input_ref(0)) + /// TableScan(v2, v3) + /// ``` + async fn test_prune_all_with_order_required() { + let ty = DataType::Int32; + let fields: Vec = vec![ + Field::with_name(ty.clone(), "v1"), + Field::with_name(ty.clone(), "v2"), + Field::with_name(ty.clone(), "v3"), + ]; + let agg = generate_agg_call(ty.clone(), fields.clone()).await; + // Perform the prune + let required_cols = vec![1, 0]; + let plan = agg.prune_col(&required_cols); + // Check the result + let proj = plan.as_logical_project().unwrap(); + assert_eq!(proj.exprs().len(), 2); + assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1); + assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0); + let proj_input = proj.input(); + let agg_new = proj_input.as_logical_agg().unwrap(); + assert_eq!(agg_new.group_keys(), vec![0]); + + assert_eq!(agg_new.agg_calls.len(), 1); + let agg_call_new = agg_new.agg_calls[0].clone(); + assert_eq!(agg_call_new.agg_kind, AggKind::Min); + assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]); + assert_eq!(agg_call_new.return_type, ty); + + let values = agg_new.input(); + let values = values.as_logical_values().unwrap(); + assert_eq!(values.schema().fields(), &fields[1..]); + } + #[tokio::test] /// Pruning /// ```text @@ -786,6 +834,7 @@ mod tests { /// Project(input_ref(1)) /// Agg(min(input_ref(1))) group by (input_ref(0)) /// TableScan(v2, v3) + /// ``` async fn test_prune_group_key() { let ctx = OptimizerContext::mock().await; let ty = DataType::Int32; @@ -846,6 +895,7 @@ mod tests { /// Project(input_ref(0), input_ref(2)) /// Agg(max(input_ref(0))) group by (input_ref(0), input_ref(1)) /// TableScan(v2, v3) + /// ``` async fn test_prune_agg() { let ty = DataType::Int32; let ctx = OptimizerContext::mock().await; diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index 7edef67aead71..5c7d229c32cda 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -856,4 +856,83 @@ mod tests { // let hash_join = result.as_stream_hash_join().unwrap(); // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond); } + /// Pruning + /// ```text + /// Join(on: input_ref(1)=input_ref(3)) + /// TableScan(v1, v2, v3) + /// TableScan(v4, v5, v6) + /// ``` + /// with required columns [3, 2] will result in + /// ```text + /// Project(input_ref(2), input_ref(1)) + /// Join(on: input_ref(0)=input_ref(2)) + /// TableScan(v2, v3) + /// TableScan(v4) + /// ``` + #[tokio::test] + async fn test_join_column_prune_with_order_required() { + let ty = DataType::Int32; + let ctx = OptimizerContext::mock().await; + let fields: Vec = (1..7) + .map(|i| Field::with_name(ty.clone(), format!("v{}", i))) + .collect(); + let left = LogicalValues::new( + vec![], + Schema { + fields: fields[0..3].to_vec(), + }, + ctx.clone(), + ); + let right = LogicalValues::new( + vec![], + Schema { + fields: fields[3..6].to_vec(), + }, + ctx, + ); + let on: ExprImpl = ExprImpl::FunctionCall(Box::new( + FunctionCall::new( + Type::Equal, + vec![ + ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))), + ExprImpl::InputRef(Box::new(InputRef::new(3, ty))), + ], + ) + .unwrap(), + )); + let join_type = JoinType::Inner; + let join = LogicalJoin::new( + left.into(), + right.into(), + join_type, + Condition::with_expr(on), + ); + + // Perform the prune + let required_cols = vec![3, 2]; + let plan = join.prune_col(&required_cols); + + // Check the result + let project = plan.as_logical_project().unwrap(); + assert_eq!(project.exprs().len(), 2); + assert_eq_input_ref!(&project.exprs()[0], 2); + assert_eq_input_ref!(&project.exprs()[1], 1); + + let join = project.input(); + let join = join.as_logical_join().unwrap(); + assert_eq!(join.schema().fields().len(), 3); + assert_eq!(join.schema().fields(), &fields[1..4]); + + let expr: ExprImpl = join.on.clone().into(); + let call = expr.as_function_call().unwrap(); + assert_eq_input_ref!(&call.inputs()[0], 0); + assert_eq_input_ref!(&call.inputs()[1], 2); + + let left = join.left(); + let left = left.as_logical_values().unwrap(); + assert_eq!(left.schema().fields(), &fields[1..3]); + let right = join.right(); + let right = right.as_logical_values().unwrap(); + assert_eq!(right.schema().fields(), &fields[3..4]); + } }