diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 95afed637b..3d618c41fe 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -36,8 +36,9 @@ use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::union::UnionExec; use datafusion::physical_plan::{memory::MemoryExec, ExecutionPlan}; -use datafusion_common::{DFSchema, ScalarValue}; +use datafusion_common::DFSchema; use datafusion_expr::{lit, Expr}; use datafusion_physical_expr::expressions::{self}; use datafusion_physical_expr::PhysicalExpr; @@ -548,16 +549,19 @@ async fn execute_non_empty_expr( writer_properties: Option, writer_stats_config: WriterStatsConfig, partition_scan: bool, + insert_plan: Arc, ) -> DeltaResult> { // For each identified file perform a parquet scan + filter + limit (1) + count. // If returned count is not zero then append the file to be rewritten and removed from the log. Otherwise do nothing to the file. let mut actions: Vec = Vec::new(); - let input_schema = snapshot.input_schema()?; - let input_dfschema: DFSchema = input_schema.clone().as_ref().clone().try_into()?; + // Take the insert plan schema since it might have been schema evolved, if its not + // it is simply the table schema + let df_schema = insert_plan.schema(); + let input_dfschema: DFSchema = df_schema.as_ref().clone().try_into()?; let scan_config = DeltaScanConfigBuilder::new() - .with_schema(snapshot.input_schema()?) + .with_schema(df_schema) .build(snapshot)?; let scan = DeltaScanBuilder::new(snapshot, log_store.clone(), &state) @@ -596,20 +600,26 @@ async fn execute_non_empty_expr( } // CDC logic, simply filters data with predicate and adds the _change_type="delete" as literal column - if let Some(cdc_actions) = execute_non_empty_expr_cdc( - snapshot, - log_store, - state.clone(), - scan, - input_dfschema, - expression, - partition_columns, - writer_properties, - writer_stats_config, - ) - .await? - { - actions.extend(cdc_actions) + // Only write when CDC actions when it was not a partition scan, load_cdf can deduce the deletes in that case + // based on the remove actions if a partition got deleted + if !partition_scan { + // We only write deletions when it was not a partition scan + if let Some(cdc_actions) = execute_non_empty_expr_cdc( + snapshot, + log_store, + state.clone(), + scan, + input_dfschema, + expression, + partition_columns, + writer_properties, + writer_stats_config, + insert_plan, + ) + .await? + { + actions.extend(cdc_actions) + } } Ok(actions) } @@ -626,6 +636,7 @@ pub(crate) async fn execute_non_empty_expr_cdc( table_partition_cols: Vec, writer_properties: Option, writer_stats_config: WriterStatsConfig, + insert_plan: Arc, ) -> DeltaResult>> { match should_write_cdc(snapshot) { // Create CDC scan @@ -636,11 +647,14 @@ pub(crate) async fn execute_non_empty_expr_cdc( Arc::new(FilterExec::try_new(cdc_predicate_expr, scan.clone())?); // Add literal column "_change_type" - let change_type_lit = lit(ScalarValue::Utf8(Some("delete".to_string()))); - let change_type_expr = state.create_physical_expr(change_type_lit, &input_dfschema)?; + let delete_change_type_expr = + state.create_physical_expr(lit("delete"), &input_dfschema)?; + + let insert_change_type_expr = + state.create_physical_expr(lit("insert"), &input_dfschema)?; // Project columns and lit - let project_expressions: Vec<(Arc, String)> = scan + let mut delete_project_expressions: Vec<(Arc, String)> = scan .schema() .fields() .into_iter() @@ -651,18 +665,35 @@ pub(crate) async fn execute_non_empty_expr_cdc( field.name().to_owned(), ) }) - .chain(iter::once((change_type_expr, "_change_type".to_owned()))) .collect(); - let projected_scan: Arc = Arc::new(ProjectionExec::try_new( - project_expressions, + let mut insert_project_expressions = delete_project_expressions.clone(); + delete_project_expressions.insert( + delete_project_expressions.len(), + (delete_change_type_expr, "_change_type".to_owned()), + ); + insert_project_expressions.insert( + insert_project_expressions.len(), + (insert_change_type_expr, "_change_type".to_owned()), + ); + + let delete_plan: Arc = Arc::new(ProjectionExec::try_new( + delete_project_expressions, cdc_scan.clone(), )?); + let insert_plan: Arc = Arc::new(ProjectionExec::try_new( + insert_project_expressions, + insert_plan.clone(), + )?); + + let cdc_plan: Arc = + Arc::new(UnionExec::new(vec![delete_plan, insert_plan])); + let cdc_actions = write_execution_plan_cdc( Some(snapshot), state.clone(), - projected_scan.clone(), + cdc_plan.clone(), table_partition_cols.clone(), log_store.object_store(), Some(snapshot.table_config().target_file_size() as usize), @@ -689,6 +720,7 @@ async fn prepare_predicate_actions( writer_properties: Option, deletion_timestamp: i64, writer_stats_config: WriterStatsConfig, + insert_plan: Arc, ) -> DeltaResult> { let candidates = find_files(snapshot, log_store.clone(), &state, Some(predicate.clone())).await?; @@ -703,6 +735,7 @@ async fn prepare_predicate_actions( writer_properties, writer_stats_config, candidates.partition_scan, + insert_plan, ) .await?; @@ -725,47 +758,6 @@ async fn prepare_predicate_actions( Ok(actions) } -/// If CDC is enabled it writes all add add actions data as deletions into _change_data directory -async fn execute_non_empty_expr_cdc_all_actions( - snapshot: &DeltaTableState, - log_store: LogStoreRef, - state: SessionState, - table_partition_cols: Vec, - writer_properties: Option, - writer_stats_config: WriterStatsConfig, -) -> DeltaResult>> { - let current_state_add_actions = &snapshot.file_actions()?; - - let scan_config = DeltaScanConfigBuilder::new() - .with_schema(snapshot.input_schema()?) - .build(snapshot)?; - - // Since all files get removed, check to write CDC - let scan = DeltaScanBuilder::new(snapshot, log_store.clone(), &state) - .with_files(current_state_add_actions) - // Use input schema which doesn't wrap partition values, otherwise divide_by_partition_value won't work on UTF8 partitions - // Since it can't fetch a scalar from a dictionary type - .with_scan_config(scan_config) - .build() - .await?; - - let input_schema = snapshot.input_schema()?; - let input_dfschema: DFSchema = input_schema.clone().as_ref().clone().try_into()?; - - execute_non_empty_expr_cdc( - snapshot, - log_store, - state, - scan.into(), - input_dfschema, - &Expr::Literal(ScalarValue::Boolean(Some(true))), // Keep all data - table_partition_cols, - writer_properties, - writer_stats_config, - ) - .await -} - impl std::future::IntoFuture for WriteBuilder { type Output = DeltaResult; type IntoFuture = BoxFuture<'static, Self::Output>; @@ -997,7 +989,7 @@ impl std::future::IntoFuture for WriteBuilder { predicate.clone(), this.snapshot.as_ref(), state.clone(), - plan, + plan.clone(), partition_columns.clone(), this.log_store.object_store().clone(), this.target_file_size, @@ -1066,6 +1058,7 @@ impl std::future::IntoFuture for WriteBuilder { this.writer_properties, deletion_timestamp, writer_stats_config, + plan, ) .await?; if !predicate_actions.is_empty() { @@ -1078,21 +1071,6 @@ impl std::future::IntoFuture for WriteBuilder { .into_iter() .map(|p| p.remove_action(true).into()); actions.extend(remove_actions); - - let cdc_actions = execute_non_empty_expr_cdc_all_actions( - snapshot, - this.log_store.clone(), - state, - partition_columns.clone(), - this.writer_properties, - writer_stats_config, - ) - .await?; - - // ADD CDC ACTIONS HERE - if let Some(cdc_actions) = cdc_actions { - actions.extend(cdc_actions); - } } }; } @@ -1190,8 +1168,11 @@ fn try_cast_batch(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowE #[cfg(test)] mod tests { use super::*; + use crate::logstore::get_actions; + use crate::operations::load_cdf::collect_batches; use crate::operations::{collect_sendable_stream, DeltaOps}; use crate::protocol::SaveMode; + use crate::test_utils::{TestResult, TestSchemas}; use crate::writer::test_utils::datafusion::{get_data, get_data_sorted, write_batch}; use crate::writer::test_utils::{ get_arrow_schema, get_delta_schema, get_delta_schema_with_nested_struct, get_record_batch, @@ -1202,6 +1183,7 @@ mod tests { use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit}; use datafusion::prelude::*; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; + use itertools::Itertools; use serde_json::{json, Value}; #[tokio::test] @@ -1930,4 +1912,247 @@ mod tests { let actual = get_data_sorted(&table, "id,value,modified").await; assert_batches_sorted_eq!(&expected, &actual); } + + #[tokio::test] + async fn test_dont_write_cdc_with_overwrite() -> TestResult { + let delta_schema = TestSchemas::simple(); + let table: DeltaTable = DeltaOps::new_in_memory() + .create() + .with_columns(delta_schema.fields().cloned()) + .with_partition_columns(["id"]) + .with_configuration_property(DeltaConfigKey::EnableChangeDataFeed, Some("true")) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let schema = Arc::new(ArrowSchema::try_from(delta_schema)?); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![Some("1"), Some("2"), Some("3")])), + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + Arc::new(StringArray::from(vec![ + Some("yes"), + Some("yes"), + Some("no"), + ])), + ], + ) + .unwrap(); + + let second_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![Some("3")])), + Arc::new(Int32Array::from(vec![Some(10)])), + Arc::new(StringArray::from(vec![Some("yes")])), + ], + ) + .unwrap(); + + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + + let table = DeltaOps(table) + .write([second_batch]) + .with_save_mode(crate::protocol::SaveMode::Overwrite) + .await + .unwrap(); + assert_eq!(table.version(), 2); + + let snapshot_bytes = table + .log_store + .read_commit_entry(2) + .await? + .expect("failed to get snapshot bytes"); + let version_actions = get_actions(2, snapshot_bytes).await?; + + let cdc_actions = version_actions + .iter() + .filter(|action| match action { + &&Action::Cdc(_) => true, + _ => false, + }) + .collect_vec(); + assert!(cdc_actions.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn test_dont_write_cdc_with_overwrite_predicate_partitioned() -> TestResult { + let delta_schema = TestSchemas::simple(); + let table: DeltaTable = DeltaOps::new_in_memory() + .create() + .with_columns(delta_schema.fields().cloned()) + .with_partition_columns(["id"]) + .with_configuration_property(DeltaConfigKey::EnableChangeDataFeed, Some("true")) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let schema = Arc::new(ArrowSchema::try_from(delta_schema)?); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![Some("1"), Some("2"), Some("3")])), + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + Arc::new(StringArray::from(vec![ + Some("yes"), + Some("yes"), + Some("no"), + ])), + ], + ) + .unwrap(); + + let second_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![Some("3")])), + Arc::new(Int32Array::from(vec![Some(10)])), + Arc::new(StringArray::from(vec![Some("yes")])), + ], + ) + .unwrap(); + + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + + let table = DeltaOps(table) + .write([second_batch]) + .with_save_mode(crate::protocol::SaveMode::Overwrite) + .with_replace_where("id='3'") + .await + .unwrap(); + assert_eq!(table.version(), 2); + let snapshot_bytes = table + .log_store + .read_commit_entry(2) + .await? + .expect("failed to get snapshot bytes"); + let version_actions = get_actions(2, snapshot_bytes).await?; + + let cdc_actions = version_actions + .iter() + .filter(|action| match action { + &&Action::Cdc(_) => true, + _ => false, + }) + .collect_vec(); + assert!(cdc_actions.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn test_dont_write_cdc_with_overwrite_predicate_unpartitioned() -> TestResult { + let delta_schema = TestSchemas::simple(); + let table: DeltaTable = DeltaOps::new_in_memory() + .create() + .with_columns(delta_schema.fields().cloned()) + .with_partition_columns(["id"]) + .with_configuration_property(DeltaConfigKey::EnableChangeDataFeed, Some("true")) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let schema = Arc::new(ArrowSchema::try_from(delta_schema)?); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![Some("1"), Some("2"), Some("3")])), + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + Arc::new(StringArray::from(vec![ + Some("yes"), + Some("yes"), + Some("no"), + ])), + ], + ) + .unwrap(); + + let second_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![Some("3")])), + Arc::new(Int32Array::from(vec![Some(3)])), + Arc::new(StringArray::from(vec![Some("yes")])), + ], + ) + .unwrap(); + + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + + let table = DeltaOps(table) + .write([second_batch]) + .with_save_mode(crate::protocol::SaveMode::Overwrite) + .with_replace_where("value=3") + .await + .unwrap(); + assert_eq!(table.version(), 2); + + let ctx = SessionContext::new(); + let cdf_scan = DeltaOps(table.clone()) + .load_cdf() + .with_session_ctx(ctx.clone()) + .with_starting_version(0) + .build() + .await + .expect("Failed to load CDF"); + + let mut batches = collect_batches( + cdf_scan + .properties() + .output_partitioning() + .partition_count(), + cdf_scan, + ctx, + ) + .await + .expect("Failed to collect batches"); + + // The batches will contain a current _commit_timestamp which shouldn't be check_append_only + let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(4)).collect(); + + assert_batches_sorted_eq! {[ + "+-------+----------+--------------+-----------------+----+", + "| value | modified | _change_type | _commit_version | id |", + "+-------+----------+--------------+-----------------+----+", + "| 1 | yes | insert | 1 | 1 |", + "| 2 | yes | insert | 1 | 2 |", + "| 3 | no | delete | 2 | 3 |", + "| 3 | no | insert | 1 | 3 |", + "| 3 | yes | insert | 2 | 3 |", + "+-------+----------+--------------+-----------------+----+", + ], &batches } + + let snapshot_bytes = table + .log_store + .read_commit_entry(2) + .await? + .expect("failed to get snapshot bytes"); + let version_actions = get_actions(2, snapshot_bytes).await?; + + let cdc_actions = version_actions + .iter() + .filter(|action| match action { + &&Action::Cdc(_) => true, + _ => false, + }) + .collect_vec(); + assert!(!cdc_actions.is_empty()); + Ok(()) + } } diff --git a/python/tests/test_cdf.py b/python/tests/test_cdf.py index 1fdd459263..c765054ca4 100644 --- a/python/tests/test_cdf.py +++ b/python/tests/test_cdf.py @@ -504,18 +504,28 @@ def test_write_predicate_unpartitioned_cdf(tmp_path, sample_data: pa.Table): configuration={"delta.enableChangeDataFeed": "true"}, ) - expected_data = ( - ds.dataset(sample_data) - .to_table(filter=(pc.field("int64") > 2)) - .append_column( - field_=pa.field("_change_type", pa.string(), nullable=False), - column=[["delete"] * 2], - ) + expected_data = pa.concat_tables( + [ + ds.dataset(sample_data) + .to_table(filter=(pc.field("int64") > 2)) + .append_column( + field_=pa.field("_change_type", pa.string(), nullable=False), + column=[["delete"] * 2], + ), + ds.dataset(sample_data) + .to_table(filter=(pc.field("int64") > 2)) + .append_column( + field_=pa.field("_change_type", pa.string(), nullable=False), + column=[["insert"] * 2], + ), + ] ) cdc_data = pq.read_table(cdc_path) assert os.path.exists(cdc_path), "_change_data doesn't exist" - assert cdc_data == expected_data + assert cdc_data.sort_by([("_change_type", "ascending")]) == expected_data.sort_by( + [("_change_type", "ascending")] + ) assert dt.to_pyarrow_table().sort_by([("utf8", "ascending")]) == sample_data @@ -532,21 +542,30 @@ def test_write_predicate_partitioned_cdf(tmp_path, sample_data: pa.Table): dt = DeltaTable(tmp_path) write_deltalake( dt, - data=ds.dataset(sample_data).to_table(filter=(pc.field("int64") > 2)), + data=ds.dataset(sample_data).to_table(filter=(pc.field("int64") > 3)), mode="overwrite", - predicate="int64 > 2", + predicate="int64 > 3", engine="rust", configuration={"delta.enableChangeDataFeed": "true"}, ) - expected_data = ( - ds.dataset(sample_data) - .to_table(filter=(pc.field("int64") > 2)) - .append_column( - field_=pa.field("_change_type", pa.string(), nullable=False), - column=[["delete"] * 2], - ) + expected_data = pa.concat_tables( + [ + ds.dataset(sample_data) + .to_table(filter=(pc.field("int64") > 3)) + .append_column( + field_=pa.field("_change_type", pa.string(), nullable=False), + column=[["delete"] * 1], + ), + ds.dataset(sample_data) + .to_table(filter=(pc.field("int64") > 3)) + .append_column( + field_=pa.field("_change_type", pa.string(), nullable=False), + column=[["insert"] * 1], + ), + ] ) + table_schema = dt.schema().to_pyarrow() table_schema = table_schema.insert( len(table_schema), pa.field("_change_type", pa.string(), nullable=False) @@ -554,8 +573,13 @@ def test_write_predicate_partitioned_cdf(tmp_path, sample_data: pa.Table): cdc_data = pq.read_table(cdc_path, schema=table_schema) assert os.path.exists(cdc_path), "_change_data doesn't exist" - assert len(os.listdir(cdc_path)) == 2 - assert cdc_data == expected_data + assert len(os.listdir(cdc_path)) == 1 + expected_data = expected_data.combine_chunks().sort_by( + [("_change_type", "ascending")] + ) + cdc_data = cdc_data.combine_chunks().sort_by([("_change_type", "ascending")]) + + assert expected_data == cdc_data assert dt.to_pyarrow_table().sort_by([("utf8", "ascending")]) == sample_data @@ -578,19 +602,23 @@ def test_write_overwrite_unpartitioned_cdf(tmp_path, sample_data: pa.Table): configuration={"delta.enableChangeDataFeed": "true"}, ) - expected_data = ( - ds.dataset(sample_data) - .to_table() - .append_column( - field_=pa.field("_change_type", pa.string(), nullable=False), - column=[["delete"] * 5], - ) - ) - cdc_data = pq.read_table(cdc_path) + # expected_data = ( + # ds.dataset(sample_data) + # .to_table() + # .append_column( + # field_=pa.field("_change_type", pa.string(), nullable=True), + # column=[["delete"] * 5], + # ) + # ) - assert os.path.exists(cdc_path), "_change_data doesn't exist" - assert cdc_data == expected_data - assert dt.to_pyarrow_table().sort_by([("utf8", "ascending")]) == sample_data + assert not os.path.exists( + cdc_path + ), "_change_data shouldn't exist since table was overwritten" + + ## TODO(ion): check if you see insert and deletes in commit version 1 + + # assert dt.load_cdf().read_all().drop_columns(['_commit_version', '_commit_timestamp']) == expected_data + # assert dt.to_pyarrow_table().sort_by([("utf8", "ascending")]) == sample_data def test_write_overwrite_partitioned_cdf(tmp_path, sample_data: pa.Table): @@ -600,34 +628,39 @@ def test_write_overwrite_partitioned_cdf(tmp_path, sample_data: pa.Table): tmp_path, sample_data, mode="append", - partition_by=["utf8"], + partition_by=["int64"], configuration={"delta.enableChangeDataFeed": "true"}, ) dt = DeltaTable(tmp_path) write_deltalake( dt, - data=ds.dataset(sample_data).to_table(), - mode="overwrite", + data=ds.dataset(sample_data).to_table(filter=(pc.field("int64") > 3)), engine="rust", - partition_by=["utf8"], + mode="overwrite", + predicate="int64 > 3", + partition_by=["int64"], configuration={"delta.enableChangeDataFeed": "true"}, ) - expected_data = ( - ds.dataset(sample_data) - .to_table() - .append_column( - field_=pa.field("_change_type", pa.string(), nullable=False), - column=[["delete"] * 5], - ) - ) + # expected_data = ( + # ds.dataset(sample_data) + # .to_table() + # .append_column( + # field_=pa.field("_change_type", pa.string(), nullable=False), + # column=[["delete"] * 5], + # ) + # ) table_schema = dt.schema().to_pyarrow() table_schema = table_schema.insert( len(table_schema), pa.field("_change_type", pa.string(), nullable=False) ) - cdc_data = pq.read_table(cdc_path, schema=table_schema) + # cdc_data = pq.read_table(cdc_path, schema=table_schema) - assert os.path.exists(cdc_path), "_change_data doesn't exist" - assert cdc_data == expected_data - assert dt.to_pyarrow_table().sort_by([("int64", "ascending")]) == sample_data + assert not os.path.exists( + cdc_path + ), "_change_data shouldn't exist since a specific partition was overwritten" + + ## TODO(ion): check if you see insert and deletes in commit version 1 + # assert cdc_data == expected_data + # assert dt.to_pyarrow_table().sort_by([("int64", "ascending")]) == sample_data