diff --git a/src/query/sql/src/planner/binder/select.rs b/src/query/sql/src/planner/binder/select.rs index ccdf0c7132c6b..fed0fd07f5797 100644 --- a/src/query/sql/src/planner/binder/select.rs +++ b/src/query/sql/src/planner/binder/select.rs @@ -39,6 +39,9 @@ use common_ast::Visitor; use common_exception::ErrorCode; use common_exception::Result; use common_exception::Span; +use common_expression::type_check::common_super_type; +use common_expression::types::DataType; +use common_functions::BUILTIN_FUNCTIONS; use tracing::warn; use super::sort::OrderItem; @@ -52,6 +55,8 @@ use crate::planner::binder::scalar::ScalarBinder; use crate::planner::binder::BindContext; use crate::planner::binder::Binder; use crate::plans::BoundColumnRef; +use crate::plans::CastExpr; +use crate::plans::EvalScalar; use crate::plans::Filter; use crate::plans::JoinType; use crate::plans::ScalarExpr; @@ -60,6 +65,7 @@ use crate::plans::UnionAll; use crate::ColumnBinding; use crate::ColumnEntry; use crate::IndexType; +use crate::Visibility; // A normalized IR for `SELECT` clause. #[derive(Debug, Default)] @@ -464,19 +470,48 @@ impl Binder { pub fn bind_union( &mut self, left_span: Span, - _right_span: Span, + right_span: Span, left_context: BindContext, right_context: BindContext, left_expr: SExpr, right_expr: SExpr, distinct: bool, ) -> Result<(SExpr, BindContext)> { - let pairs = left_context + let mut coercion_types = Vec::with_capacity(left_context.columns.len()); + for (left_col, right_col) in left_context .columns .iter() .zip(right_context.columns.iter()) - .map(|(l, r)| (l.index, r.index)) - .collect(); + { + if left_col.data_type != right_col.data_type { + if let Some(data_type) = common_super_type( + *left_col.data_type.clone(), + *right_col.data_type.clone(), + &BUILTIN_FUNCTIONS.default_cast_rules, + ) { + coercion_types.push(data_type); + } else { + return Err(ErrorCode::SemanticError(format!( + "SetOperation's types cannot be matched, left column {:?}, type: {:?}, right column {:?}, type: {:?}", + left_col.column_name, + left_col.data_type, + right_col.column_name, + right_col.data_type + ))); + } + } else { + coercion_types.push(*left_col.data_type.clone()); + } + } + let (new_bind_context, pairs, left_expr, right_expr) = self.coercion_union_type( + left_span, + right_span, + left_context, + right_context, + left_expr, + right_expr, + coercion_types, + )?; let union_plan = UnionAll { pairs }; let mut new_expr = SExpr::create_binary( @@ -488,14 +523,14 @@ impl Binder { if distinct { new_expr = self.bind_distinct( left_span, - &left_context, - left_context.all_column_bindings(), + &new_bind_context, + new_bind_context.all_column_bindings(), &mut HashMap::new(), new_expr, )?; } - Ok((new_expr, left_context)) + Ok((new_expr, new_bind_context)) } pub fn bind_intersect( @@ -589,6 +624,118 @@ impl Binder { Ok((s_expr, left_context)) } + #[allow(clippy::type_complexity)] + #[allow(clippy::too_many_arguments)] + fn coercion_union_type( + &self, + left_span: Span, + right_span: Span, + left_bind_context: BindContext, + right_bind_context: BindContext, + mut left_expr: SExpr, + mut right_expr: SExpr, + coercion_types: Vec, + ) -> Result<(BindContext, Vec<(IndexType, IndexType)>, SExpr, SExpr)> { + let mut left_scalar_items = Vec::with_capacity(left_bind_context.columns.len()); + let mut right_scalar_items = Vec::with_capacity(right_bind_context.columns.len()); + let mut new_bind_context = BindContext::new(); + let mut pairs = Vec::with_capacity(left_bind_context.columns.len()); + for (idx, (left_col, right_col)) in left_bind_context + .columns + .iter() + .zip(right_bind_context.columns.iter()) + .enumerate() + { + let left_index = if *left_col.data_type != coercion_types[idx] { + let new_column_index = self + .metadata + .write() + .add_derived_column(left_col.column_name.clone(), coercion_types[idx].clone()); + let column_binding = ColumnBinding { + database_name: None, + table_name: None, + column_position: None, + table_index: None, + column_name: left_col.column_name.clone(), + index: new_column_index, + data_type: Box::new(coercion_types[idx].clone()), + visibility: Visibility::Visible, + virtual_computed_expr: None, + }; + let left_coercion_expr = CastExpr { + span: left_span, + is_try: false, + argument: Box::new( + BoundColumnRef { + span: left_span, + column: left_col.clone(), + } + .into(), + ), + target_type: Box::new(coercion_types[idx].clone()), + }; + left_scalar_items.push(ScalarItem { + scalar: left_coercion_expr.into(), + index: new_column_index, + }); + new_bind_context.add_column_binding(column_binding); + new_column_index + } else { + new_bind_context.add_column_binding(left_col.clone()); + left_col.index + }; + let right_index = if *right_col.data_type != coercion_types[idx] { + let new_column_index = self + .metadata + .write() + .add_derived_column(right_col.column_name.clone(), coercion_types[idx].clone()); + let right_coercion_expr = CastExpr { + span: right_span, + is_try: false, + argument: Box::new( + BoundColumnRef { + span: right_span, + column: right_col.clone(), + } + .into(), + ), + target_type: Box::new(coercion_types[idx].clone()), + }; + right_scalar_items.push(ScalarItem { + scalar: right_coercion_expr.into(), + index: new_column_index, + }); + new_column_index + } else { + right_col.index + }; + pairs.push((left_index, right_index)); + } + if !left_scalar_items.is_empty() { + left_expr = SExpr::create_unary( + Arc::new( + EvalScalar { + items: left_scalar_items, + } + .into(), + ), + Arc::new(left_expr), + ); + } + if !right_scalar_items.is_empty() { + right_expr = SExpr::create_unary( + Arc::new( + EvalScalar { + items: right_scalar_items, + } + .into(), + ), + Arc::new(right_expr), + ); + } + Ok((new_bind_context, pairs, left_expr, right_expr)) + } + #[allow(clippy::too_many_arguments)] fn analyze_lazy_materialization( &self, diff --git a/tests/sqllogictests/suites/query/union.test b/tests/sqllogictests/suites/query/union.test index 3373b51455790..922a571d4cb7b 100644 --- a/tests/sqllogictests/suites/query/union.test +++ b/tests/sqllogictests/suites/query/union.test @@ -177,3 +177,10 @@ drop table t1 statement error 1065 select [1,2,3] union all select 2 + +# type coercion +query R rowsort +select 1 as c union all select 3.3::Double; +---- +1.0 +3.3