diff --git a/crates/core/src/operations/load_cdf.rs b/crates/core/src/operations/load_cdf.rs index 4f3a4bdbd6..dd074a197c 100644 --- a/crates/core/src/operations/load_cdf.rs +++ b/crates/core/src/operations/load_cdf.rs @@ -1,5 +1,9 @@ //! Module for reading the change datafeed of delta tables +use datafusion_physical_expr::{ + expressions::{self}, + PhysicalExpr, +}; use std::sync::Arc; use std::time::SystemTime; @@ -8,6 +12,7 @@ use chrono::{DateTime, Utc}; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::union::UnionExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; @@ -29,6 +34,8 @@ pub struct CdfLoadBuilder { snapshot: DeltaTableState, /// Delta object store for handling data files log_store: LogStoreRef, + /// Columns to project + columns: Option>, /// Version to read from starting_version: i64, /// Version to stop reading at @@ -47,6 +54,7 @@ impl CdfLoadBuilder { Self { snapshot, log_store, + columns: None, starting_version: 0, ending_version: None, starting_timestamp: None, @@ -85,6 +93,12 @@ impl CdfLoadBuilder { self } + /// Columns to select + pub fn with_columns(mut self, columns: Vec) -> Self { + self.columns = Some(columns); + self + } + /// This is a rust version of https://github.com/delta-io/delta/blob/master/spark/src/main/scala/org/apache/spark/sql/delta/commands/cdc/CDCReader.scala#L418 /// Which iterates through versions of the delta table collects the relevant actions / commit info and returns those /// groupings for later use. The scala implementation has a lot more edge case handling and read schema checking (and just error checking in general) @@ -293,7 +307,24 @@ impl CdfLoadBuilder { // The output batches are then unioned to create a single output. Coalesce partitions is only here for the time // being for development. I plan to parallelize the reads once the base idea is correct. - let union_scan: Arc = Arc::new(UnionExec::new(vec![cdc_scan, add_scan])); + let mut union_scan: Arc = + Arc::new(UnionExec::new(vec![cdc_scan, add_scan])); + + if let Some(columns) = &self.columns { + let expressions: Vec<(Arc, String)> = union_scan + .schema() + .fields() + .into_iter() + .enumerate() + .map(|(idx, field)| -> (Arc, String) { + let field_name = field.name(); + let expr = Arc::new(expressions::Column::new(field_name, idx)); + (expr, field_name.to_owned()) + }) + .filter(|(_, field_name)| columns.contains(field_name)) + .collect(); + union_scan = Arc::new(ProjectionExec::try_new(expressions, union_scan)?); + } Ok(DeltaCdfScan::new(union_scan)) } } diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index debe460065..252a285078 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -159,6 +159,7 @@ class RawDeltaTable: def check_can_write_timestamp_ntz(self, schema: pyarrow.Schema) -> None: ... def load_cdf( self, + columns: Optional[List[str]] = None, starting_version: int = 0, ending_version: Optional[int] = None, starting_timestamp: Optional[str] = None, diff --git a/python/deltalake/table.py b/python/deltalake/table.py index a85a31bf0b..4caf524ad6 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -585,9 +585,14 @@ def load_cdf( ending_version: Optional[int] = None, starting_timestamp: Optional[str] = None, ending_timestamp: Optional[str] = None, + columns: Optional[List[str]] = None, ) -> pyarrow.RecordBatchReader: return self._table.load_cdf( - starting_version, ending_version, starting_timestamp, ending_timestamp + columns=columns, + starting_version=starting_version, + ending_version=ending_version, + starting_timestamp=starting_timestamp, + ending_timestamp=ending_timestamp, ) @property diff --git a/python/src/lib.rs b/python/src/lib.rs index b6a90fb2f5..0f4424c725 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -580,7 +580,7 @@ impl RawDeltaTable { Ok(()) } - #[pyo3(signature = (starting_version = 0, ending_version = None, starting_timestamp = None, ending_timestamp = None))] + #[pyo3(signature = (starting_version = 0, ending_version = None, starting_timestamp = None, ending_timestamp = None, columns = None))] pub fn load_cdf( &mut self, py: Python, @@ -588,6 +588,7 @@ impl RawDeltaTable { ending_version: Option, starting_timestamp: Option, ending_timestamp: Option, + columns: Option>, ) -> PyResult> { let ctx = SessionContext::new(); let mut cdf_read = CdfLoadBuilder::new( @@ -612,6 +613,10 @@ impl RawDeltaTable { cdf_read = cdf_read.with_starting_timestamp(ending_ts); } + if let Some(columns) = columns { + cdf_read = cdf_read.with_columns(columns); + } + cdf_read = cdf_read.with_session_ctx(ctx.clone()); let plan = rt().block_on(cdf_read.build()).map_err(PythonError::from)?; diff --git a/python/tests/test_cdf.py b/python/tests/test_cdf.py index 905e3c44e2..2292d85cdb 100644 --- a/python/tests/test_cdf.py +++ b/python/tests/test_cdf.py @@ -412,3 +412,9 @@ def test_read_cdf_non_partitioned(): datetime(2024, 4, 14, 15, 58, 31, 257000), datetime(2024, 4, 14, 15, 58, 32, 495000), ] + + +def test_read_cdf_partitioned_projection(): + dt = DeltaTable("../crates/test/tests/data/cdf-table/") + columns = ["id", "_change_type", "_commit_version"] + assert columns == dt.load_cdf(0, 3, columns=columns).schema.names