From f0c7e60eba85776fb63c23af54810510aca5c95e Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 9 Jul 2022 18:54:05 -0700 Subject: [PATCH 01/15] chore: start thinking about getting invariants --- Cargo.lock | 4 +- rust/src/schema.rs | 115 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9559797840..7ba2ab3f3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1848,9 +1848,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-src" -version = "111.22.0+1.1.1q" +version = "111.21.0+1.1.1p" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f31f0d509d1c1ae9cada2f9539ff8f37933831fd5098879e482aa687d659853" +checksum = "6d0a8313729211913936f1b95ca47a5fc7f2e04cd658c115388287f8a8361008" dependencies = [ "cc", ] diff --git a/rust/src/schema.rs b/rust/src/schema.rs index 9bbb9b0807..e71d997347 100644 --- a/rust/src/schema.rs +++ b/rust/src/schema.rs @@ -62,6 +62,55 @@ impl SchemaTypeStruct { name, valid_fields ))) } + + /// Get all invariants in the schemas + pub fn get_invariants(&self) -> Result, crate::DeltaTableError> { + let mut remaining_fields: Vec = self.get_fields().clone(); + let mut invariants: Vec = Vec::new(); + + while let Some(field) = remaining_fields.pop() { + match field.r#type { + SchemaDataType::r#struct(inner) => { + remaining_fields.extend(inner.get_fields().clone()); + } + SchemaDataType::array(inner) => { + remaining_fields.push(SchemaField::new( + "dummy".to_string(), + *inner.elementType, + false, + HashMap::new(), + )); + } + SchemaDataType::map(inner) => { + remaining_fields.push(SchemaField::new( + "dummy".to_string(), + *inner.keyType, + false, + HashMap::new(), + )); + remaining_fields.push(SchemaField::new( + "dummy".to_string(), + *inner.valueType, + false, + HashMap::new(), + )); + } + _ => {} + } + // JSON format: {"expression": {"expression": ""} } + if let Some(Value::String(invariant_json)) = field.metadata.get("delta.invariants") { + let json: Value = serde_json::from_str(invariant_json)?; + if let Value::Object(json) = json { + if let Some(Value::Object(expr1)) = json.get("expression") { + if let Some(Value::String(sql)) = expr1.get("expression") { + invariants.push(sql.clone()); + } + } + } + } + } + Ok(invariants) + } } /// Describes a specific field of the Delta table schema. @@ -219,3 +268,69 @@ pub enum SchemaDataType { /// Represents the schema of the delta table. pub type Schema = SchemaTypeStruct; + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_get_invariants() { + let schema: Schema = serde_json::from_value(json!({ + "type": "struct", + "fields": [{"name": "x", "type": "string", "nullable": true, "metadata": {}}] + })) + .unwrap(); + let invariants = schema.get_invariants().unwrap(); + assert_eq!(invariants.len(), 0); + + let schema: Schema = serde_json::from_value(json!({ + "type": "struct", + "fields": [ + {"name": "x", "type": "integer", "nullable": true, "metadata": { + "delta.invariants": "{\"expression\": { \"expression\": \"x > 2\"} }" + }}, + {"name": "y", "type": "integer", "nullable": true, "metadata": { + "delta.invariants": "{\"expression\": { \"expression\": \"y < 4\"} }" + }} + ] + })) + .unwrap(); + let invariants = schema.get_invariants().unwrap(); + assert_eq!(invariants.len(), 2); + assert!(invariants.contains(&"x > 2".to_string())); + assert!(invariants.contains(&"y < 4".to_string())); + + let schema: Schema = serde_json::from_value(json!({ + "type": "struct", + "fields": [{ + "name": "a_map", + "type": { + "type": "map", + "keyType": "string", + "valueType": { + "type": "array", + "elementType": { + "type": "struct", + "fields": [{ + "name": "d", + "type": "integer", + "metadata": { + "delta.invariants": "{\"expression\": { \"expression\": \"a_map.value.element.d < 4\"} }" + }, + "nullable": false + }] + }, + "containsNull": false + }, + "valueContainsNull": false + }, + "nullable": false, + "metadata": {} + }] + })).unwrap(); + let invariants = schema.get_invariants().unwrap(); + assert_eq!(invariants.len(), 1); + assert_eq!(invariants[0], "a_map.value.element.d < 4"); + } +} From b11004e77d657d183b19756b9b81b292b7810c19 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Wed, 21 Sep 2022 21:06:30 -0700 Subject: [PATCH 02/15] Start parsing nested --- rust/src/delta_datafusion.rs | 12 +++++ rust/src/schema.rs | 86 ++++++++++++++++++++++++++---------- 2 files changed, 74 insertions(+), 24 deletions(-) diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index 8fd4fd6d5b..b3f76886f2 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -603,6 +603,18 @@ fn left_larger_than_right(left: ScalarValue, right: ScalarValue) -> Option } } +fn enforce_invariants( + record_batch: RecordBatch, + invariants: &Vec<(String, String)>, +) -> Result<(), DeltaTableError> { + let ctx = SessionContext::new(); + // TODO: How does one query a record batch in data fusion?? + for invariant in invariants.iter() { + let sql = format!("SELECT {} FROM data WHERE {} LIMIT 1", name, invariant); + } + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/src/schema.rs b/rust/src/schema.rs index e71d997347..aa2df0ceb2 100644 --- a/rust/src/schema.rs +++ b/rust/src/schema.rs @@ -64,35 +64,66 @@ impl SchemaTypeStruct { } /// Get all invariants in the schemas - pub fn get_invariants(&self) -> Result, crate::DeltaTableError> { - let mut remaining_fields: Vec = self.get_fields().clone(); - let mut invariants: Vec = Vec::new(); + pub fn get_invariants(&self) -> Result, crate::DeltaTableError> { + let mut remaining_fields: Vec<(String, SchemaField)> = self + .get_fields() + .iter() + .map(|field| ("".to_string(), field.clone())) + .collect(); + let mut invariants: Vec<(String, String)> = Vec::new(); - while let Some(field) = remaining_fields.pop() { + let add_segment = |prefix: &str, segment: &str| -> String { + if prefix.is_empty() { + segment.to_owned() + } else { + format!("{}.{}", prefix, segment) + } + }; + + while let Some((prefix, field)) = remaining_fields.pop() { match field.r#type { SchemaDataType::r#struct(inner) => { - remaining_fields.extend(inner.get_fields().clone()); + let new_prefix = add_segment(&prefix, &field.name); + remaining_fields.extend( + inner + .get_fields() + .iter() + .map(|field| (new_prefix.clone(), field.clone())) + .collect::>(), + ); } SchemaDataType::array(inner) => { - remaining_fields.push(SchemaField::new( - "dummy".to_string(), - *inner.elementType, - false, - HashMap::new(), + let new_prefix = add_segment(&prefix, "elements"); + remaining_fields.push(( + new_prefix, + SchemaField::new( + "dummy".to_string(), + *inner.elementType, + false, + HashMap::new(), + ), )); } SchemaDataType::map(inner) => { - remaining_fields.push(SchemaField::new( - "dummy".to_string(), - *inner.keyType, - false, - HashMap::new(), + let new_prefix = add_segment(&prefix, "keys"); + remaining_fields.push(( + new_prefix, + SchemaField::new( + "dummy".to_string(), + *inner.keyType, + false, + HashMap::new(), + ), )); - remaining_fields.push(SchemaField::new( - "dummy".to_string(), - *inner.valueType, - false, - HashMap::new(), + let new_prefix = add_segment(&prefix, "values"); + remaining_fields.push(( + new_prefix, + SchemaField::new( + "dummy".to_string(), + *inner.valueType, + false, + HashMap::new(), + ), )); } _ => {} @@ -103,7 +134,8 @@ impl SchemaTypeStruct { if let Value::Object(json) = json { if let Some(Value::Object(expr1)) = json.get("expression") { if let Some(Value::String(sql)) = expr1.get("expression") { - invariants.push(sql.clone()); + let full_field_name = add_segment(&prefix,& field.name); + invariants.push((full_field_name, sql.clone())); } } } @@ -298,8 +330,8 @@ mod tests { .unwrap(); let invariants = schema.get_invariants().unwrap(); assert_eq!(invariants.len(), 2); - assert!(invariants.contains(&"x > 2".to_string())); - assert!(invariants.contains(&"y < 4".to_string())); + assert!(invariants.contains(&("x".to_string(), "x > 2".to_string()))); + assert!(invariants.contains(&("y".to_string(), "y < 4".to_string()))); let schema: Schema = serde_json::from_value(json!({ "type": "struct", @@ -331,6 +363,12 @@ mod tests { })).unwrap(); let invariants = schema.get_invariants().unwrap(); assert_eq!(invariants.len(), 1); - assert_eq!(invariants[0], "a_map.value.element.d < 4"); + assert_eq!( + invariants[0], + ( + "a_map.value.element.d".to_string(), + "a_map.value.element.d < 4".to_string() + ) + ); } } From 3f8a5013f7475923ee5f29a12e6d9f7e6a5cafa5 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 22 Sep 2022 17:53:44 -0700 Subject: [PATCH 03/15] fix: Resolve nested fields correctly --- rust/src/schema.rs | 46 ++++++++++++++++------------------------------ 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/rust/src/schema.rs b/rust/src/schema.rs index aa2df0ceb2..6250e86444 100644 --- a/rust/src/schema.rs +++ b/rust/src/schema.rs @@ -68,7 +68,7 @@ impl SchemaTypeStruct { let mut remaining_fields: Vec<(String, SchemaField)> = self .get_fields() .iter() - .map(|field| ("".to_string(), field.clone())) + .map(|field| (field.name.clone(), field.clone())) .collect(); let mut invariants: Vec<(String, String)> = Vec::new(); @@ -80,50 +80,37 @@ impl SchemaTypeStruct { } }; - while let Some((prefix, field)) = remaining_fields.pop() { + while let Some((field_path, field)) = remaining_fields.pop() { match field.r#type { SchemaDataType::r#struct(inner) => { - let new_prefix = add_segment(&prefix, &field.name); remaining_fields.extend( inner .get_fields() .iter() - .map(|field| (new_prefix.clone(), field.clone())) + .map(|field| { + let new_prefix = add_segment(&field_path, &field.name); + (new_prefix, field.clone()) + }) .collect::>(), ); } SchemaDataType::array(inner) => { - let new_prefix = add_segment(&prefix, "elements"); + let element_field_name = add_segment(&field_path, "element"); remaining_fields.push(( - new_prefix, - SchemaField::new( - "dummy".to_string(), - *inner.elementType, - false, - HashMap::new(), - ), + element_field_name, + SchemaField::new("".to_string(), *inner.elementType, false, HashMap::new()), )); } SchemaDataType::map(inner) => { - let new_prefix = add_segment(&prefix, "keys"); + let key_field_name = add_segment(&field_path, "key"); remaining_fields.push(( - new_prefix, - SchemaField::new( - "dummy".to_string(), - *inner.keyType, - false, - HashMap::new(), - ), + key_field_name, + SchemaField::new("".to_string(), *inner.keyType, false, HashMap::new()), )); - let new_prefix = add_segment(&prefix, "values"); + let value_field_name = add_segment(&field_path, "value"); remaining_fields.push(( - new_prefix, - SchemaField::new( - "dummy".to_string(), - *inner.valueType, - false, - HashMap::new(), - ), + value_field_name, + SchemaField::new("".to_string(), *inner.valueType, false, HashMap::new()), )); } _ => {} @@ -134,8 +121,7 @@ impl SchemaTypeStruct { if let Value::Object(json) = json { if let Some(Value::Object(expr1)) = json.get("expression") { if let Some(Value::String(sql)) = expr1.get("expression") { - let full_field_name = add_segment(&prefix,& field.name); - invariants.push((full_field_name, sql.clone())); + invariants.push((field_path, sql.clone())); } } } From 5c5d3bf8250d69a277bba8da0f403e7088636cf8 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 22 Sep 2022 19:55:31 -0700 Subject: [PATCH 04/15] feat: implement enforcement --- Cargo.lock | 4 +- rust/src/delta.rs | 6 ++ rust/src/delta_datafusion.rs | 112 ++++++++++++++++++++++++++++++++--- 3 files changed, 111 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7ba2ab3f3c..9559797840 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1848,9 +1848,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-src" -version = "111.21.0+1.1.1p" +version = "111.22.0+1.1.1q" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0a8313729211913936f1b95ca47a5fc7f2e04cd658c115388287f8a8361008" +checksum = "8f31f0d509d1c1ae9cada2f9539ff8f37933831fd5098879e482aa687d659853" dependencies = [ "cc", ] diff --git a/rust/src/delta.rs b/rust/src/delta.rs index 658d6de8c4..60ccf64032 100644 --- a/rust/src/delta.rs +++ b/rust/src/delta.rs @@ -141,6 +141,12 @@ pub enum DeltaTableError { #[from] source: action::ActionError, }, + /// Error returned when attempting to write bad data to the table + #[error("Attempted to write invalid data to the table: {:#?}", violations)] + InvalidData { + /// Action error details returned of the invalid action. + violations: Vec, + }, /// Error returned when it is not a DeltaTable. #[error("Not a Delta table: {0}")] NotATable(String), diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index b3f76886f2..391e7c0f9a 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -32,11 +32,12 @@ use crate::{DeltaTable, DeltaTableError}; use arrow::array::ArrayRef; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType as ArrowDataType, Schema as ArrowSchema, TimeUnit}; +use arrow::record_batch::RecordBatch; use async_trait::async_trait; use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion::datasource::file_format::{parquet::ParquetFormat, FileFormat}; -use datafusion::datasource::{listing::PartitionedFile, TableProvider, TableType}; -use datafusion::execution::context::SessionState; +use datafusion::datasource::{listing::PartitionedFile, MemTable, TableProvider, TableType}; +use datafusion::execution::context::{SessionContext, SessionState}; use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use datafusion::physical_plan::file_format::FileScanConfig; use datafusion::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; @@ -603,23 +604,58 @@ fn left_larger_than_right(left: ScalarValue, right: ScalarValue) -> Option } } -fn enforce_invariants( - record_batch: RecordBatch, +/// Checks that the record batch adheres to the given invariants. +pub fn enforce_invariants( + record_batch: &RecordBatch, invariants: &Vec<(String, String)>, ) -> Result<(), DeltaTableError> { + // Invariants are deprecated, so let's not pay the overhead for any of this + // if we can avoid it. + if invariants.is_empty() { + return Ok(()); + } + let ctx = SessionContext::new(); - // TODO: How does one query a record batch in data fusion?? - for invariant in invariants.iter() { - let sql = format!("SELECT {} FROM data WHERE {} LIMIT 1", name, invariant); + let table = MemTable::try_new(record_batch.schema(), vec![vec![record_batch.clone()]])?; + ctx.register_table("data", Arc::new(table))?; + let rt = tokio::runtime::Runtime::new().unwrap(); + + let mut violations: Vec = Vec::new(); + + for (column_name, invariant) in invariants.iter() { + if column_name.contains(".") { + return Err(DeltaTableError::Generic( + "Support for column invariants on nested columns is not supported.".to_string(), + )); + } + + let sql = format!( + "SELECT {} FROM data WHERE not ({}) LIMIT 1", + column_name, invariant + ); + + let dfs: Vec = rt.block_on(async { ctx.sql(&sql).await?.collect().await })?; + if !dfs.is_empty() && dfs[0].num_rows() > 0 { + let value = format!("{:?}", dfs[0].column(0)); + let msg = format!("Invariant ({}) violated by value {}", invariant, value); + violations.push(msg); + } + } + + if !violations.is_empty() { + Err(DeltaTableError::InvalidData { violations }) + } else { + Ok(()) } - Ok(()) } #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::Field; + use arrow::array::StructArray; + use arrow::datatypes::{DataType, Field, Schema}; use chrono::{TimeZone, Utc}; + use datafusion::from_slice::FromSlice; use serde_json::json; // test deserialization of serialized partition values. @@ -730,4 +766,62 @@ mod tests { }; assert_eq!(file.partition_values, ref_file.partition_values) } + + #[test] + fn test_enforce_invariants() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from_slice(&["a", "b", "c", "d"])), + Arc::new(arrow::array::Int32Array::from_slice(&[1, 10, 10, 100])), + ], + ) + .unwrap(); + // Empty invariants is okay + let invariants: Vec<(String, String)> = vec![]; + assert!(enforce_invariants(&batch, &invariants).is_ok()); + + // Valid invariants return Ok(()) + let invariants = vec![ + ("a".to_string(), "a is not null".to_string()), + ("b".to_string(), "b < 1000".to_string()), + ]; + assert!(enforce_invariants(&batch, &invariants).is_ok()); + + // Violated invariants returns an error with list of violations + let invariants = vec![ + ("a".to_string(), "a is null".to_string()), + ("b".to_string(), "b < 100".to_string()), + ]; + let result = enforce_invariants(&batch, &invariants); + assert!(result.is_err()); + assert!(matches!(result, Err(DeltaTableError::InvalidData { .. }))); + if let Err(DeltaTableError::InvalidData { violations }) = result { + assert_eq!(violations.len(), 2); + } + + // Irrelevant invariants return a different error + let invariants = vec![("c".to_string(), "c > 2000".to_string())]; + let result = enforce_invariants(&batch, &invariants); + assert!(result.is_err()); + + // Nested invariants are unsupported + let struct_fields = schema.fields().clone(); + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Struct(struct_fields), + false, + )])); + let inner = Arc::new(StructArray::from(batch)); + let batch = RecordBatch::try_new(schema, vec![inner]).unwrap(); + + let invariants = vec![("x.b".to_string(), "x.b < 1000".to_string())]; + let result = enforce_invariants(&batch, &invariants); + assert!(result.is_err()); + assert!(matches!(result, Err(DeltaTableError::Generic { .. }))); + } } From c29f5769234f3f799a0f752e50f332f590c28ea7 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 22 Sep 2022 20:14:24 -0700 Subject: [PATCH 05/15] feat: Move invariant into documented struct --- rust/src/delta_datafusion.rs | 33 +++++++++++++++++++++------------ rust/src/schema.rs | 34 +++++++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index 391e7c0f9a..7172189c33 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -47,6 +47,12 @@ use datafusion_expr::{combine_filters, Expr}; use object_store::{path::Path, ObjectMeta}; use url::Url; +use crate::action; +use crate::delta; +use crate::schema; +use crate::DeltaTableError; +use crate::Invariant; + impl From for DataFusionError { fn from(err: DeltaTableError) -> Self { match err { @@ -607,7 +613,7 @@ fn left_larger_than_right(left: ScalarValue, right: ScalarValue) -> Option /// Checks that the record batch adheres to the given invariants. pub fn enforce_invariants( record_batch: &RecordBatch, - invariants: &Vec<(String, String)>, + invariants: &Vec, ) -> Result<(), DeltaTableError> { // Invariants are deprecated, so let's not pay the overhead for any of this // if we can avoid it. @@ -622,8 +628,8 @@ pub fn enforce_invariants( let mut violations: Vec = Vec::new(); - for (column_name, invariant) in invariants.iter() { - if column_name.contains(".") { + for invariant in invariants.iter() { + if invariant.field_name.contains('.') { return Err(DeltaTableError::Generic( "Support for column invariants on nested columns is not supported.".to_string(), )); @@ -631,13 +637,16 @@ pub fn enforce_invariants( let sql = format!( "SELECT {} FROM data WHERE not ({}) LIMIT 1", - column_name, invariant + invariant.field_name, invariant.invariant_sql ); let dfs: Vec = rt.block_on(async { ctx.sql(&sql).await?.collect().await })?; if !dfs.is_empty() && dfs[0].num_rows() > 0 { let value = format!("{:?}", dfs[0].column(0)); - let msg = format!("Invariant ({}) violated by value {}", invariant, value); + let msg = format!( + "Invariant ({}) violated by value {}", + invariant.invariant_sql, value + ); violations.push(msg); } } @@ -782,20 +791,20 @@ mod tests { ) .unwrap(); // Empty invariants is okay - let invariants: Vec<(String, String)> = vec![]; + let invariants: Vec = vec![]; assert!(enforce_invariants(&batch, &invariants).is_ok()); // Valid invariants return Ok(()) let invariants = vec![ - ("a".to_string(), "a is not null".to_string()), - ("b".to_string(), "b < 1000".to_string()), + Invariant::new("a", "a is not null"), + Invariant::new("b", "b < 1000"), ]; assert!(enforce_invariants(&batch, &invariants).is_ok()); // Violated invariants returns an error with list of violations let invariants = vec![ - ("a".to_string(), "a is null".to_string()), - ("b".to_string(), "b < 100".to_string()), + Invariant::new("a", "a is null"), + Invariant::new("b", "b < 100"), ]; let result = enforce_invariants(&batch, &invariants); assert!(result.is_err()); @@ -805,7 +814,7 @@ mod tests { } // Irrelevant invariants return a different error - let invariants = vec![("c".to_string(), "c > 2000".to_string())]; + let invariants = vec![Invariant::new("c", "c > 2000")]; let result = enforce_invariants(&batch, &invariants); assert!(result.is_err()); @@ -819,7 +828,7 @@ mod tests { let inner = Arc::new(StructArray::from(batch)); let batch = RecordBatch::try_new(schema, vec![inner]).unwrap(); - let invariants = vec![("x.b".to_string(), "x.b < 1000".to_string())]; + let invariants = vec![Invariant::new("x.b", "x.b < 1000")]; let result = enforce_invariants(&batch, &invariants); assert!(result.is_err()); assert!(matches!(result, Err(DeltaTableError::Generic { .. }))); diff --git a/rust/src/schema.rs b/rust/src/schema.rs index 6250e86444..da30de6e63 100644 --- a/rust/src/schema.rs +++ b/rust/src/schema.rs @@ -21,6 +21,25 @@ static STRUCT_TAG: &str = "struct"; static ARRAY_TAG: &str = "array"; static MAP_TAG: &str = "map"; +/// An invariant for a column that is enforced on all writes to a Delta table. +#[derive(PartialEq, Debug, Default, Clone)] +pub struct Invariant { + /// The full path to the field. + pub field_name: String, + /// The SQL string that must always evaluate to true. + pub invariant_sql: String, +} + +impl Invariant { + /// Create a new invariant + pub fn new(field_name: &str, invariant_sql: &str) -> Self { + Invariant { + field_name: field_name.to_string(), + invariant_sql: invariant_sql.to_string(), + } + } +} + /// Represents a struct field defined in the Delta table schema. // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#Schema-Serialization-Format #[derive(Serialize, Deserialize, PartialEq, Debug, Default, Clone)] @@ -64,13 +83,13 @@ impl SchemaTypeStruct { } /// Get all invariants in the schemas - pub fn get_invariants(&self) -> Result, crate::DeltaTableError> { + pub fn get_invariants(&self) -> Result, crate::DeltaTableError> { let mut remaining_fields: Vec<(String, SchemaField)> = self .get_fields() .iter() .map(|field| (field.name.clone(), field.clone())) .collect(); - let mut invariants: Vec<(String, String)> = Vec::new(); + let mut invariants: Vec = Vec::new(); let add_segment = |prefix: &str, segment: &str| -> String { if prefix.is_empty() { @@ -121,7 +140,7 @@ impl SchemaTypeStruct { if let Value::Object(json) = json { if let Some(Value::Object(expr1)) = json.get("expression") { if let Some(Value::String(sql)) = expr1.get("expression") { - invariants.push((field_path, sql.clone())); + invariants.push(Invariant::new(&field_path, sql)); } } } @@ -316,8 +335,8 @@ mod tests { .unwrap(); let invariants = schema.get_invariants().unwrap(); assert_eq!(invariants.len(), 2); - assert!(invariants.contains(&("x".to_string(), "x > 2".to_string()))); - assert!(invariants.contains(&("y".to_string(), "y < 4".to_string()))); + assert!(invariants.contains(&Invariant::new("x", "x > 2"))); + assert!(invariants.contains(&Invariant::new("y", "y < 4"))); let schema: Schema = serde_json::from_value(json!({ "type": "struct", @@ -351,10 +370,7 @@ mod tests { assert_eq!(invariants.len(), 1); assert_eq!( invariants[0], - ( - "a_map.value.element.d".to_string(), - "a_map.value.element.d < 4".to_string() - ) + Invariant::new("a_map.value.element.d", "a_map.value.element.d < 4") ); } } From 7843116b8c67d911dcba4ce05b2b6ce49b23a5a3 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 22 Sep 2022 20:29:51 -0700 Subject: [PATCH 06/15] feat: put data checker in nice struct --- rust/src/delta_datafusion.rs | 109 ++++++++++++++++++++++------------- rust/src/schema.rs | 2 +- 2 files changed, 69 insertions(+), 42 deletions(-) diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index 7172189c33..2e6e5fa711 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -610,51 +610,74 @@ fn left_larger_than_right(left: ScalarValue, right: ScalarValue) -> Option } } -/// Checks that the record batch adheres to the given invariants. -pub fn enforce_invariants( - record_batch: &RecordBatch, - invariants: &Vec, -) -> Result<(), DeltaTableError> { - // Invariants are deprecated, so let's not pay the overhead for any of this - // if we can avoid it. - if invariants.is_empty() { - return Ok(()); - } +/// Responsible for checking batches of data conform to table's invariants. +pub struct DeltaDataChecker { + invariants: Vec, + ctx: SessionContext, + rt: tokio::runtime::Runtime, +} - let ctx = SessionContext::new(); - let table = MemTable::try_new(record_batch.schema(), vec![vec![record_batch.clone()]])?; - ctx.register_table("data", Arc::new(table))?; - let rt = tokio::runtime::Runtime::new().unwrap(); +impl DeltaDataChecker { + /// Create a new DeltaDataChecker + pub fn new(invariants: Vec) -> Self { + Self { + invariants, + ctx: SessionContext::new(), + rt: tokio::runtime::Runtime::new().unwrap(), + } + } - let mut violations: Vec = Vec::new(); + /// Check that a record batch conforms to table's invariants. + /// + /// If it does not, it will return [DeltaTableError::InvalidData] with a list + /// of values that violated each invariant. + pub fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { + self.enforce_invariants(record_batch) + // TODO: for support for Protocol V3, check constraints + } - for invariant in invariants.iter() { - if invariant.field_name.contains('.') { - return Err(DeltaTableError::Generic( - "Support for column invariants on nested columns is not supported.".to_string(), - )); + fn enforce_invariants(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { + // Invariants are deprecated, so let's not pay the overhead for any of this + // if we can avoid it. + if self.invariants.is_empty() { + return Ok(()); } - let sql = format!( - "SELECT {} FROM data WHERE not ({}) LIMIT 1", - invariant.field_name, invariant.invariant_sql - ); + let table = MemTable::try_new(record_batch.schema(), vec![vec![record_batch.clone()]])?; + self.ctx.register_table("data", Arc::new(table))?; + + let mut violations: Vec = Vec::new(); - let dfs: Vec = rt.block_on(async { ctx.sql(&sql).await?.collect().await })?; - if !dfs.is_empty() && dfs[0].num_rows() > 0 { - let value = format!("{:?}", dfs[0].column(0)); - let msg = format!( - "Invariant ({}) violated by value {}", - invariant.invariant_sql, value + for invariant in self.invariants.iter() { + if invariant.field_name.contains('.') { + return Err(DeltaTableError::Generic( + "Support for column invariants on nested columns is not supported.".to_string(), + )); + } + + let sql = format!( + "SELECT {} FROM data WHERE not ({}) LIMIT 1", + invariant.field_name, invariant.invariant_sql ); - violations.push(msg); + + let dfs: Vec = self + .rt + .block_on(async { self.ctx.sql(&sql).await?.collect().await })?; + if !dfs.is_empty() && dfs[0].num_rows() > 0 { + let value = format!("{:?}", dfs[0].column(0)); + let msg = format!( + "Invariant ({}) violated by value {}", + invariant.invariant_sql, value + ); + violations.push(msg); + } } - } - if !violations.is_empty() { - Err(DeltaTableError::InvalidData { violations }) - } else { - Ok(()) + if !violations.is_empty() { + Err(DeltaTableError::InvalidData { violations }) + } else { + Ok(()) + } } } @@ -792,21 +815,25 @@ mod tests { .unwrap(); // Empty invariants is okay let invariants: Vec = vec![]; - assert!(enforce_invariants(&batch, &invariants).is_ok()); + assert!(DeltaDataChecker::new(invariants) + .check_batch(&batch) + .is_ok()); // Valid invariants return Ok(()) let invariants = vec![ Invariant::new("a", "a is not null"), Invariant::new("b", "b < 1000"), ]; - assert!(enforce_invariants(&batch, &invariants).is_ok()); + assert!(DeltaDataChecker::new(invariants) + .check_batch(&batch) + .is_ok()); // Violated invariants returns an error with list of violations let invariants = vec![ Invariant::new("a", "a is null"), Invariant::new("b", "b < 100"), ]; - let result = enforce_invariants(&batch, &invariants); + let result = DeltaDataChecker::new(invariants).check_batch(&batch); assert!(result.is_err()); assert!(matches!(result, Err(DeltaTableError::InvalidData { .. }))); if let Err(DeltaTableError::InvalidData { violations }) = result { @@ -815,7 +842,7 @@ mod tests { // Irrelevant invariants return a different error let invariants = vec![Invariant::new("c", "c > 2000")]; - let result = enforce_invariants(&batch, &invariants); + let result = DeltaDataChecker::new(invariants).check_batch(&batch); assert!(result.is_err()); // Nested invariants are unsupported @@ -829,7 +856,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![inner]).unwrap(); let invariants = vec![Invariant::new("x.b", "x.b < 1000")]; - let result = enforce_invariants(&batch, &invariants); + let result = DeltaDataChecker::new(invariants).check_batch(&batch); assert!(result.is_err()); assert!(matches!(result, Err(DeltaTableError::Generic { .. }))); } diff --git a/rust/src/schema.rs b/rust/src/schema.rs index da30de6e63..eb8183ffa8 100644 --- a/rust/src/schema.rs +++ b/rust/src/schema.rs @@ -33,7 +33,7 @@ pub struct Invariant { impl Invariant { /// Create a new invariant pub fn new(field_name: &str, invariant_sql: &str) -> Self { - Invariant { + Self { field_name: field_name.to_string(), invariant_sql: invariant_sql.to_string(), } From a3cb00bc762c0d4a398de19e55fc743350d001bf Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 24 Sep 2022 10:20:37 -0700 Subject: [PATCH 07/15] Add writer integration tests --- .../test_write_to_pyspark.py | 107 ++++++++++++++++++ .../test_writer_readable.py | 38 +------ python/tests/pyspark_integration/utils.py | 49 ++++++++ 3 files changed, 157 insertions(+), 37 deletions(-) create mode 100644 python/tests/pyspark_integration/test_write_to_pyspark.py create mode 100644 python/tests/pyspark_integration/utils.py diff --git a/python/tests/pyspark_integration/test_write_to_pyspark.py b/python/tests/pyspark_integration/test_write_to_pyspark.py new file mode 100644 index 0000000000..ca07c8c86a --- /dev/null +++ b/python/tests/pyspark_integration/test_write_to_pyspark.py @@ -0,0 +1,107 @@ +"""Tests that deltalake(delta-rs) can write to tables written by PySpark""" +import pathlib + +import pyarrow as pa +import pytest + +from deltalake import write_deltalake +from deltalake.deltalake import PyDeltaTableError + +from .utils import assert_spark_read_equal, get_spark + +try: + from pandas.testing import assert_frame_equal +except ModuleNotFoundError: + _has_pandas = False +else: + _has_pandas = True + + +try: + import delta + import delta.pip_utils + import delta.tables + import pyspark + + spark = get_spark() +except ModuleNotFoundError: + pass + + +@pytest.mark.pyspark +@pytest.mark.integration +def test_write_basic(tmp_path: pathlib.Path, sample_data: pa.Table): + # Write table in Spark + spark = get_spark() + schema = pyspark.sql.types.StructType( + [ + pyspark.sql.types.StructField( + "c1", + dataType=pyspark.sql.types.IntegerType(), + nullable=True, + ) + ] + ) + spark.createDataFrame([(4,)], schema=schema).write.save( + str(tmp_path), + mode="append", + format="delta", + ) + # Overwrite table in deltalake + data = pa.table({"c1": pa.array([5, 6], type=pa.int32())}) + write_deltalake(str(tmp_path), data, mode="overwrite") + + # Read table in Spark + assert_spark_read_equal(data, str(tmp_path)) + + +@pytest.mark.pyspark +@pytest.mark.integration +def test_write_invariant(tmp_path: pathlib.Path): + # Write table in Spark with invariant + spark = get_spark() + + schema = pyspark.sql.types.StructType( + [ + pyspark.sql.types.StructField( + "c1", + dataType=pyspark.sql.types.IntegerType(), + nullable=True, + metadata={ + "delta.invariants": '{"expression": { "expression": "c1 > 3"} }' + }, + ) + ] + ) + + delta.tables.DeltaTable.create(spark).location(str(tmp_path)).addColumns( + schema + ).execute() + + spark.createDataFrame([(4,)], schema=schema).write.save( + str(tmp_path), + mode="append", + format="delta", + ) + + # Cannot write invalid data to the table + invalid_data = pa.table({"c1": pa.array([6, 2], type=pa.int32())}) + with pytest.raises( + PyDeltaTableError, match="Invariant (c1 > 3) violated by value .+2" + ): + write_deltalake(str(tmp_path), invalid_data, mode="overwrite") + + # Can write valid data to the table + valid_data = pa.table({"c1": pa.array([5, 6], type=pa.int32())}) + write_deltalake(str(tmp_path), valid_data, mode="append") + + expected = pa.table({"c1": pa.array([4, 5, 6], type=pa.int32())}) + assert_spark_read_equal(expected, str(tmp_path)) + + +@pytest.mark.pyspark +@pytest.mark.integration +def test_checks_min_writer_version(): + # Write table in Spark with constraint + # assert we fail to write any data to it + pass diff --git a/python/tests/pyspark_integration/test_writer_readable.py b/python/tests/pyspark_integration/test_writer_readable.py index c4f15721de..83e1172766 100644 --- a/python/tests/pyspark_integration/test_writer_readable.py +++ b/python/tests/pyspark_integration/test_writer_readable.py @@ -7,25 +7,7 @@ from deltalake import DeltaTable, write_deltalake -try: - from pandas.testing import assert_frame_equal -except ModuleNotFoundError: - _has_pandas = False -else: - _has_pandas = True - - -def get_spark(): - builder = ( - pyspark.sql.SparkSession.builder.appName("MyApp") - .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") - .config( - "spark.sql.catalog.spark_catalog", - "org.apache.spark.sql.delta.catalog.DeltaCatalog", - ) - ) - return delta.pip_utils.configure_spark_with_delta_pip(builder).getOrCreate() - +from .utils import assert_spark_read_equal, get_spark try: import delta @@ -38,24 +20,6 @@ def get_spark(): pass -def assert_spark_read_equal( - expected: pa.Table, uri: str, sort_by: List[str] = ["int32"] -): - df = spark.read.format("delta").load(uri) - - # Spark and pyarrow don't convert these types to the same Pandas values - incompatible_types = ["timestamp", "struct"] - - assert_frame_equal( - df.toPandas() - .sort_values(sort_by, ignore_index=True) - .drop(incompatible_types, axis="columns"), - expected.to_pandas() - .sort_values(sort_by, ignore_index=True) - .drop(incompatible_types, axis="columns"), - ) - - @pytest.mark.pyspark @pytest.mark.integration def test_basic_read(sample_data: pa.Table, existing_table: DeltaTable): diff --git a/python/tests/pyspark_integration/utils.py b/python/tests/pyspark_integration/utils.py new file mode 100644 index 0000000000..9d9e43929a --- /dev/null +++ b/python/tests/pyspark_integration/utils.py @@ -0,0 +1,49 @@ +from typing import List + +import pyarrow as pa + +try: + import delta + import delta.pip_utils + import delta.tables + import pyspark +except ModuleNotFoundError: + pass + +try: + from pandas.testing import assert_frame_equal +except ModuleNotFoundError: + _has_pandas = False +else: + _has_pandas = True + + +def get_spark(): + builder = ( + pyspark.sql.SparkSession.builder.appName("MyApp") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + ) + return delta.pip_utils.configure_spark_with_delta_pip(builder).getOrCreate() + + +def assert_spark_read_equal( + expected: pa.Table, uri: str, sort_by: List[str] = ["int32"] +): + spark = get_spark() + df = spark.read.format("delta").load(uri) + + # Spark and pyarrow don't convert these types to the same Pandas values + incompatible_types = ["timestamp", "struct"] + + assert_frame_equal( + df.toPandas() + .sort_values(sort_by, ignore_index=True) + .drop(incompatible_types, axis="columns"), + expected.to_pandas() + .sort_values(sort_by, ignore_index=True) + .drop(incompatible_types, axis="columns"), + ) From fcba790c7da8a055db8b392a545f8c929d7b8bdb Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 24 Sep 2022 12:51:49 -0700 Subject: [PATCH 08/15] Implement Python support for checking invariants --- python/Cargo.toml | 2 +- python/deltalake/_internal.pyi | 7 +++- python/deltalake/writer.py | 28 ++++++++++++++-- python/src/lib.rs | 33 ++++++++++++++++++- python/src/schema.rs | 18 ++++++++++ .../test_write_to_pyspark.py | 28 ++++++++++++---- python/tests/pyspark_integration/utils.py | 4 +-- rust/src/schema.rs | 2 +- 8 files changed, 107 insertions(+), 15 deletions(-) diff --git a/python/Cargo.toml b/python/Cargo.toml index 3d4aa11f2c..39344f54fb 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -36,4 +36,4 @@ features = ["extension-module", "abi3", "abi3-py37"] [dependencies.deltalake] path = "../rust" version = "0" -features = ["s3", "azure", "glue", "gcs", "python"] +features = ["s3", "azure", "glue", "gcs", "python", "datafusion-ext"] diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index ac9ba3db91..23e26fbf3a 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -1,5 +1,5 @@ import sys -from typing import Any, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union if sys.version_info >= (3, 8): from typing import Literal @@ -118,6 +118,7 @@ class StructType: class Schema: def __init__(self, fields: List[Field]) -> None: ... fields: List[Field] + invariants: List[Tuple[str, str]] def to_json(self) -> str: ... @staticmethod @@ -212,3 +213,7 @@ class DeltaFileSystemHandler: self, path: str, metadata: dict[str, str] | None = None ) -> ObjectOutputStream: """Open an output stream for sequential writing.""" + +class DeltaDataChecker: + def __init__(self, invariants: List[Tuple[str, str]]) -> None: ... + def check_batch(self, batch: pa.RecordBatch) -> None: ... diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 1d61862d6b..41269ffccf 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -35,6 +35,7 @@ import pyarrow.fs as pa_fs from pyarrow.lib import RecordBatchReader +from ._internal import DeltaDataChecker as _DeltaDataChecker from ._internal import PyDeltaTableError from ._internal import write_new_deltalake as _write_new_deltalake from .table import DeltaTable @@ -192,11 +193,11 @@ def write_deltalake( if partition_by: assert partition_by == table.metadata().partition_columns - if table.protocol().min_writer_version > 1: + if table.protocol().min_writer_version > 2: raise DeltaTableProtocolError( "This table's min_writer_version is " f"{table.protocol().min_writer_version}, " - "but this method only supports version 1." + "but this method only supports version 2." ) else: # creating a new table current_version = -1 @@ -234,6 +235,29 @@ def visitor(written_file: Any) -> None: ) ) + if table is not None: + # We don't currently provide a way to set invariants + # (and maybe never will), so only enforce if already exist. + invariants = table.schema().invariants + checker = _DeltaDataChecker(invariants) + + def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: + checker.check_batch(batch) + return batch + + if isinstance(data, RecordBatchReader): + batch_iter = data + elif isinstance(data, pa.RecordBatch): + batch_iter = [data] + elif isinstance(data, pa.Table): + batch_iter = data.to_batches() + else: + batch_iter = data + + data = RecordBatchReader.from_batches( + schema, (validate_batch(batch) for batch in batch_iter) + ) + ds.write_dataset( data, base_dir="/", diff --git a/python/src/lib.rs b/python/src/lib.rs index aa9a9477c5..85c0ec7522 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -8,14 +8,16 @@ use chrono::{DateTime, FixedOffset, Utc}; use deltalake::action::{ self, Action, ColumnCountStat, ColumnValueStat, DeltaOperation, SaveMode, Stats, }; +use deltalake::arrow::record_batch::RecordBatch; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::builder::DeltaTableBuilder; +use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::partitions::PartitionFilter; use deltalake::DeltaDataTypeLong; use deltalake::DeltaDataTypeTimestamp; use deltalake::DeltaTableMetaData; use deltalake::DeltaTransactionOptions; -use deltalake::Schema; +use deltalake::{Invariant, Schema}; use pyo3::create_exception; use pyo3::exceptions::PyException; use pyo3::exceptions::PyValueError; @@ -585,6 +587,34 @@ fn write_new_deltalake( Ok(()) } +#[pyclass(name = "DeltaDataChecker", text_signature = "(invariants)")] +struct PyDeltaDataChecker { + inner: DeltaDataChecker, +} + +#[pymethods] +impl PyDeltaDataChecker { + #[new] + fn new(invariants: Vec<(String, String)>) -> Self { + let invariants: Vec = invariants + .into_iter() + .map(|(field_name, invariant_sql)| Invariant { + field_name, + invariant_sql, + }) + .collect(); + Self { + inner: DeltaDataChecker::new(invariants), + } + } + + fn check_batch(&self, batch: RecordBatch) -> PyResult<()> { + self.inner + .check_batch(&batch) + .map_err(PyDeltaTableError::from_raw) + } +} + #[pymodule] // module name need to match project name fn _internal(py: Python, m: &PyModule) -> PyResult<()> { @@ -594,6 +624,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(pyo3::wrap_pyfunction!(write_new_deltalake, m)?)?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add("PyDeltaTableError", py.get_type::())?; // There are issues with submodules, so we will expose them flat for now // See also: https://github.com/PyO3/pyo3/issues/759 diff --git a/python/src/schema.rs b/python/src/schema.rs index e4c976cf43..d87e136965 100644 --- a/python/src/schema.rs +++ b/python/src/schema.rs @@ -1064,4 +1064,22 @@ impl PySchema { Err(PyTypeError::new_err("Type is not a struct")) } } + + /// The list of invariants on the table. + /// + /// :rtype: List[Tuple[str, str]] + /// :return: a tuple of strings for each invariant. The first string is the + /// field path and the second is the SQL of the invariant. + #[getter] + fn invariants(self_: PyRef<'_, Self>) -> PyResult> { + let super_ = self_.as_ref(); + let invariants = super_ + .inner_type + .get_invariants() + .map_err(|err| PyException::new_err(err.to_string()))?; + Ok(invariants + .into_iter() + .map(|invariant| (invariant.field_name, invariant.invariant_sql)) + .collect()) + } } diff --git a/python/tests/pyspark_integration/test_write_to_pyspark.py b/python/tests/pyspark_integration/test_write_to_pyspark.py index ca07c8c86a..20995147fb 100644 --- a/python/tests/pyspark_integration/test_write_to_pyspark.py +++ b/python/tests/pyspark_integration/test_write_to_pyspark.py @@ -30,7 +30,7 @@ @pytest.mark.pyspark @pytest.mark.integration -def test_write_basic(tmp_path: pathlib.Path, sample_data: pa.Table): +def test_write_basic(tmp_path: pathlib.Path): # Write table in Spark spark = get_spark() schema = pyspark.sql.types.StructType( @@ -52,7 +52,7 @@ def test_write_basic(tmp_path: pathlib.Path, sample_data: pa.Table): write_deltalake(str(tmp_path), data, mode="overwrite") # Read table in Spark - assert_spark_read_equal(data, str(tmp_path)) + assert_spark_read_equal(data, str(tmp_path), sort_by="c1") @pytest.mark.pyspark @@ -87,7 +87,7 @@ def test_write_invariant(tmp_path: pathlib.Path): # Cannot write invalid data to the table invalid_data = pa.table({"c1": pa.array([6, 2], type=pa.int32())}) with pytest.raises( - PyDeltaTableError, match="Invariant (c1 > 3) violated by value .+2" + PyDeltaTableError, match="Invariant \(c1 > 3\) violated by value .+2" ): write_deltalake(str(tmp_path), invalid_data, mode="overwrite") @@ -96,12 +96,26 @@ def test_write_invariant(tmp_path: pathlib.Path): write_deltalake(str(tmp_path), valid_data, mode="append") expected = pa.table({"c1": pa.array([4, 5, 6], type=pa.int32())}) - assert_spark_read_equal(expected, str(tmp_path)) + assert_spark_read_equal(expected, str(tmp_path), sort_by="c1") @pytest.mark.pyspark @pytest.mark.integration -def test_checks_min_writer_version(): +def test_checks_min_writer_version(tmp_path: pathlib.Path): # Write table in Spark with constraint - # assert we fail to write any data to it - pass + spark = get_spark() + + spark.createDataFrame([(4,)], schema=["c1"]).write.save( + str(tmp_path), + mode="append", + format="delta", + ) + + # Add a constraint upgrades the minWriterProtocol + spark.sql(f"ALTER TABLE delta.{str(tmp_path)} ADD CONSTRAINT x CHECK c1 > 2") + + with pytest.raises( + PyDeltaTableError, match="The table's min_writer_version is 3 but" + ): + valid_data = pa.table({"c1": pa.array([5, 6], type=pa.int32())}) + write_deltalake(str(tmp_path), valid_data, mode="append") diff --git a/python/tests/pyspark_integration/utils.py b/python/tests/pyspark_integration/utils.py index 9d9e43929a..5ec23317a0 100644 --- a/python/tests/pyspark_integration/utils.py +++ b/python/tests/pyspark_integration/utils.py @@ -42,8 +42,8 @@ def assert_spark_read_equal( assert_frame_equal( df.toPandas() .sort_values(sort_by, ignore_index=True) - .drop(incompatible_types, axis="columns"), + .drop(incompatible_types, axis="columns", errors="ignore"), expected.to_pandas() .sort_values(sort_by, ignore_index=True) - .drop(incompatible_types, axis="columns"), + .drop(incompatible_types, axis="columns", errors="ignore"), ) diff --git a/rust/src/schema.rs b/rust/src/schema.rs index eb8183ffa8..0449cadf3f 100644 --- a/rust/src/schema.rs +++ b/rust/src/schema.rs @@ -22,7 +22,7 @@ static ARRAY_TAG: &str = "array"; static MAP_TAG: &str = "map"; /// An invariant for a column that is enforced on all writes to a Delta table. -#[derive(PartialEq, Debug, Default, Clone)] +#[derive(Eq, PartialEq, Debug, Default, Clone)] pub struct Invariant { /// The full path to the field. pub field_name: String, From 7d8d25bb642b6e154b5de92e8974172f07d038dd Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 24 Sep 2022 13:22:48 -0700 Subject: [PATCH 09/15] Fix final test --- python/tests/pyspark_integration/test_write_to_pyspark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tests/pyspark_integration/test_write_to_pyspark.py b/python/tests/pyspark_integration/test_write_to_pyspark.py index 20995147fb..167f3dfd32 100644 --- a/python/tests/pyspark_integration/test_write_to_pyspark.py +++ b/python/tests/pyspark_integration/test_write_to_pyspark.py @@ -112,10 +112,10 @@ def test_checks_min_writer_version(tmp_path: pathlib.Path): ) # Add a constraint upgrades the minWriterProtocol - spark.sql(f"ALTER TABLE delta.{str(tmp_path)} ADD CONSTRAINT x CHECK c1 > 2") + spark.sql(f"ALTER TABLE delta.`{str(tmp_path)}` ADD CONSTRAINT x CHECK (c1 > 2)") with pytest.raises( - PyDeltaTableError, match="The table's min_writer_version is 3 but" + PyDeltaTableError, match="This table's min_writer_version is 3, but" ): - valid_data = pa.table({"c1": pa.array([5, 6], type=pa.int32())}) + valid_data = pa.table({"c1": pa.array([5, 6])}) write_deltalake(str(tmp_path), valid_data, mode="append") From 9c7d3a19c55feff559657f42d95b11671a32462d Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 24 Sep 2022 13:57:15 -0700 Subject: [PATCH 10/15] Fix imports --- rust/src/delta_datafusion.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index 2e6e5fa711..74c43fba66 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -47,10 +47,6 @@ use datafusion_expr::{combine_filters, Expr}; use object_store::{path::Path, ObjectMeta}; use url::Url; -use crate::action; -use crate::delta; -use crate::schema; -use crate::DeltaTableError; use crate::Invariant; impl From for DataFusionError { From a01368159278846e8d7d13c314cfe19731db2230 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 25 Sep 2022 11:38:47 -0700 Subject: [PATCH 11/15] Update tests --- .../pyspark_integration/test_write_to_pyspark.py | 14 ++++---------- python/tests/test_writer.py | 2 +- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/python/tests/pyspark_integration/test_write_to_pyspark.py b/python/tests/pyspark_integration/test_write_to_pyspark.py index 167f3dfd32..897746f694 100644 --- a/python/tests/pyspark_integration/test_write_to_pyspark.py +++ b/python/tests/pyspark_integration/test_write_to_pyspark.py @@ -5,18 +5,11 @@ import pytest from deltalake import write_deltalake -from deltalake.deltalake import PyDeltaTableError +from deltalake._internal import PyDeltaTableError +from deltalake.writer import DeltaTableProtocolError from .utils import assert_spark_read_equal, get_spark -try: - from pandas.testing import assert_frame_equal -except ModuleNotFoundError: - _has_pandas = False -else: - _has_pandas = True - - try: import delta import delta.pip_utils @@ -89,6 +82,7 @@ def test_write_invariant(tmp_path: pathlib.Path): with pytest.raises( PyDeltaTableError, match="Invariant \(c1 > 3\) violated by value .+2" ): + # raise PyDeltaTableError("test") write_deltalake(str(tmp_path), invalid_data, mode="overwrite") # Can write valid data to the table @@ -115,7 +109,7 @@ def test_checks_min_writer_version(tmp_path: pathlib.Path): spark.sql(f"ALTER TABLE delta.`{str(tmp_path)}` ADD CONSTRAINT x CHECK (c1 > 2)") with pytest.raises( - PyDeltaTableError, match="This table's min_writer_version is 3, but" + DeltaTableProtocolError, match="This table's min_writer_version is 3, but" ): valid_data = pa.table({"c1": pa.array([5, 6])}) write_deltalake(str(tmp_path), valid_data, mode="append") diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 23f6c37cec..cda1c3601e 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -420,7 +420,7 @@ def test_writer_null_stats(tmp_path: pathlib.Path): def test_writer_fails_on_protocol(existing_table: DeltaTable, sample_data: pa.Table): - existing_table.protocol = Mock(return_value=ProtocolVersions(1, 2)) + existing_table.protocol = Mock(return_value=ProtocolVersions(1, 3)) with pytest.raises(DeltaTableProtocolError): write_deltalake(existing_table, sample_data, mode="overwrite") From c91d58e97f1b023b64faacc9d04268e708e1b513 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 25 Sep 2022 12:07:01 -0700 Subject: [PATCH 12/15] fix: minor issue in docstring --- python/src/schema.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/schema.rs b/python/src/schema.rs index d87e136965..be561d6bf4 100644 --- a/python/src/schema.rs +++ b/python/src/schema.rs @@ -1069,7 +1069,7 @@ impl PySchema { /// /// :rtype: List[Tuple[str, str]] /// :return: a tuple of strings for each invariant. The first string is the - /// field path and the second is the SQL of the invariant. + /// field path and the second is the SQL of the invariant. #[getter] fn invariants(self_: PyRef<'_, Self>) -> PyResult> { let super_ = self_.as_ref(); From 6ef1086a398c245b27aa18c75e014b8f8a803050 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 27 Sep 2022 19:10:05 -0700 Subject: [PATCH 13/15] Don't make an async runtime within the API --- python/src/lib.rs | 11 ++++++++--- rust/src/delta_datafusion.rs | 12 ++++-------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/src/lib.rs b/python/src/lib.rs index 85c0ec7522..7a7fc657ad 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -590,6 +590,7 @@ fn write_new_deltalake( #[pyclass(name = "DeltaDataChecker", text_signature = "(invariants)")] struct PyDeltaDataChecker { inner: DeltaDataChecker, + rt: tokio::runtime::Runtime, } #[pymethods] @@ -605,13 +606,17 @@ impl PyDeltaDataChecker { .collect(); Self { inner: DeltaDataChecker::new(invariants), + rt: tokio::runtime::Runtime::new().unwrap(), } } fn check_batch(&self, batch: RecordBatch) -> PyResult<()> { - self.inner - .check_batch(&batch) - .map_err(PyDeltaTableError::from_raw) + self.rt.block_on(async { + self.inner + .check_batch(&batch) + .await + .map_err(PyDeltaTableError::from_raw) + }) } } diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index 74c43fba66..7681f6251a 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -610,7 +610,6 @@ fn left_larger_than_right(left: ScalarValue, right: ScalarValue) -> Option pub struct DeltaDataChecker { invariants: Vec, ctx: SessionContext, - rt: tokio::runtime::Runtime, } impl DeltaDataChecker { @@ -619,7 +618,6 @@ impl DeltaDataChecker { Self { invariants, ctx: SessionContext::new(), - rt: tokio::runtime::Runtime::new().unwrap(), } } @@ -627,12 +625,12 @@ impl DeltaDataChecker { /// /// If it does not, it will return [DeltaTableError::InvalidData] with a list /// of values that violated each invariant. - pub fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { - self.enforce_invariants(record_batch) + pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { + self.enforce_invariants(record_batch).await // TODO: for support for Protocol V3, check constraints } - fn enforce_invariants(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { + async fn enforce_invariants(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { // Invariants are deprecated, so let's not pay the overhead for any of this // if we can avoid it. if self.invariants.is_empty() { @@ -656,9 +654,7 @@ impl DeltaDataChecker { invariant.field_name, invariant.invariant_sql ); - let dfs: Vec = self - .rt - .block_on(async { self.ctx.sql(&sql).await?.collect().await })?; + let dfs: Vec = self.ctx.sql(&sql).await?.collect().await?; if !dfs.is_empty() && dfs[0].num_rows() > 0 { let value = format!("{:?}", dfs[0].column(0)); let msg = format!( From bcd2c5a1f33287d14a0ec8a94704bcb744ad0b9b Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 27 Sep 2022 20:10:00 -0700 Subject: [PATCH 14/15] format --- rust/src/delta_datafusion.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index 7681f6251a..beea0fea20 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -654,7 +654,7 @@ impl DeltaDataChecker { invariant.field_name, invariant.invariant_sql ); - let dfs: Vec = self.ctx.sql(&sql).await?.collect().await?; + let dfs: Vec = self.ctx.sql(&sql).await?.collect().await?; if !dfs.is_empty() && dfs[0].num_rows() > 0 { let value = format!("{:?}", dfs[0].column(0)); let msg = format!( From d72d1b00506c0745c6120300dcfad00f7878c677 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 27 Sep 2022 20:35:20 -0700 Subject: [PATCH 15/15] fix: arrange test for async --- rust/src/delta_datafusion.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index beea0fea20..9a7dda42c6 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -791,8 +791,8 @@ mod tests { assert_eq!(file.partition_values, ref_file.partition_values) } - #[test] - fn test_enforce_invariants() { + #[tokio::test] + async fn test_enforce_invariants() { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Int32, false), @@ -809,6 +809,7 @@ mod tests { let invariants: Vec = vec![]; assert!(DeltaDataChecker::new(invariants) .check_batch(&batch) + .await .is_ok()); // Valid invariants return Ok(()) @@ -818,6 +819,7 @@ mod tests { ]; assert!(DeltaDataChecker::new(invariants) .check_batch(&batch) + .await .is_ok()); // Violated invariants returns an error with list of violations @@ -825,7 +827,7 @@ mod tests { Invariant::new("a", "a is null"), Invariant::new("b", "b < 100"), ]; - let result = DeltaDataChecker::new(invariants).check_batch(&batch); + let result = DeltaDataChecker::new(invariants).check_batch(&batch).await; assert!(result.is_err()); assert!(matches!(result, Err(DeltaTableError::InvalidData { .. }))); if let Err(DeltaTableError::InvalidData { violations }) = result { @@ -834,7 +836,7 @@ mod tests { // Irrelevant invariants return a different error let invariants = vec![Invariant::new("c", "c > 2000")]; - let result = DeltaDataChecker::new(invariants).check_batch(&batch); + let result = DeltaDataChecker::new(invariants).check_batch(&batch).await; assert!(result.is_err()); // Nested invariants are unsupported @@ -848,7 +850,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![inner]).unwrap(); let invariants = vec![Invariant::new("x.b", "x.b < 1000")]; - let result = DeltaDataChecker::new(invariants).check_batch(&batch); + let result = DeltaDataChecker::new(invariants).check_batch(&batch).await; assert!(result.is_err()); assert!(matches!(result, Err(DeltaTableError::Generic { .. }))); }