diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index a67c871e4cbf..1b8705c5c497 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -61,7 +61,7 @@ use datafusion_expr::{ expr::{Alias, ScalarFunction}, is_null, lit, utils::COUNT_STAR_EXPANSION, - SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, + Projection, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_functions::core::coalesce; use datafusion_functions_aggregate::expr_fn::{ @@ -1963,7 +1963,7 @@ impl DataFrame { /// # } /// ``` pub fn with_column_renamed( - self, + mut self, old_name: impl Into, new_name: &str, ) -> Result { @@ -1972,41 +1972,85 @@ impl DataFrame { .config_options() .sql_parser .enable_ident_normalization; + let old_column: Column = if ident_opts { Column::from_qualified_name(old_name) } else { Column::from_qualified_name_ignore_case(old_name) }; - let (qualifier_rename, field_rename) = - match self.plan.schema().qualified_field_from_column(&old_column) { - Ok(qualifier_and_field) => qualifier_and_field, - // no-op if field not found - Err(DataFusionError::SchemaError( - SchemaError::FieldNotFound { .. }, - _, - )) => return Ok(self), - Err(err) => return Err(err), - }; - let projection = self - .plan - .schema() - .iter() - .map(|(qualifier, field)| { - if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename { - ( - col(Column::from((qualifier, field))) - .alias_qualified(qualifier.cloned(), new_name), - false, - ) - } else { - (col(Column::from((qualifier, field))), false) - } - }) - .collect::>(); - let project_plan = LogicalPlanBuilder::from(self.plan) - .project_with_validation(projection)? - .build()?; + let project_plan = if let LogicalPlan::Projection(Projection { + expr, + input, + schema, + .. + }) = self.plan + { + // special case: we already have a projection on top, so we can reuse it rather than creating a new one + let (qualifier_rename, field_rename) = + match schema.qualified_field_from_column(&old_column) { + Ok(qualifier_and_field) => qualifier_and_field, + // no-op if field not found + Err(DataFusionError::SchemaError( + SchemaError::FieldNotFound { .. }, + _, + )) => { + self.plan = LogicalPlan::Projection( + Projection::try_new_with_schema(expr, input, schema)?, + ); + return Ok(self); + } + Err(err) => return Err(err), + }; + + let expr: Vec<_> = expr + .into_iter() + .map(|e| { + let (qualifier, field) = e.qualified_name(); + + if qualifier.as_ref().eq(&qualifier_rename) + && field.as_str() == field_rename.name() + { + e.alias_qualified(qualifier, new_name.to_string()) + } else { + e + } + }) + .collect(); + LogicalPlan::Projection(Projection::try_new(expr, input)?) + } else { + let (qualifier_rename, field_rename) = + match self.plan.schema().qualified_field_from_column(&old_column) { + Ok(qualifier_and_field) => qualifier_and_field, + // no-op if field not found + Err(DataFusionError::SchemaError( + SchemaError::FieldNotFound { .. }, + _, + )) => return Ok(self), + Err(err) => return Err(err), + }; + + let projection = self + .plan + .schema() + .iter() + .map(|(qualifier, field)| { + if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename { + ( + col(Column::from((qualifier, field))) + .alias_qualified(qualifier.cloned(), new_name), + false, + ) + } else { + (col(Column::from((qualifier, field))), false) + } + }) + .collect::>(); + + LogicalPlanBuilder::from(self.plan) + .project_with_validation(projection)? + .build()? + }; Ok(DataFrame { session_state: self.session_state, plan: project_plan, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 2a1d6426872e..d3c6ac45af0d 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1648,6 +1648,18 @@ async fn with_column_renamed() -> Result<()> { // no-op for missing column .with_column_renamed("c4", "boom")?; + // one projection is reused for all renames + assert_snapshot!( + df_sum_renamed.logical_plan(), + @r#" + Projection: aggregate_test_100.c1 AS one, aggregate_test_100.c2 AS two, aggregate_test_100.c3, aggregate_test_100.c2 + aggregate_test_100.c3 AS sum AS total + Limit: skip=0, fetch=1 + Sort: aggregate_test_100.c1 ASC NULLS FIRST, aggregate_test_100.c2 ASC NULLS FIRST, aggregate_test_100.c3 ASC NULLS FIRST + Filter: aggregate_test_100.c2 = Int32(3) AND aggregate_test_100.c1 = Utf8("a") + Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3 + TableScan: aggregate_test_100 + "#); + let references: Vec<_> = df_sum_renamed .schema() .iter()