Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,10 @@ impl SimplifyInfo for SessionSimplifyProvider<'_> {
fn get_data_type(&self, expr: &Expr) -> datafusion_common::Result<DataType> {
expr.get_type(self.df_schema)
}

fn get_schema(&self) -> Option<&DFSchema> {
Some(self.df_schema)
}
}

#[derive(Debug)]
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2238,7 +2238,7 @@ mod tests {
// verify that the plan correctly casts u8 to i64
// the cast from u8 to i64 for literal will be simplified, and get lit(int64(5))
// the cast here is implicit so has CastOptions with safe=true
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }";
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5), metadata: None }, fail_on_overflow: false }";
assert!(format!("{exec_plan:?}").contains(expected));
Ok(())
}
Expand All @@ -2263,7 +2263,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), metadata: None }, "c1"), (Literal { value: Int64(NULL), metadata: None }, "c2"), (Literal { value: Int64(NULL), metadata: None }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;

assert_eq!(format!("{cube:?}"), expected);

Expand All @@ -2290,7 +2290,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), metadata: None }, "c1"), (Literal { value: Int64(NULL), metadata: None }, "c2"), (Literal { value: Int64(NULL), metadata: None }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;

assert_eq!(format!("{rollup:?}"), expected);

Expand Down Expand Up @@ -2474,7 +2474,7 @@ mod tests {
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.

let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }";
let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\"), metadata: None }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\"), metadata: None }, fail_on_overflow: false }, fail_on_overflow: false }";

let actual = format!("{execution_plan:?}");
assert!(actual.contains(expected), "{}", actual);
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arrow::datatypes::{DataType, Field, Schema};
use chrono::{DateTime, TimeZone, Utc};
use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*};
use datafusion_common::cast::as_int32_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_common::{DFSchemaRef, ToDFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
Expand Down Expand Up @@ -71,6 +71,10 @@ impl SimplifyInfo for MyInfo {
fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
expr.get_type(self.schema.as_ref())
}

fn get_schema(&self) -> Option<&DFSchema> {
Some(self.schema.as_ref())
}
}

impl From<DFSchemaRef> for MyInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::array::{as_string_array, record_batch, Int8Array, UInt64Array};
use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array};
use arrow::array::{
builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array,
Int32Array, RecordBatch, StringArray,
Expand Down Expand Up @@ -1527,6 +1527,57 @@ async fn test_metadata_based_udf() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_metadata_based_udf_with_literal() -> Result<()> {
let ctx = SessionContext::new();
let df = ctx.sql("select 0;").await?.select(vec![
lit(5u64).alias_with_metadata(
"lit_with_doubling",
Some(
[("modify_values".to_string(), "double_output".to_string())]
.into_iter()
.collect(),
),
),
lit(5u64).alias("lit_no_doubling"),
])?;

let output_metadata: HashMap<String, String> =
[("output_metatype".to_string(), "custom_value".to_string())]
.into_iter()
.collect();
let custom_udf = ScalarUDF::from(MetadataBasedUdf::new(output_metadata.clone()));

let plan = LogicalPlanBuilder::from(df.into_optimized_plan()?)
.project(vec![
custom_udf
.call(vec![col("lit_with_doubling")])
.alias("doubled_output"),
custom_udf
.call(vec![col("lit_no_doubling")])
.alias("not_doubled_output"),
])?
.build()?;

let actual = DataFrame::new(ctx.state(), plan).collect().await?;

let schema = Arc::new(Schema::new(vec![
Field::new("doubled_output", DataType::UInt64, true)
.with_metadata(output_metadata.clone()),
Field::new("not_doubled_output", DataType::UInt64, true)
.with_metadata(output_metadata.clone()),
]));

let expected = RecordBatch::try_new(
schema,
vec![create_array!(UInt64, [10]), create_array!(UInt64, [5])],
)?;

assert_eq!(expected, actual[0]);

Ok(())
}

/// This UDF is to test extension handling, both on the input and output
/// sides. For the input, we will handle the data differently if there is
/// the canonical extension type Bool8. For the output we will add a
Expand Down
12 changes: 10 additions & 2 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1512,8 +1512,16 @@ impl Expr {
|expr| {
// f_up: unalias on up so we can remove nested aliases like
// `(x as foo) as bar`
if let Expr::Alias(Alias { expr, .. }) = expr {
Ok(Transformed::yes(*expr))
if let Expr::Alias(alias) = expr {
match alias
.metadata
.as_ref()
.map(|h| h.is_empty())
.unwrap_or(true)
{
true => Ok(Transformed::yes(*alias.expr)),
false => Ok(Transformed::no(Expr::Alias(alias))),
}
} else {
Ok(Transformed::no(expr))
}
Expand Down
9 changes: 8 additions & 1 deletion datafusion/expr/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Structs and traits to provide the information needed for expression simplification.

use arrow::datatypes::DataType;
use datafusion_common::{DFSchemaRef, DataFusionError, Result};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};

use crate::{execution_props::ExecutionProps, Expr, ExprSchemable};

Expand All @@ -40,6 +40,9 @@ pub trait SimplifyInfo {

/// Returns data type of this expr needed for determining optimized int type of a value
fn get_data_type(&self, expr: &Expr) -> Result<DataType>;

/// Returns the Schema which may be needed to evalute an Expr
fn get_schema(&self) -> Option<&DFSchema>;
}

/// Provides simplification information based on DFSchema and
Expand Down Expand Up @@ -106,6 +109,10 @@ impl SimplifyInfo for SimplifyContext<'_> {
fn execution_props(&self) -> &ExecutionProps {
self.props
}

fn get_schema(&self) -> Option<&DFSchema> {
self.schema.as_ref().map(|v| v.as_ref())
}
}

/// Was the expression simplified?
Expand Down
14 changes: 12 additions & 2 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,18 @@ fn is_expr_trivial(expr: &Expr) -> bool {
fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
expr.transform_up(|expr| {
match expr {
// remove any intermediate aliases
Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)),
// remove any intermediate aliases if they do not carry metadata
Expr::Alias(alias) => {
match alias
.metadata
.as_ref()
.map(|h| h.is_empty())
.unwrap_or(true)
{
true => Ok(Transformed::yes(*alias.expr)),
false => Ok(Transformed::no(Expr::Alias(alias))),
}
}
Expr::Column(col) => {
// Find index of column:
let idx = input.schema.index_of_column(&col)?;
Expand Down
70 changes: 55 additions & 15 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ use datafusion_common::{
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::{
and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like,
Operator, Volatility, WindowFunctionDefinition,
and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr,
ExprSchemable, Like, Operator, Volatility, WindowFunctionDefinition,
};
use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
use datafusion_expr::{
Expand Down Expand Up @@ -172,6 +172,9 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
/// Ok(DataType::Int32)
/// }
/// fn get_schema(&self) -> Option<&DFSchema> {
/// None
/// }
/// }
///
/// // Create the simplifier
Expand Down Expand Up @@ -227,12 +230,17 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
mut expr: Expr,
) -> Result<(Transformed<Expr>, u32)> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;

let empty_schema = DFSchema::empty();
let mut const_evaluator = ConstEvaluator::try_new(
self.info.execution_props(),
self.info.get_schema().unwrap_or(&empty_schema),
)?;
let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);

if self.canonicalize {
expr = expr.rewrite(&mut Canonicalizer::new()).data()?
expr = expr.rewrite(&mut Canonicalizer::new()).data()?;
}

// Evaluating constants can enable new simplifications and
Expand All @@ -249,14 +257,17 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
.transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
expr = data;
num_cycles += 1;

// Track if any transformation occurred
has_transformed = has_transformed || transformed;
if !transformed || num_cycles >= self.max_simplifier_cycles {
break;
}
}

// shorten inlist should be started after other inlist rules are applied
expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;

Ok((
Transformed::new_transformed(expr, has_transformed),
num_cycles,
Expand Down Expand Up @@ -540,7 +551,7 @@ impl TreeNodeRewriter for ConstEvaluator<'_> {
// stack as not ok (as all parents have at least one child or
// descendant that can not be evaluated

if !Self::can_evaluate(&expr) {
if !self.can_evaluate(&expr) {
// walk back up stack, marking first parent that is not mutable
let parent_iter = self.can_evaluate.iter_mut().rev();
for p in parent_iter {
Expand Down Expand Up @@ -586,12 +597,16 @@ impl<'a> ConstEvaluator<'a> {
/// Create a new `ConstantEvaluator`. Session constants (such as
/// the time for `now()` are taken from the passed
/// `execution_props`.
pub fn try_new(execution_props: &'a ExecutionProps) -> Result<Self> {
pub fn try_new(
execution_props: &'a ExecutionProps,
input_schema: &DFSchema,
) -> Result<Self> {
// The dummy column name is unused and doesn't matter as only
// expressions without column references can be evaluated
static DUMMY_COL_NAME: &str = ".";

let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]);
let input_schema = DFSchema::try_from(schema.clone())?;
let input_schema = input_schema.clone();
// Need a single "input" row to produce a single output row
let col = new_null_array(&DataType::Null, 1);
let input_batch = RecordBatch::try_new(std::sync::Arc::new(schema), vec![col])?;
Expand All @@ -616,7 +631,17 @@ impl<'a> ConstEvaluator<'a> {

/// Can the expression be evaluated at plan time, (assuming all of
/// its children can also be evaluated)?
fn can_evaluate(expr: &Expr) -> bool {
fn can_evaluate(&self, expr: &Expr) -> bool {
// Any time we have metadata on the return value we cannot
// simplify to a literal because we will lose it.
let false = expr
.metadata(&self.input_schema)
.map(|metadata| !metadata.is_empty())
.unwrap_or(false)
else {
return false;
};

// check for reasons we can't evaluate this node
//
// NOTE all expr types are listed here so when new ones are
Expand All @@ -640,8 +665,8 @@ impl<'a> ConstEvaluator<'a> {
Self::volatility_ok(func.signature().volatility)
}
Expr::Literal(_)
| Expr::Alias(..)
| Expr::Unnest(_)
| Expr::Alias(_)
| Expr::BinaryExpr { .. }
| Expr::Not(_)
| Expr::IsNotNull(_)
Expand Down Expand Up @@ -4326,16 +4351,23 @@ mod tests {
#[derive(Debug, Clone)]
struct SimplifyMockUdaf {
simplify: bool,
signature: Signature,
}

impl SimplifyMockUdaf {
/// make simplify method return new expression
fn new_with_simplify() -> Self {
Self { simplify: true }
Self {
simplify: true,
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
/// make simplify method return no change
fn new_without_simplify() -> Self {
Self { simplify: false }
Self {
simplify: false,
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}

Expand All @@ -4349,7 +4381,7 @@ mod tests {
}

fn signature(&self) -> &Signature {
unimplemented!()
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Expand Down Expand Up @@ -4409,16 +4441,24 @@ mod tests {
#[derive(Debug, Clone)]
struct SimplifyMockUdwf {
simplify: bool,
signature: Signature,
}

impl SimplifyMockUdwf {
/// make simplify method return new expression
fn new_with_simplify() -> Self {
Self { simplify: true }
Self {
simplify: true,
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}

/// make simplify method return no change
fn new_without_simplify() -> Self {
Self { simplify: false }
Self {
simplify: false,
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}

Expand All @@ -4432,7 +4472,7 @@ mod tests {
}

fn signature(&self) -> &Signature {
unimplemented!()
&self.signature
}

fn simplify(&self) -> Option<WindowFunctionSimplification> {
Expand Down
Loading
Loading