Skip to content

Commit

Permalink
inlist: move type coercion to logical phase (#3472)
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 authored Sep 14, 2022
1 parent 0388682 commit f3bb84f
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 69 deletions.
32 changes: 19 additions & 13 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1696,9 +1696,11 @@ mod tests {
async fn plan(logical_plan: &LogicalPlan) -> Result<Arc<dyn ExecutionPlan>> {
let mut session_state = make_session_state();
session_state.config.target_partitions = 4;
// optimize the logical plan
let logical_plan = session_state.optimize(logical_plan)?;
let planner = DefaultPhysicalPlanner::default();
planner
.create_physical_plan(logical_plan, &session_state)
.create_physical_plan(&logical_plan, &session_state)
.await
}

Expand All @@ -1714,12 +1716,12 @@ mod tests {
.limit(3, Some(10))?
.build()?;

let plan = plan(&logical_plan).await?;
let exec_plan = plan(&logical_plan).await?;

// verify that the plan correctly casts u8 to i64
// the cast here is implicit so has CastOptions with safe=true
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 6 }, op: Lt, right: TryCastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }";
assert!(format!("{:?}", plan).contains(expected));
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }";
assert!(format!("{:?}", exec_plan).contains(expected));

Ok(())
}
Expand Down Expand Up @@ -1821,8 +1823,7 @@ mod tests {
async fn test_with_zero_offset_plan() -> Result<()> {
let logical_plan = test_csv_scan().await?.limit(0, None)?.build()?;
let plan = plan(&logical_plan).await?;
assert!(format!("{:?}", plan).contains("GlobalLimitExec"));
assert!(format!("{:?}", plan).contains("skip: 0"));
assert!(format!("{:?}", plan).contains("limit: None"));
Ok(())
}

Expand Down Expand Up @@ -1952,8 +1953,8 @@ mod tests {
.project(vec![col("c1").in_list(list, false)])?
.build()?;
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8 }], negated: false, set: None }";
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }], negated: false, set: None }";
assert!(format!("{:?}", execution_plan).contains(expected));

// expression: "a in (struct::null, 'a')"
Expand All @@ -1965,10 +1966,9 @@ mod tests {
.filter(col("c12").lt(lit(0.05)))?
.project(vec![col("c12").lt_eq(lit(0.025)).in_list(list, false)])?
.build()?;
let execution_plan = plan(&logical_plan).await;
let e = plan(&logical_plan).await.unwrap_err().to_string();

let e = execution_plan.unwrap_err().to_string();
assert_contains!(&e, "Can not find compatible types to compare Boolean with [Struct([Field { name: \"foo\", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }]), Utf8]");
assert_contains!(&e, "The data type inlist should be same, the value type is Boolean, one of list expr type is Struct([Field { name: \"foo\", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }])");

Ok(())
}
Expand Down Expand Up @@ -1996,7 +1996,10 @@ mod tests {
.project(vec![col("c1").in_list(list, false)])?
.build()?;
let execution_plan = plan(&logical_plan).await?;
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(4) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(6) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(7) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(8) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(9) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(10) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(11) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(12) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(13) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(14) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(15) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(16) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(17) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(18) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(19) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(20) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(21) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(22) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(23) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(24) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(25) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(26) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(27) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(28) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(29) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(30) }, cast_type: Utf8 }], negated: false, set: Some(InSet { set: ";
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }, Literal { value: Utf8(\"2\") },";
assert!(format!("{:?}", execution_plan).contains(expected));
let expected =
"Literal { value: Utf8(\"30\") }], negated: false, set: Some(InSet { set: ";
assert!(format!("{:?}", execution_plan).contains(expected));
Ok(())
}
Expand All @@ -2015,7 +2018,10 @@ mod tests {
.project(vec![col("c1").in_list(list, false)])?
.build()?;
let execution_plan = plan(&logical_plan).await?;
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [TryCastExpr { expr: Literal { value: Int64(NULL) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(4) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(6) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(7) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(8) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(9) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(10) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(11) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(12) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(13) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(14) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(15) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(16) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(17) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(18) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(19) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(20) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(21) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(22) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(23) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(24) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(25) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(26) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(27) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(28) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(29) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(30) }, cast_type: Utf8 }], negated: false, set: Some(InSet {";
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(NULL) }, Literal { value: Utf8(\"1\") }, Literal { value: Utf8(\"2\") }";
assert!(format!("{:?}", execution_plan).contains(expected));
let expected =
"Literal { value: Utf8(\"30\") }], negated: false, set: Some(InSet";
assert!(format!("{:?}", execution_plan).contains(expected));
Ok(())
}
Expand Down
94 changes: 93 additions & 1 deletion datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ impl ExprRewriter for TypeCoercionRewriter<'_> {
right.clone().cast_to(&coerced_type, &self.schema)?,
),
};

expr.rewrite(&mut self.const_evaluator)
}
}
Expand Down Expand Up @@ -164,11 +163,61 @@ impl ExprRewriter for TypeCoercionRewriter<'_> {
};
expr.rewrite(&mut self.const_evaluator)
}
Expr::InList {
expr,
list,
negated,
} => {
let expr_data_type = expr.get_type(&self.schema)?;
let list_data_types = list
.iter()
.map(|list_expr| list_expr.get_type(&self.schema))
.collect::<Result<Vec<_>>>()?;
let result_type =
get_coerce_type_for_list(&expr_data_type, &list_data_types);
match result_type {
None => Err(DataFusionError::Plan(format!(
"Can not find compatible types to compare {:?} with {:?}",
expr_data_type, list_data_types
))),
Some(coerced_type) => {
// find the coerced type
let cast_expr = expr.cast_to(&coerced_type, &self.schema)?;
let cast_list_expr = list
.into_iter()
.map(|list_expr| {
list_expr.cast_to(&coerced_type, &self.schema)
})
.collect::<Result<Vec<_>>>()?;
let expr = Expr::InList {
expr: Box::new(cast_expr),
list: cast_list_expr,
negated,
};
expr.rewrite(&mut self.const_evaluator)
}
}
}
expr => Ok(expr),
}
}
}

/// Attempts to coerce the types of `list_types` to be comparable with the
/// `expr_type`.
/// Returns the common data type for `expr_type` and `list_types`
fn get_coerce_type_for_list(
expr_type: &DataType,
list_types: &[DataType],
) -> Option<DataType> {
list_types
.iter()
.fold(Some(expr_type.clone()), |left, right_type| match left {
None => None,
Some(left_type) => comparison_coercion(&left_type, right_type),
})
}

/// Returns `expressions` coerced to types compatible with
/// `signature`, if possible.
///
Expand Down Expand Up @@ -348,6 +397,49 @@ mod test {
Ok(())
}

#[test]
fn inlist_case() -> Result<()> {
// a in (1,4,8), a is int64
let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(None, "a", DataType::Int64, true)],
std::collections::HashMap::new(),
)
.unwrap(),
),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: #a IN ([Int64(1), Int64(4), Int64(8)])\n EmptyRelation",
&format!("{:?}", plan)
);
// a in (1,4,8), a is decimal
let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(None, "a", DataType::Decimal128(12, 4), true)],
std::collections::HashMap::new(),
)
.unwrap(),
),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: CAST(#a AS Decimal128(24, 4)) IN ([Decimal128(Some(10000),24,4), Decimal128(Some(40000),24,4), Decimal128(Some(80000),24,4)])\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

fn empty() -> Arc<LogicalPlan> {
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
Expand Down
60 changes: 58 additions & 2 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,17 @@ pub fn in_list(
negated: &bool,
schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
// check the data type
let expr_data_type = expr.data_type(schema)?;
for list_expr in list.iter() {
let list_expr_data_type = list_expr.data_type(schema)?;
if !expr_data_type.eq(&list_expr_data_type) {
return Err(DataFusionError::Internal(format!(
"The data type inlist should be same, the value type is {}, one of list expr type is {}",
expr_data_type, list_expr_data_type
)));
}
}
Ok(Arc::new(InListExpr::new(expr, list, *negated, schema)))
}

Expand All @@ -969,9 +980,54 @@ mod tests {

use super::*;
use crate::expressions;
use crate::expressions::{col, lit};
use crate::planner::in_list_cast;
use crate::expressions::{col, lit, try_cast};
use datafusion_common::Result;
use datafusion_expr::binary_rule::comparison_coercion;

type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);

// Try to do the type coercion for list physical expr.
// It's just used in the test
fn in_list_cast(
expr: Arc<dyn PhysicalExpr>,
list: Vec<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Result<InListCastResult> {
let expr_type = &expr.data_type(input_schema)?;
let list_types: Vec<DataType> = list
.iter()
.map(|list_expr| list_expr.data_type(input_schema).unwrap())
.collect();
let result_type = get_coerce_type(expr_type, &list_types);
match result_type {
None => Err(DataFusionError::Plan(format!(
"Can not find compatible types to compare {:?} with {:?}",
expr_type, list_types
))),
Some(data_type) => {
// find the coerced type
let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
let cast_list_expr = list
.into_iter()
.map(|list_expr| {
try_cast(list_expr, input_schema, data_type.clone()).unwrap()
})
.collect();
Ok((cast_expr, cast_list_expr))
}
}
}

// Attempts to coerce the types of `list_type` to be comparable with the
// `expr_type`
fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option<DataType> {
list_type
.iter()
.fold(Some(expr_type.clone()), |left, right_type| match left {
None => None,
Some(left_type) => comparison_coercion(&left_type, right_type),
})
}

// applies the in_list expr to an input batch and list
macro_rules! in_list {
Expand Down
54 changes: 1 addition & 53 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use crate::expressions::try_cast;
use crate::var_provider::is_system_variables;
use crate::{
execution_props::ExecutionProps,
Expand All @@ -28,7 +27,6 @@ use crate::{
};
use arrow::datatypes::{DataType, Schema};
use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::binary_rule::comparison_coercion;
use datafusion_expr::{binary_expr, Expr, Operator};
use std::sync::Arc;

Expand Down Expand Up @@ -410,10 +408,7 @@ pub fn create_physical_expr(
)
})
.collect::<Result<Vec<_>>>()?;

let (cast_expr, cast_list_exprs) =
in_list_cast(value_expr, list_exprs, input_schema)?;
expressions::in_list(cast_expr, cast_list_exprs, negated, input_schema)
expressions::in_list(value_expr, list_exprs, negated, input_schema)
}
},
other => Err(DataFusionError::NotImplemented(format!(
Expand All @@ -422,50 +417,3 @@ pub fn create_physical_expr(
))),
}
}

type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);

pub(crate) fn in_list_cast(
expr: Arc<dyn PhysicalExpr>,
list: Vec<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Result<InListCastResult> {
let expr_type = &expr.data_type(input_schema)?;
let list_types: Vec<DataType> = list
.iter()
.map(|list_expr| list_expr.data_type(input_schema).unwrap())
.collect();
let result_type = get_coerce_type(expr_type, &list_types);
match result_type {
None => Err(DataFusionError::Plan(format!(
"Can not find compatible types to compare {:?} with {:?}",
expr_type, list_types
))),
Some(data_type) => {
// find the coerced type
let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
let cast_list_expr = list
.into_iter()
.map(|list_expr| {
try_cast(list_expr, input_schema, data_type.clone()).unwrap()
})
.collect();
Ok((cast_expr, cast_list_expr))
}
}
}

/// Attempts to coerce the types of `list_type` to be comparable with the
/// `expr_type`
fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option<DataType> {
// get the equal coerced data type
list_type
.iter()
.fold(Some(expr_type.clone()), |left, right_type| {
match left {
None => None,
// TODO refactor a framework to do the data type coercion
Some(left_type) => comparison_coercion(&left_type, right_type),
}
})
}

0 comments on commit f3bb84f

Please sign in to comment.