Skip to content

Commit

Permalink
add projection
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Jul 25, 2024
1 parent 292e3c2 commit ae8651a
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 3 deletions.
33 changes: 32 additions & 1 deletion crates/core/src/operations/load_cdf.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
//! Module for reading the change datafeed of delta tables
use datafusion_physical_expr::{
expressions::{self},
PhysicalExpr,
};
use std::sync::Arc;
use std::time::SystemTime;

Expand All @@ -8,6 +12,7 @@ use chrono::{DateTime, Utc};
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::file_format::FileFormat;
use datafusion::datasource::physical_plan::FileScanConfig;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
Expand All @@ -29,6 +34,8 @@ pub struct CdfLoadBuilder {
snapshot: DeltaTableState,
/// Delta object store for handling data files
log_store: LogStoreRef,
/// Columns to project
columns: Option<Vec<String>>,
/// Version to read from
starting_version: i64,
/// Version to stop reading at
Expand All @@ -47,6 +54,7 @@ impl CdfLoadBuilder {
Self {
snapshot,
log_store,
columns: None,
starting_version: 0,
ending_version: None,
starting_timestamp: None,
Expand Down Expand Up @@ -85,6 +93,12 @@ impl CdfLoadBuilder {
self
}

/// Columns to select
pub fn with_columns(mut self, columns: Vec<String>) -> Self {
self.columns = Some(columns);
self
}

/// This is a rust version of https://github.com/delta-io/delta/blob/master/spark/src/main/scala/org/apache/spark/sql/delta/commands/cdc/CDCReader.scala#L418
/// Which iterates through versions of the delta table collects the relevant actions / commit info and returns those
/// groupings for later use. The scala implementation has a lot more edge case handling and read schema checking (and just error checking in general)
Expand Down Expand Up @@ -293,7 +307,24 @@ impl CdfLoadBuilder {

// The output batches are then unioned to create a single output. Coalesce partitions is only here for the time
// being for development. I plan to parallelize the reads once the base idea is correct.
let union_scan: Arc<dyn ExecutionPlan> = Arc::new(UnionExec::new(vec![cdc_scan, add_scan]));
let mut union_scan: Arc<dyn ExecutionPlan> =
Arc::new(UnionExec::new(vec![cdc_scan, add_scan]));

if let Some(columns) = &self.columns {
let expressions: Vec<(Arc<dyn PhysicalExpr>, String)> = union_scan
.schema()
.fields()
.into_iter()
.enumerate()
.map(|(idx, field)| -> (Arc<dyn PhysicalExpr>, String) {
let field_name = field.name();
let expr = Arc::new(expressions::Column::new(field_name, idx));
(expr, field_name.to_owned())
})
.filter(|(_, field_name)| columns.contains(field_name))
.collect();
union_scan = Arc::new(ProjectionExec::try_new(expressions, union_scan)?);
}
Ok(DeltaCdfScan::new(union_scan))
}
}
Expand Down
1 change: 1 addition & 0 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class RawDeltaTable:
def check_can_write_timestamp_ntz(self, schema: pyarrow.Schema) -> None: ...
def load_cdf(
self,
columns: Optional[List[str]] = None,
starting_version: int = 0,
ending_version: Optional[int] = None,
starting_timestamp: Optional[str] = None,
Expand Down
7 changes: 6 additions & 1 deletion python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,14 @@ def load_cdf(
ending_version: Optional[int] = None,
starting_timestamp: Optional[str] = None,
ending_timestamp: Optional[str] = None,
columns: Optional[List[str]] = None,
) -> pyarrow.RecordBatchReader:
return self._table.load_cdf(
starting_version, ending_version, starting_timestamp, ending_timestamp
columns=columns,
starting_version=starting_version,
ending_version=ending_version,
starting_timestamp=starting_timestamp,
ending_timestamp=ending_timestamp,
)

@property
Expand Down
7 changes: 6 additions & 1 deletion python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,14 +580,15 @@ impl RawDeltaTable {
Ok(())
}

#[pyo3(signature = (starting_version = 0, ending_version = None, starting_timestamp = None, ending_timestamp = None))]
#[pyo3(signature = (starting_version = 0, ending_version = None, starting_timestamp = None, ending_timestamp = None, columns = None))]
pub fn load_cdf(
&mut self,
py: Python,
starting_version: i64,
ending_version: Option<i64>,
starting_timestamp: Option<String>,
ending_timestamp: Option<String>,
columns: Option<Vec<String>>,
) -> PyResult<PyArrowType<ArrowArrayStreamReader>> {
let ctx = SessionContext::new();
let mut cdf_read = CdfLoadBuilder::new(
Expand All @@ -612,6 +613,10 @@ impl RawDeltaTable {
cdf_read = cdf_read.with_starting_timestamp(ending_ts);
}

if let Some(columns) = columns {
cdf_read = cdf_read.with_columns(columns);
}

cdf_read = cdf_read.with_session_ctx(ctx.clone());

let plan = rt().block_on(cdf_read.build()).map_err(PythonError::from)?;
Expand Down
6 changes: 6 additions & 0 deletions python/tests/test_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,9 @@ def test_read_cdf_non_partitioned():
datetime(2024, 4, 14, 15, 58, 31, 257000),
datetime(2024, 4, 14, 15, 58, 32, 495000),
]


def test_read_cdf_partitioned_projection():
dt = DeltaTable("../crates/test/tests/data/cdf-table/")
columns = ["id", "_change_type", "_commit_version"]
assert columns == dt.load_cdf(0, 3, columns=columns).schema.names

0 comments on commit ae8651a

Please sign in to comment.