diff --git a/crates/deltalake-core/src/delta_datafusion/expr.rs b/crates/deltalake-core/src/delta_datafusion/expr.rs index f9275832a1..49cdae4387 100644 --- a/crates/deltalake-core/src/delta_datafusion/expr.rs +++ b/crates/deltalake-core/src/delta_datafusion/expr.rs @@ -347,9 +347,10 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> { mod test { use arrow_schema::DataType as ArrowDataType; use datafusion::prelude::SessionContext; - use datafusion_common::{DFSchema, ScalarValue}; + use datafusion_common::{Column, DFSchema, ScalarValue}; use datafusion_expr::{col, decode, lit, substring, Cast, Expr, ExprSchemable}; + use crate::delta_datafusion::DeltaSessionContext; use crate::kernel::{DataType, PrimitiveType, StructField, StructType}; use crate::{DeltaOps, DeltaTable}; @@ -388,6 +389,11 @@ mod test { DataType::Primitive(PrimitiveType::Integer), true, ), + StructField::new( + "Value3".to_string(), + DataType::Primitive(PrimitiveType::Integer), + true, + ), StructField::new( "modified".to_string(), DataType::Primitive(PrimitiveType::String), @@ -442,7 +448,10 @@ mod test { }), "arrow_cast(1, 'Int32')".to_string() ), - simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()), + simple!( + Expr::Column(Column::from_qualified_name_ignore_case("Value3")).eq(lit(3_i64)), + "Value3 = 3".to_string() + ), simple!(col("active").is_true(), "active IS TRUE".to_string()), simple!(col("active"), "active".to_string()), simple!(col("active").eq(lit(true)), "active = true".to_string()), @@ -536,7 +545,7 @@ mod test { ), ]; - let session = SessionContext::new(); + let session: SessionContext = DeltaSessionContext::default().into(); for test in tests { let actual = fmt_expr_to_sql(&test.expr).unwrap(); diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 83db86e8e2..9f2818de93 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -1033,7 +1033,7 @@ impl DeltaDataChecker { Self { invariants, constraints: vec![], - ctx: SessionContext::new(), + ctx: DeltaSessionContext::default().into(), } } @@ -1042,10 +1042,16 @@ impl DeltaDataChecker { Self { constraints, invariants: vec![], - ctx: SessionContext::new(), + ctx: DeltaSessionContext::default().into(), } } + /// Specify the Datafusion context + pub fn with_session_context(mut self, context: SessionContext) -> Self { + self.ctx = context; + self + } + /// Create a new DeltaDataChecker pub fn new(snapshot: &DeltaTableState) -> Self { let metadata = snapshot.metadata(); @@ -1059,7 +1065,7 @@ impl DeltaDataChecker { Self { invariants, constraints, - ctx: SessionContext::new(), + ctx: DeltaSessionContext::default().into(), } } diff --git a/crates/deltalake-core/src/operations/constraints.rs b/crates/deltalake-core/src/operations/constraints.rs index 889e668b1a..ed5888bd13 100644 --- a/crates/deltalake-core/src/operations/constraints.rs +++ b/crates/deltalake-core/src/operations/constraints.rs @@ -8,11 +8,15 @@ use datafusion::execution::context::SessionState; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; +use datafusion_common::ToDFSchema; use futures::future::BoxFuture; use futures::StreamExt; use serde_json::json; -use crate::delta_datafusion::{register_store, DeltaDataChecker, DeltaScanBuilder}; +use crate::delta_datafusion::expr::fmt_expr_to_sql; +use crate::delta_datafusion::{ + register_store, DeltaDataChecker, DeltaScanBuilder, DeltaSessionContext, +}; use crate::kernel::{Action, CommitInfo, IsolationLevel, Metadata, Protocol}; use crate::logstore::LogStoreRef; use crate::operations::datafusion_utils::Expression; @@ -23,6 +27,8 @@ use crate::table::Constraint; use crate::DeltaTable; use crate::{DeltaResult, DeltaTableError}; +use super::datafusion_utils::into_expr; + /// Build a constraint to add to a table pub struct ConstraintBuilder { snapshot: DeltaTableState, @@ -47,10 +53,10 @@ impl ConstraintBuilder { /// Specify the constraint to be added pub fn with_constraint, E: Into>( mut self, - column: S, + name: S, expression: E, ) -> Self { - self.name = Some(column.into()); + self.name = Some(name.into()); self.expr = Some(expression.into()); self } @@ -75,15 +81,10 @@ impl std::future::IntoFuture for ConstraintBuilder { Some(v) => v, None => return Err(DeltaTableError::Generic("No name provided".to_string())), }; - let expr = match this.expr { - Some(Expression::String(s)) => s, - Some(Expression::DataFusion(e)) => e.to_string(), - None => { - return Err(DeltaTableError::Generic( - "No expression provided".to_string(), - )) - } - }; + + let expr = this + .expr + .ok_or_else(|| DeltaTableError::Generic("No Expresion provided".to_string()))?; let mut metadata = this .snapshot @@ -94,23 +95,29 @@ impl std::future::IntoFuture for ConstraintBuilder { if metadata.configuration.contains_key(&configuration_key) { return Err(DeltaTableError::Generic(format!( - "Constraint with name: {} already exists, expr: {}", - name, expr + "Constraint with name: {} already exists", + name ))); } let state = this.state.unwrap_or_else(|| { - let session = SessionContext::new(); + let session: SessionContext = DeltaSessionContext::default().into(); register_store(this.log_store.clone(), session.runtime_env()); session.state() }); - // Checker built here with the one time constraint to check. - let checker = DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr)]); let scan = DeltaScanBuilder::new(&this.snapshot, this.log_store.clone(), &state) .build() .await?; + let schema = scan.schema().to_dfschema()?; + let expr = into_expr(expr, &schema, &state)?; + let expr_str = fmt_expr_to_sql(&expr)?; + + // Checker built here with the one time constraint to check. + let checker = + DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr_str)]); + let plan: Arc = Arc::new(scan); let mut tasks = vec![]; for p in 0..plan.output_partitioning().partition_count() { @@ -140,9 +147,10 @@ impl std::future::IntoFuture for ConstraintBuilder { // We have validated the table passes it's constraints, now to add the constraint to // the table. - metadata - .configuration - .insert(format!("delta.constraints.{}", name), Some(expr.clone())); + metadata.configuration.insert( + format!("delta.constraints.{}", name), + Some(expr_str.clone()), + ); let old_protocol = this.snapshot.protocol(); let protocol = Protocol { @@ -162,12 +170,12 @@ impl std::future::IntoFuture for ConstraintBuilder { let operational_parameters = HashMap::from_iter([ ("name".to_string(), json!(&name)), - ("expr".to_string(), json!(&expr)), + ("expr".to_string(), json!(&expr_str)), ]); let operations = DeltaOperation::AddConstraint { name: name.clone(), - expr: expr.clone(), + expr: expr_str.clone(), }; let commit_info = CommitInfo { @@ -208,11 +216,37 @@ mod tests { use std::sync::Arc; use arrow_array::{Array, Int32Array, RecordBatch, StringArray}; + use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema}; + use datafusion_expr::{col, lit}; use crate::writer::test_utils::{create_bare_table, get_arrow_schema, get_record_batch}; - use crate::{DeltaOps, DeltaResult}; + use crate::{DeltaOps, DeltaResult, DeltaTable}; + + fn get_constraint(table: &DeltaTable, name: &str) -> String { + table + .metadata() + .unwrap() + .configuration + .get(name) + .unwrap() + .clone() + .unwrap() + } + + async fn get_constraint_op_params(table: &mut DeltaTable) -> String { + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[commit_info.len() - 1]; + last_commit + .operation_parameters + .as_ref() + .unwrap() + .get("expr") + .unwrap() + .as_str() + .unwrap() + .to_owned() + } - #[cfg(feature = "datafusion")] #[tokio::test] async fn add_constraint_with_invalid_data() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -225,12 +259,10 @@ mod tests { .add_constraint() .with_constraint("id", "value > 5") .await; - dbg!(&constraint); assert!(constraint.is_err()); Ok(()) } - #[cfg(feature = "datafusion")] #[tokio::test] async fn add_valid_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -239,18 +271,89 @@ mod tests { .await?; let table = DeltaOps(write); - let constraint = table + let mut table = table .add_constraint() - .with_constraint("id", "value < 1000") - .await; - dbg!(&constraint); - assert!(constraint.is_ok()); - let version = constraint?.version(); + .with_constraint("id", "value < 1000") + .await?; + let version = table.version(); + assert_eq!(version, 1); + + let expected_expr = "value < 1000"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + assert_eq!( + get_constraint(&table, "delta.constraints.id"), + expected_expr + ); + Ok(()) + } + + #[tokio::test] + async fn add_constraint_datafusion() -> DeltaResult<()> { + // Add constraint by providing a datafusion expression. + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let mut table = table + .add_constraint() + .with_constraint("valid_values", col("value").lt(lit(1000))) + .await?; + let version = table.version(); assert_eq!(version, 1); + + let expected_expr = "value < 1000"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + assert_eq!( + get_constraint(&table, "delta.constraints.valid_values"), + expected_expr + ); + + Ok(()) + } + + #[tokio::test] + async fn test_constraint_case_sensitive() -> DeltaResult<()> { + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("Id", ArrowDataType::Utf8, true), + Field::new("vAlue", ArrowDataType::Int32, true), + Field::new("mOdifieD", ArrowDataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&arrow_schema.clone()), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + + let table = DeltaOps::new_in_memory().write(vec![batch]).await.unwrap(); + + let mut table = DeltaOps(table) + .add_constraint() + .with_constraint("valid_values", "vAlue < 1000") + .await?; + let version = table.version(); + assert_eq!(version, 1); + + let expected_expr = "vAlue < 1000"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + assert_eq!( + get_constraint(&table, "delta.constraints.valid_values"), + expected_expr + ); + Ok(()) } - #[cfg(feature = "datafusion")] #[tokio::test] async fn add_conflicting_named_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -269,12 +372,10 @@ mod tests { .add_constraint() .with_constraint("id", "value < 10") .await; - dbg!(&second_constraint); assert!(second_constraint.is_err()); Ok(()) } - #[cfg(feature = "datafusion")] #[tokio::test] async fn write_data_that_violates_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -294,7 +395,6 @@ mod tests { ]; let batch = RecordBatch::try_new(get_arrow_schema(&None), invalid_values)?; let err = table.write(vec![batch]).await; - dbg!(&err); assert!(err.is_err()); Ok(()) }