Skip to content

Commit

Permalink
feat: Union types coercion (#3513)
Browse files Browse the repository at this point in the history
1
  • Loading branch information
gandronchik authored Sep 19, 2022
1 parent f30fc4e commit 4b1e044
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 32 deletions.
32 changes: 30 additions & 2 deletions datafusion/core/src/physical_plan/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@

use std::{any::Any, sync::Arc};

use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use arrow::{
datatypes::{Field, Schema, SchemaRef},
record_batch::RecordBatch,
};
use futures::StreamExt;
use itertools::Itertools;
use log::debug;

use super::{
Expand All @@ -46,14 +50,38 @@ pub struct UnionExec {
inputs: Vec<Arc<dyn ExecutionPlan>>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
/// Schema of Union
schema: SchemaRef,
}

impl UnionExec {
/// Create a new UnionExec
pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
let fields: Vec<Field> = (0..inputs[0].schema().fields().len())
.map(|i| {
inputs
.iter()
.filter_map(|input| {
if input.schema().fields().len() > i {
Some(input.schema().field(i).clone())
} else {
None
}
})
.find_or_first(|f| f.is_nullable())
.unwrap()
})
.collect();

let schema = Arc::new(Schema::new_with_metadata(
fields,
inputs[0].schema().metadata().clone(),
));

UnionExec {
inputs,
metrics: ExecutionPlanMetricsSet::new(),
schema,
}
}

Expand All @@ -70,7 +98,7 @@ impl ExecutionPlan for UnionExec {
}

fn schema(&self) -> SchemaRef {
self.inputs[0].schema()
self.schema.clone()
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
Expand Down
38 changes: 36 additions & 2 deletions datafusion/expr/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! Expression rewriter

use crate::expr::GroupingSet;
use crate::logical_plan::Aggregate;
use crate::utils::grouping_set_to_exprlist;
use crate::logical_plan::{Aggregate, Projection};
use crate::utils::{from_plan, grouping_set_to_exprlist};
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
Expand Down Expand Up @@ -524,6 +524,40 @@ pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
exprs.into_iter().map(unnormalize_col).collect()
}

/// Returns plan with expressions coerced to types compatible with
/// schema types
pub fn coerce_plan_expr_for_schema(
plan: &LogicalPlan,
schema: &DFSchema,
) -> Result<LogicalPlan> {
let new_expr = plan
.expressions()
.into_iter()
.enumerate()
.map(|(i, expr)| {
let new_type = schema.field(i).data_type();
if plan.schema().field(i).data_type() != schema.field(i).data_type() {
match (plan, &expr) {
(
LogicalPlan::Projection(Projection { input, .. }),
Expr::Alias(e, alias),
) => Ok(Expr::Alias(
Box::new(e.clone().cast_to(new_type, input.schema())?),
alias.clone(),
)),
_ => expr.cast_to(new_type, plan.schema()),
}
} else {
Ok(expr)
}
})
.collect::<Result<Vec<_>>>()?;

let new_inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();

from_plan(plan, &new_expr, &new_inputs)
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
73 changes: 46 additions & 27 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

//! This module provides a builder for creating LogicalPlans

use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs};
use crate::binary_rule::comparison_coercion;
use crate::expr_rewriter::{
coerce_plan_expr_for_schema, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs,
};
use crate::utils::{
columnize_expr, exprlist_to_fields, from_plan, grouping_set_to_exprlist,
};
Expand Down Expand Up @@ -882,43 +885,59 @@ pub fn union_with_alias(
right_plan: LogicalPlan,
alias: Option<String>,
) -> Result<LogicalPlan> {
let union_schema = left_plan.schema().clone();
let inputs_iter = vec![left_plan, right_plan]
let union_schema = (0..left_plan.schema().fields().len())
.map(|i| {
let left_field = left_plan.schema().field(i);
let right_field = right_plan.schema().field(i);
let nullable = left_field.is_nullable() || right_field.is_nullable();
let data_type =
comparison_coercion(left_field.data_type(), right_field.data_type())
.ok_or_else(|| {
DataFusionError::Plan(format!(
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
right_field.name(),
right_field.data_type(),
left_field.name(),
left_field.data_type()
))
})?;

Ok(DFField::new(
alias.as_deref(),
left_field.name(),
data_type,
nullable,
))
})
.collect::<Result<Vec<_>>>()?
.to_dfschema()?;

let inputs = vec![left_plan, right_plan]
.into_iter()
.flat_map(|p| match p {
LogicalPlan::Union(Union { inputs, .. }) => inputs,
x => vec![Arc::new(x)],
});

inputs_iter
.clone()
.skip(1)
.try_for_each(|input_plan| -> Result<()> {
union_schema.check_arrow_schema_type_compatible(
&((**input_plan.schema()).clone().into()),
)
})?;

let inputs = inputs_iter
.map(|p| match p.as_ref() {
LogicalPlan::Projection(Projection {
expr, input, alias, ..
}) => Ok(Arc::new(project_with_column_index_alias(
expr.to_vec(),
input.clone(),
union_schema.clone(),
alias.clone(),
)?)),
x => Ok(Arc::new(x.clone())),
})
.into_iter()
.map(|p| {
let plan = coerce_plan_expr_for_schema(&p, &union_schema)?;
match plan {
LogicalPlan::Projection(Projection {
expr, input, alias, ..
}) => Ok(Arc::new(project_with_column_index_alias(
expr.to_vec(),
input,
Arc::new(union_schema.clone()),
alias,
)?)),
x => Ok(Arc::new(x)),
}
})
.collect::<Result<Vec<_>>>()?;

if inputs.is_empty() {
return Err(DataFusionError::Plan("Empty UNION".to_string()));
}

let union_schema = (**inputs[0].schema()).clone();
let union_schema = Arc::new(match alias {
Some(ref alias) => union_schema.replace_qualifier(alias.as_str()),
None => union_schema.strip_qualifiers(),
Expand Down
108 changes: 107 additions & 1 deletion datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4163,14 +4163,120 @@ mod tests {
let sql = "SELECT interval '1 year 1 day' UNION ALL SELECT 1";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Column Int64(1) (type: Int64) is \
"Plan(\"UNION Column Int64(1) (type: Int64) is \
not compatible with column IntervalMonthDayNano\
(\\\"950737950189618795196236955648\\\") \
(type: Interval(MonthDayNano))\")",
format!("{:?}", err)
);
}

#[test]
fn union_with_different_decimal_data_types() {
let sql = "SELECT 1 a UNION ALL SELECT 1.1 a";
let expected = "Union\
\n Projection: CAST(Int64(1) AS Float64) AS a\
\n EmptyRelation\
\n Projection: Float64(1.1) AS a\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_null() {
let sql = "SELECT NULL a UNION ALL SELECT 1.1 a";
let expected = "Union\
\n Projection: CAST(NULL AS Float64) AS a\
\n EmptyRelation\
\n Projection: Float64(1.1) AS a\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_float_and_string() {
let sql = "SELECT 'a' a UNION ALL SELECT 1.1 a";
let expected = "Union\
\n Projection: Utf8(\"a\") AS a\
\n EmptyRelation\
\n Projection: CAST(Float64(1.1) AS Utf8) AS a\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_multiply_cols() {
let sql = "SELECT 'a' a, 1 b UNION ALL SELECT 1.1 a, 1.1 b";
let expected = "Union\
\n Projection: Utf8(\"a\") AS a, CAST(Int64(1) AS Float64) AS b\
\n EmptyRelation\
\n Projection: CAST(Float64(1.1) AS Utf8) AS a, Float64(1.1) AS b\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn sorted_union_with_different_types_and_group_by() {
let sql = "SELECT a FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT a FROM (select 1.1 a) x GROUP BY 1) ORDER BY 1";
let expected = "Sort: #a ASC NULLS LAST\
\n Union\
\n Projection: CAST(#x.a AS Float64) AS a\
\n Aggregate: groupBy=[[#x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Int64(1) AS a, alias=x\
\n EmptyRelation\
\n Projection: #x.a\
\n Aggregate: groupBy=[[#x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Float64(1.1) AS a, alias=x\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_binary_expr_and_cast() {
let sql = "SELECT cast(0.0 + a as integer) FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT 2.1 + a FROM (select 1 a) x GROUP BY 1)";
let expected = "Union\
\n Projection: CAST(#Float64(0) + x.a AS Float64) AS Float64(0) + x.a\
\n Aggregate: groupBy=[[CAST(Float64(0) + #x.a AS Int32)]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Int64(1) AS a, alias=x\
\n EmptyRelation\
\n Projection: #Float64(2.1) + x.a\
\n Aggregate: groupBy=[[Float64(2.1) + #x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Int64(1) AS a, alias=x\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_aliases() {
let sql = "SELECT a as a1 FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT a as a1 FROM (select 1.1 a) x GROUP BY 1)";
let expected = "Union\
\n Projection: CAST(#x.a AS Float64) AS a1\
\n Aggregate: groupBy=[[#x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Int64(1) AS a, alias=x\
\n EmptyRelation\
\n Projection: #x.a AS a1\
\n Aggregate: groupBy=[[#x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Float64(1.1) AS a, alias=x\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_incompatible_data_types() {
let sql = "SELECT 'a' a UNION ALL SELECT true a";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)\")",
format!("{:?}", err)
);
}

#[test]
fn empty_over() {
let sql = "SELECT order_id, MAX(order_id) OVER () from orders";
Expand Down

0 comments on commit 4b1e044

Please sign in to comment.