From 599275b143011401f469e0844b4a2c8c57226d2a Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Mon, 12 Aug 2024 20:09:23 +0200 Subject: [PATCH] fix: don't flatten fields during cdf read --- .../src/delta_datafusion/cdf/scan_utils.rs | 4 +- crates/core/src/operations/load_cdf.rs | 4 +- python/tests/test_cdf.py | 61 +++++++++++-------- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/crates/core/src/delta_datafusion/cdf/scan_utils.rs b/crates/core/src/delta_datafusion/cdf/scan_utils.rs index a05efadd44..27285179f6 100644 --- a/crates/core/src/delta_datafusion/cdf/scan_utils.rs +++ b/crates/core/src/delta_datafusion/cdf/scan_utils.rs @@ -92,9 +92,9 @@ pub fn create_partition_values( Ok(file_groups) } -pub fn create_cdc_schema(mut schema_fields: Vec, include_type: bool) -> SchemaRef { +pub fn create_cdc_schema(mut schema_fields: Vec>, include_type: bool) -> SchemaRef { if include_type { - schema_fields.push(Field::new(CHANGE_TYPE_COL, DataType::Utf8, true)); + schema_fields.push(Field::new(CHANGE_TYPE_COL, DataType::Utf8, true).into()); } Arc::new(Schema::new(schema_fields)) } diff --git a/crates/core/src/operations/load_cdf.rs b/crates/core/src/operations/load_cdf.rs index 4b1bea9a80..77a06170ce 100644 --- a/crates/core/src/operations/load_cdf.rs +++ b/crates/core/src/operations/load_cdf.rs @@ -260,10 +260,10 @@ impl CdfLoadBuilder { let partition_values = self.snapshot.metadata().partition_columns.clone(); let schema = self.snapshot.input_schema()?; - let schema_fields: Vec = self + let schema_fields: Vec> = self .snapshot .input_schema()? - .flattened_fields() + .fields() .into_iter() .filter(|f| !partition_values.contains(f.name())) .cloned() diff --git a/python/tests/test_cdf.py b/python/tests/test_cdf.py index c765054ca4..36d94c9f99 100644 --- a/python/tests/test_cdf.py +++ b/python/tests/test_cdf.py @@ -601,15 +601,15 @@ def test_write_overwrite_unpartitioned_cdf(tmp_path, sample_data: pa.Table): engine="rust", configuration={"delta.enableChangeDataFeed": "true"}, ) - - # expected_data = ( - # ds.dataset(sample_data) - # .to_table() - # .append_column( - # field_=pa.field("_change_type", pa.string(), nullable=True), - # column=[["delete"] * 5], - # ) - # ) + sort_values = [("_change_type", "ascending"), ("utf8", "ascending")] + expected_data = ( + ds.dataset(pa.concat_tables([sample_data] * 3)) + .to_table() + .append_column( + field_=pa.field("_change_type", pa.string(), nullable=True), + column=[["delete"] * 5 + ["insert"] * 10], + ) + ).sort_by(sort_values) assert not os.path.exists( cdc_path @@ -617,8 +617,14 @@ def test_write_overwrite_unpartitioned_cdf(tmp_path, sample_data: pa.Table): ## TODO(ion): check if you see insert and deletes in commit version 1 - # assert dt.load_cdf().read_all().drop_columns(['_commit_version', '_commit_timestamp']) == expected_data - # assert dt.to_pyarrow_table().sort_by([("utf8", "ascending")]) == sample_data + assert ( + dt.load_cdf() + .read_all() + .drop_columns(["_commit_version", "_commit_timestamp"]) + .sort_by(sort_values) + == expected_data + ) + assert dt.to_pyarrow_table().sort_by([("utf8", "ascending")]) == sample_data def test_write_overwrite_partitioned_cdf(tmp_path, sample_data: pa.Table): @@ -632,10 +638,12 @@ def test_write_overwrite_partitioned_cdf(tmp_path, sample_data: pa.Table): configuration={"delta.enableChangeDataFeed": "true"}, ) + batch2 = ds.dataset(sample_data).to_table(filter=(pc.field("int64") > 3)) + dt = DeltaTable(tmp_path) write_deltalake( dt, - data=ds.dataset(sample_data).to_table(filter=(pc.field("int64") > 3)), + data=batch2, engine="rust", mode="overwrite", predicate="int64 > 3", @@ -643,24 +651,29 @@ def test_write_overwrite_partitioned_cdf(tmp_path, sample_data: pa.Table): configuration={"delta.enableChangeDataFeed": "true"}, ) - # expected_data = ( - # ds.dataset(sample_data) - # .to_table() - # .append_column( - # field_=pa.field("_change_type", pa.string(), nullable=False), - # column=[["delete"] * 5], - # ) - # ) table_schema = dt.schema().to_pyarrow() table_schema = table_schema.insert( len(table_schema), pa.field("_change_type", pa.string(), nullable=False) ) - # cdc_data = pq.read_table(cdc_path, schema=table_schema) + + sort_values = [("_change_type", "ascending"), ("utf8", "ascending")] + + first_batch = sample_data.append_column( + field_=pa.field("_change_type", pa.string(), nullable=True), + column=[["insert"] * 5], + ) + + expected_data = pa.concat_tables([batch2] * 2).append_column( + field_=pa.field("_change_type", pa.string(), nullable=True), + column=[["delete", "insert"]], + ) assert not os.path.exists( cdc_path ), "_change_data shouldn't exist since a specific partition was overwritten" - ## TODO(ion): check if you see insert and deletes in commit version 1 - # assert cdc_data == expected_data - # assert dt.to_pyarrow_table().sort_by([("int64", "ascending")]) == sample_data + assert dt.load_cdf().read_all().drop_columns( + ["_commit_version", "_commit_timestamp"] + ).sort_by(sort_values).select(expected_data.column_names) == pa.concat_tables( + [first_batch, expected_data] + ).sort_by(sort_values)