Skip to content

Commit

Permalink
fix: don't flatten fields during cdf read
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Aug 12, 2024
1 parent 85ffddf commit 5a5f515
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
4 changes: 2 additions & 2 deletions crates/core/src/delta_datafusion/cdf/scan_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ pub fn create_partition_values<F: FileAction>(
Ok(file_groups)
}

pub fn create_cdc_schema(mut schema_fields: Vec<Field>, include_type: bool) -> SchemaRef {
pub fn create_cdc_schema(mut schema_fields: Vec<Arc<Field>>, include_type: bool) -> SchemaRef {
if include_type {
schema_fields.push(Field::new(CHANGE_TYPE_COL, DataType::Utf8, true));
schema_fields.push(Field::new(CHANGE_TYPE_COL, DataType::Utf8, true).into());
}
Arc::new(Schema::new(schema_fields))
}
4 changes: 2 additions & 2 deletions crates/core/src/operations/load_cdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,10 @@ impl CdfLoadBuilder {

let partition_values = self.snapshot.metadata().partition_columns.clone();
let schema = self.snapshot.input_schema()?;
let schema_fields: Vec<Field> = self
let schema_fields: Vec<Arc<Field>> = self
.snapshot
.input_schema()?
.flattened_fields()
.fields()
.into_iter()
.filter(|f| !partition_values.contains(f.name()))
.cloned()
Expand Down
61 changes: 37 additions & 24 deletions python/tests/test_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,24 +601,30 @@ def test_write_overwrite_unpartitioned_cdf(tmp_path, sample_data: pa.Table):
engine="rust",
configuration={"delta.enableChangeDataFeed": "true"},
)

# expected_data = (
# ds.dataset(sample_data)
# .to_table()
# .append_column(
# field_=pa.field("_change_type", pa.string(), nullable=True),
# column=[["delete"] * 5],
# )
# )
sort_values = [("_change_type", "ascending"), ("utf8", "ascending")]
expected_data = (
ds.dataset(pa.concat_tables([sample_data] * 3))
.to_table()
.append_column(
field_=pa.field("_change_type", pa.string(), nullable=True),
column=[["delete"] * 5 + ["insert"] * 10],
)
).sort_by(sort_values)

assert not os.path.exists(
cdc_path
), "_change_data shouldn't exist since table was overwritten"

## TODO(ion): check if you see insert and deletes in commit version 1

# assert dt.load_cdf().read_all().drop_columns(['_commit_version', '_commit_timestamp']) == expected_data
# assert dt.to_pyarrow_table().sort_by([("utf8", "ascending")]) == sample_data
assert (
dt.load_cdf()
.read_all()
.drop_columns(["_commit_version", "_commit_timestamp"])
.sort_by(sort_values)
== expected_data
)
assert dt.to_pyarrow_table().sort_by([("utf8", "ascending")]) == sample_data


def test_write_overwrite_partitioned_cdf(tmp_path, sample_data: pa.Table):
Expand All @@ -632,35 +638,42 @@ def test_write_overwrite_partitioned_cdf(tmp_path, sample_data: pa.Table):
configuration={"delta.enableChangeDataFeed": "true"},
)

batch2 = ds.dataset(sample_data).to_table(filter=(pc.field("int64") > 3))

dt = DeltaTable(tmp_path)
write_deltalake(
dt,
data=ds.dataset(sample_data).to_table(filter=(pc.field("int64") > 3)),
data=batch2,
engine="rust",
mode="overwrite",
predicate="int64 > 3",
partition_by=["int64"],
configuration={"delta.enableChangeDataFeed": "true"},
)

# expected_data = (
# ds.dataset(sample_data)
# .to_table()
# .append_column(
# field_=pa.field("_change_type", pa.string(), nullable=False),
# column=[["delete"] * 5],
# )
# )
table_schema = dt.schema().to_pyarrow()
table_schema = table_schema.insert(
len(table_schema), pa.field("_change_type", pa.string(), nullable=False)
)
# cdc_data = pq.read_table(cdc_path, schema=table_schema)

sort_values = [("_change_type", "ascending"), ("utf8", "ascending")]

first_batch = sample_data.append_column(
field_=pa.field("_change_type", pa.string(), nullable=True),
column=[["insert"] * 5],
)

expected_data = pa.concat_tables([batch2] * 2).append_column(
field_=pa.field("_change_type", pa.string(), nullable=True),
column=[["delete", "insert"]],
)

assert not os.path.exists(
cdc_path
), "_change_data shouldn't exist since a specific partition was overwritten"

## TODO(ion): check if you see insert and deletes in commit version 1
# assert cdc_data == expected_data
# assert dt.to_pyarrow_table().sort_by([("int64", "ascending")]) == sample_data
assert dt.load_cdf().read_all().drop_columns(
["_commit_version", "_commit_timestamp"]
).sort_by(sort_values).select(expected_data.column_names) == pa.concat_tables(
[first_batch, expected_data]
).sort_by(sort_values)

0 comments on commit 5a5f515

Please sign in to comment.