Skip to content

Commit

Permalink
fix!: ensure predicates are parsable (#1690)
Browse files Browse the repository at this point in the history
# Description
Resolves two issues that impact Datafusion implemented operators

1. When a user has an expression with a scalar built-in scalar function
we are unable parse the output predicate since the
`DummyContextProvider`'s methods are unimplemented. The provider now
uses the user provided state or a default. More work is required in the
future to allow a user provided Datafusion state to be used during the
conflict checker.

2. The string representation was not parsable by sqlparser since it was
not valid SQL. New code was written to transform an expression into a
parsable sql string. Current implementation is not exhaustive however
common use cases are covered.

The delta_datafusion.rs file is getting large so I transformed it into a
module.

This implementation makes reuse of some code from Datafusion. I've added
the Apache License at the top of the file. Let me know if any else is
required to be compliant.


# Related Issue(s)
- closes #1625

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
  • Loading branch information
Blajda and wjones127 authored Oct 3, 2023
1 parent dd1fa8c commit 4da7d66
Show file tree
Hide file tree
Showing 8 changed files with 598 additions and 40 deletions.
505 changes: 505 additions & 0 deletions rust/src/delta_datafusion/expr.rs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ use crate::{open_table, open_table_with_storage_options, DeltaTable, Invariant,

const PATH_COLUMN: &str = "__delta_rs_path";

pub mod expr;

impl From<DeltaTableError> for DataFusionError {
fn from(err: DeltaTableError) -> Self {
match err {
Expand Down
15 changes: 12 additions & 3 deletions rust/src/operations/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};

use crate::delta_datafusion::expr::fmt_expr_to_sql;
use crate::protocol::{Action, Add, Remove};
use datafusion::execution::context::{SessionContext, SessionState};
use datafusion::physical_expr::create_physical_expr;
Expand Down Expand Up @@ -263,7 +264,7 @@ async fn execute(
// Do not make a commit when there are zero updates to the state
if !actions.is_empty() {
let operation = DeltaOperation::Delete {
predicate: Some(predicate.canonical_name()),
predicate: Some(fmt_expr_to_sql(&predicate)?),
};
version = commit(
object_store.as_ref(),
Expand Down Expand Up @@ -298,7 +299,9 @@ impl std::future::IntoFuture for DeleteBuilder {
let predicate = match this.predicate {
Some(predicate) => match predicate {
Expression::DataFusion(expr) => Some(expr),
Expression::String(s) => Some(this.snapshot.parse_predicate_expression(s)?),
Expression::String(s) => {
Some(this.snapshot.parse_predicate_expression(s, &state)?)
}
},
None => None,
};
Expand Down Expand Up @@ -335,6 +338,7 @@ mod tests {
use arrow::record_batch::RecordBatch;
use datafusion::assert_batches_sorted_eq;
use datafusion::prelude::*;
use serde_json::json;
use std::sync::Arc;

async fn setup_table(partitions: Option<Vec<&str>>) -> DeltaTable {
Expand Down Expand Up @@ -456,7 +460,7 @@ mod tests {
assert_eq!(table.version(), 2);
assert_eq!(table.get_file_uris().count(), 2);

let (table, metrics) = DeltaOps(table)
let (mut table, metrics) = DeltaOps(table)
.delete()
.with_predicate(col("value").eq(lit(1)))
.await
Expand All @@ -470,6 +474,11 @@ mod tests {
assert_eq!(metrics.num_deleted_rows, Some(1));
assert_eq!(metrics.num_copied_rows, Some(3));

let commit_info = table.history(None).await.unwrap();
let last_commit = &commit_info[commit_info.len() - 1];
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["predicate"], json!("value = 1"));

let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
Expand Down
32 changes: 23 additions & 9 deletions rust/src/operations/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use serde_json::{Map, Value};

use super::datafusion_utils::{into_expr, maybe_into_expr, Expression};
use super::transaction::commit;
use crate::delta_datafusion::expr::fmt_expr_to_sql;
use crate::delta_datafusion::{parquet_scan_from_actions, register_store};
use crate::operations::datafusion_utils::MetricObserverExec;
use crate::operations::write::write_execution_plan;
Expand Down Expand Up @@ -171,6 +172,7 @@ impl MergeBuilder {
let builder = builder(UpdateBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
builder.updates,
OperationType::Update,
Expand Down Expand Up @@ -204,6 +206,7 @@ impl MergeBuilder {
let builder = builder(DeleteBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
HashMap::default(),
OperationType::Delete,
Expand Down Expand Up @@ -240,6 +243,7 @@ impl MergeBuilder {
let builder = builder(InsertBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
builder.set,
OperationType::Insert,
Expand Down Expand Up @@ -278,6 +282,7 @@ impl MergeBuilder {
let builder = builder(UpdateBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
builder.updates,
OperationType::Update,
Expand Down Expand Up @@ -311,6 +316,7 @@ impl MergeBuilder {
let builder = builder(DeleteBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
HashMap::default(),
OperationType::Delete,
Expand Down Expand Up @@ -448,15 +454,21 @@ struct MergeOperation {
impl MergeOperation {
pub fn try_new(
snapshot: &DeltaTableState,
state: &Option<&SessionState>,
predicate: Option<Expression>,
operations: HashMap<Column, Expression>,
r#type: OperationType,
) -> DeltaResult<Self> {
let predicate = maybe_into_expr(predicate, snapshot)?;
let context = SessionContext::new();
let mut s = &context.state();
if let Some(df_state) = state {
s = df_state;
}
let predicate = maybe_into_expr(predicate, snapshot, s)?;
let mut _operations = HashMap::new();

for (column, expr) in operations {
_operations.insert(column, into_expr(expr, snapshot)?);
_operations.insert(column, into_expr(expr, snapshot, s)?);
}

Ok(MergeOperation {
Expand Down Expand Up @@ -518,7 +530,7 @@ async fn execute(

let predicate = match predicate {
Expression::DataFusion(expr) => expr,
Expression::String(s) => snapshot.parse_predicate_expression(s)?,
Expression::String(s) => snapshot.parse_predicate_expression(s, &state)?,
};

let schema = snapshot.input_schema()?;
Expand Down Expand Up @@ -675,7 +687,10 @@ async fn execute(
};

let action_type = action_type.to_string();
let predicate = op.predicate.map(|expr| expr.display_name().unwrap());
let predicate = op
.predicate
.map(|expr| fmt_expr_to_sql(&expr))
.transpose()?;

predicates.push(MergePredicate {
action_type,
Expand Down Expand Up @@ -1035,7 +1050,7 @@ async fn execute(
// Do not make a commit when there are zero updates to the state
if !actions.is_empty() {
let operation = DeltaOperation::Merge {
predicate: Some(predicate.canonical_name()),
predicate: Some(fmt_expr_to_sql(&predicate)?),
matched_predicates: match_operations,
not_matched_predicates: not_match_target_operations,
not_matched_by_source_predicates: not_match_source_operations,
Expand Down Expand Up @@ -1222,10 +1237,9 @@ mod tests {
parameters["notMatchedPredicates"],
json!(r#"[{"actionType":"insert"}]"#)
);
// Todo: Expected this predicate to actually be 'value = 1'. Predicate should contain a valid sql expression
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"update","predicate":"value = Int32(1)"}]"#)
json!(r#"[{"actionType":"update","predicate":"value = 1"}]"#)
);

let expected = vec![
Expand Down Expand Up @@ -1447,7 +1461,7 @@ mod tests {
assert_eq!(parameters["predicate"], json!("id = source.id"));
assert_eq!(
parameters["matchedPredicates"],
json!(r#"[{"actionType":"delete","predicate":"source.value <= Int32(10)"}]"#)
json!(r#"[{"actionType":"delete","predicate":"source.value <= 10"}]"#)
);

let expected = vec![
Expand Down Expand Up @@ -1579,7 +1593,7 @@ mod tests {
assert_eq!(parameters["predicate"], json!("id = source.id"));
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"delete","predicate":"modified > Utf8(\"2021-02-01\")"}]"#)
json!(r#"[{"actionType":"delete","predicate":"modified > '2021-02-01'"}]"#)
);

let expected = vec![
Expand Down
12 changes: 9 additions & 3 deletions rust/src/operations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ mod datafusion_utils {
use arrow_schema::SchemaRef;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::Result as DataFusionResult;
use datafusion::execution::context::SessionState;
use datafusion::physical_plan::DisplayAs;
use datafusion::physical_plan::{
metrics::{ExecutionPlanMetricsSet, MetricsSet},
Expand Down Expand Up @@ -240,19 +241,24 @@ mod datafusion_utils {
}
}

pub(crate) fn into_expr(expr: Expression, snapshot: &DeltaTableState) -> DeltaResult<Expr> {
pub(crate) fn into_expr(
expr: Expression,
snapshot: &DeltaTableState,
df_state: &SessionState,
) -> DeltaResult<Expr> {
match expr {
Expression::DataFusion(expr) => Ok(expr),
Expression::String(s) => snapshot.parse_predicate_expression(s),
Expression::String(s) => snapshot.parse_predicate_expression(s, df_state),
}
}

pub(crate) fn maybe_into_expr(
expr: Option<Expression>,
snapshot: &DeltaTableState,
df_state: &SessionState,
) -> DeltaResult<Option<Expr>> {
Ok(match expr {
Some(predicate) => Some(into_expr(predicate, snapshot)?),
Some(predicate) => Some(into_expr(predicate, snapshot, df_state)?),
None => None,
})
}
Expand Down
5 changes: 4 additions & 1 deletion rust/src/operations/transaction/conflict_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ impl<'a> TransactionInfo<'a> {
actions: &'a Vec<Action>,
read_whole_table: bool,
) -> DeltaResult<Self> {
use datafusion::prelude::SessionContext;

let session = SessionContext::new();
let read_predicates = read_predicates
.map(|pred| read_snapshot.parse_predicate_expression(pred))
.map(|pred| read_snapshot.parse_predicate_expression(pred, &session.state()))
.transpose()?;
Ok(Self {
txn_id: "".into(),
Expand Down
47 changes: 28 additions & 19 deletions rust/src/operations/transaction/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use arrow::datatypes::{
DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef,
};
use datafusion::datasource::physical_plan::wrap_partition_type_in_dict;
use datafusion::execution::context::SessionState;
use datafusion::optimizer::utils::conjunction;
use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics};
use datafusion_common::config::ConfigOptions;
Expand Down Expand Up @@ -104,7 +105,11 @@ impl DeltaTableState {
}

/// Parse an expression string into a datafusion [`Expr`]
pub fn parse_predicate_expression(&self, expr: impl AsRef<str>) -> DeltaResult<Expr> {
pub fn parse_predicate_expression(
&self,
expr: impl AsRef<str>,
df_state: &SessionState,
) -> DeltaResult<Expr> {
let dialect = &GenericDialect {};
let mut tokenizer = Tokenizer::new(dialect, expr.as_ref());
let tokens = tokenizer
Expand All @@ -121,7 +126,7 @@ impl DeltaTableState {

// TODO should we add the table name as qualifier when available?
let df_schema = DFSchema::try_from_qualified_schema("", self.arrow_schema()?.as_ref())?;
let context_provider = DummyContextProvider::default();
let context_provider = DeltaContextProvider { state: df_state };
let sql_to_rel = SqlToRel::new(&context_provider);

Ok(sql_to_rel.sql_to_expr(sql, &df_schema, &mut Default::default())?)
Expand Down Expand Up @@ -342,59 +347,63 @@ impl PruningStatistics for DeltaTableState {
}
}

#[derive(Default)]
struct DummyContextProvider {
options: ConfigOptions,
pub(crate) struct DeltaContextProvider<'a> {
state: &'a SessionState,
}

impl ContextProvider for DummyContextProvider {
impl<'a> ContextProvider for DeltaContextProvider<'a> {
fn get_table_provider(&self, _name: TableReference) -> DFResult<Arc<dyn TableSource>> {
unimplemented!()
}

fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
unimplemented!()
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions().get(name).cloned()
}

fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
unimplemented!()
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.state.aggregate_functions().get(name).cloned()
}

fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
fn get_variable_type(&self, _var: &[String]) -> Option<DataType> {
unimplemented!()
}

fn options(&self) -> &ConfigOptions {
&self.options
self.state.config_options()
}

fn get_window_meta(&self, _name: &str) -> Option<Arc<datafusion_expr::WindowUDF>> {
unimplemented!()
fn get_window_meta(&self, name: &str) -> Option<Arc<datafusion_expr::WindowUDF>> {
self.state.window_functions().get(name).cloned()
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::operations::transaction::test_utils::{create_add_action, init_table_actions};
use datafusion::prelude::SessionContext;
use datafusion_expr::{col, lit};

#[test]
fn test_parse_predicate_expression() {
let state = DeltaTableState::from_actions(init_table_actions(), 0).unwrap();
let snapshot = DeltaTableState::from_actions(init_table_actions(), 0).unwrap();
let session = SessionContext::new();
let state = session.state();

// parses simple expression
let parsed = state.parse_predicate_expression("value > 10").unwrap();
let parsed = snapshot
.parse_predicate_expression("value > 10", &state)
.unwrap();
let expected = col("value").gt(lit::<i64>(10));
assert_eq!(parsed, expected);

// fails for unknown column
let parsed = state.parse_predicate_expression("non_existent > 10");
let parsed = snapshot.parse_predicate_expression("non_existent > 10", &state);
assert!(parsed.is_err());

// parses complex expression
let parsed = state
.parse_predicate_expression("value > 10 OR value <= 0")
let parsed = snapshot
.parse_predicate_expression("value > 10 OR value <= 0", &state)
.unwrap();
let expected = col("value")
.gt(lit::<i64>(10))
Expand Down
Loading

0 comments on commit 4da7d66

Please sign in to comment.