Skip to content

fix: Dialect requires derived table alias #12994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ pub trait Dialect: Send + Sync {
true
}

/// Whether the dialect requires a table alias for any subquery in the FROM clause
/// This affects behavior when deriving logical plans for Sort, Limit, etc.
fn requires_derived_table_alias(&self) -> bool {
false
}

/// Allows the dialect to override scalar function unparsing if the dialect has specific rules.
/// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is
/// a custom implementation for the function.
Expand Down Expand Up @@ -300,6 +306,10 @@ impl Dialect for MySqlDialect {
ast::DataType::Datetime(None)
}

fn requires_derived_table_alias(&self) -> bool {
true
}

fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
Expand Down Expand Up @@ -362,6 +372,7 @@ pub struct CustomDialect {
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: sqlparser::ast::DataType,
supports_column_alias_in_table_alias: bool,
requires_derived_table_alias: bool,
}

impl Default for CustomDialect {
Expand All @@ -384,6 +395,7 @@ impl Default for CustomDialect {
),
date32_cast_dtype: sqlparser::ast::DataType::Date,
supports_column_alias_in_table_alias: true,
requires_derived_table_alias: false,
}
}
}
Expand Down Expand Up @@ -472,6 +484,10 @@ impl Dialect for CustomDialect {

Ok(None)
}

fn requires_derived_table_alias(&self) -> bool {
self.requires_derived_table_alias
}
}

/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
Expand Down Expand Up @@ -503,6 +519,7 @@ pub struct CustomDialectBuilder {
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: ast::DataType,
supports_column_alias_in_table_alias: bool,
requires_derived_table_alias: bool,
}

impl Default for CustomDialectBuilder {
Expand Down Expand Up @@ -531,6 +548,7 @@ impl CustomDialectBuilder {
),
date32_cast_dtype: sqlparser::ast::DataType::Date,
supports_column_alias_in_table_alias: true,
requires_derived_table_alias: false,
}
}

Expand All @@ -551,6 +569,7 @@ impl CustomDialectBuilder {
date32_cast_dtype: self.date32_cast_dtype,
supports_column_alias_in_table_alias: self
.supports_column_alias_in_table_alias,
requires_derived_table_alias: self.requires_derived_table_alias,
}
}

Expand Down Expand Up @@ -653,4 +672,12 @@ impl CustomDialectBuilder {
self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias;
self
}

pub fn with_requires_derived_table_alias(
mut self,
requires_derived_table_alias: bool,
) -> Self {
self.requires_derived_table_alias = requires_derived_table_alias;
self
}
}
57 changes: 50 additions & 7 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,14 @@ impl Unparser<'_> {
Ok(())
}

fn derive(&self, plan: &LogicalPlan, relation: &mut RelationBuilder) -> Result<()> {
fn derive(
&self,
plan: &LogicalPlan,
relation: &mut RelationBuilder,
alias: Option<ast::TableAlias>,
) -> Result<()> {
let mut derived_builder = DerivedRelationBuilder::default();
derived_builder.lateral(false).alias(None).subquery({
derived_builder.lateral(false).alias(alias).subquery({
let inner_statement = self.plan_to_sql(plan)?;
if let ast::Statement::Query(inner_query) = inner_statement {
inner_query
Expand All @@ -239,6 +244,23 @@ impl Unparser<'_> {
Ok(())
}

fn derive_with_dialect_alias(
&self,
alias: &str,
plan: &LogicalPlan,
relation: &mut RelationBuilder,
) -> Result<()> {
if self.dialect.requires_derived_table_alias() {
self.derive(
plan,
relation,
Some(self.new_table_alias(alias.to_string(), vec![])),
)
} else {
self.derive(plan, relation, None)
}
}

fn select_to_sql_recursively(
&self,
plan: &LogicalPlan,
Expand Down Expand Up @@ -284,7 +306,11 @@ impl Unparser<'_> {

// Projection can be top-level plan for derived table
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_projection",
plan,
relation,
);
}
self.reconstruct_select_statement(plan, p, select)?;
self.select_to_sql_recursively(p.input.as_ref(), query, select, relation)
Expand All @@ -311,8 +337,13 @@ impl Unparser<'_> {
LogicalPlan::Limit(limit) => {
// Limit can be top-level plan for derived table
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_limit",
plan,
relation,
);
}

if let Some(fetch) = limit.fetch {
let Some(query) = query.as_mut() else {
return internal_err!(
Expand Down Expand Up @@ -350,7 +381,11 @@ impl Unparser<'_> {
LogicalPlan::Sort(sort) => {
// Sort can be top-level plan for derived table
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_sort",
plan,
relation,
);
}
let Some(query_ref) = query else {
return internal_err!(
Expand Down Expand Up @@ -396,7 +431,11 @@ impl Unparser<'_> {
LogicalPlan::Distinct(distinct) => {
// Distinct can be top-level plan for derived table
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_distinct",
plan,
relation,
);
}
let (select_distinct, input) = match distinct {
Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()),
Expand Down Expand Up @@ -559,7 +598,11 @@ impl Unparser<'_> {

// Covers cases where the UNION is a subquery and the projection is at the top level
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_union",
plan,
relation,
);
}

let input_exprs: Vec<SetExpr> = union
Expand Down
39 changes: 39 additions & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,45 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
unparser_dialect: Box<dyn UnparserDialect>,
}
let tests: Vec<TestStatementWithDialect> = vec![
TestStatementWithDialect {
sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;",
expected:
// top projection sort gets derived into a subquery
// for MySQL, this subquery needs an alias
"SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the test only covers the derived_sort case but some cases aren't covered, such as derived_projection, derived_limit,...
Could you add more tests for them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added another test case for a more complex SQL that results in 2 nested derives for derived_sort and derived_distinct in the same query.

I'm not too sure what SQL to use to trigger derived_limit or derived_projection though... any ideas that I could include?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For derived_projection, I found the SQL can trigger:

select j1_id from (select 1 as j1_id)
 -> SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`

However, I guess it's not a valid SQL for MySQL, right? I tried it in DataFusion and DuckDB, they accept it but Postgres doesn't allow it. Maybe, it could be a DataFusion dialect to MySQL dialect case.

For derived_limit, a similar case can trigger it:

select * from (select * from j1 limit 10) 
  -> SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is valid for MySQL, I just tested it out in my CLI. I've added both as new cases, thanks!

parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;",
expected:
// top projection sort still gets derived into a subquery in default dialect
// except for the default dialect, the subquery is left non-aliased
"SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10",
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;",
expected:
"SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id from (select 1 as j1_id);",
expected:
"SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select * from (select * from j1 limit 10);",
expected:
"SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
expected:
Expand Down