Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add invariant enforcement support #834

Merged
merged 15 commits into from
Sep 28, 2022
Merged
2 changes: 1 addition & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ features = ["extension-module", "abi3", "abi3-py37"]
[dependencies.deltalake]
path = "../rust"
version = "0"
features = ["s3", "azure", "glue", "gcs", "python"]
features = ["s3", "azure", "glue", "gcs", "python", "datafusion-ext"]
7 changes: 6 additions & 1 deletion python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

if sys.version_info >= (3, 8):
from typing import Literal
Expand Down Expand Up @@ -118,6 +118,7 @@ class StructType:
class Schema:
def __init__(self, fields: List[Field]) -> None: ...
fields: List[Field]
invariants: List[Tuple[str, str]]

def to_json(self) -> str: ...
@staticmethod
Expand Down Expand Up @@ -212,3 +213,7 @@ class DeltaFileSystemHandler:
self, path: str, metadata: dict[str, str] | None = None
) -> ObjectOutputStream:
"""Open an output stream for sequential writing."""

class DeltaDataChecker:
def __init__(self, invariants: List[Tuple[str, str]]) -> None: ...
def check_batch(self, batch: pa.RecordBatch) -> None: ...
28 changes: 26 additions & 2 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import pyarrow.fs as pa_fs
from pyarrow.lib import RecordBatchReader

from ._internal import DeltaDataChecker as _DeltaDataChecker
from ._internal import PyDeltaTableError
from ._internal import write_new_deltalake as _write_new_deltalake
from .table import DeltaTable
Expand Down Expand Up @@ -192,11 +193,11 @@ def write_deltalake(
if partition_by:
assert partition_by == table.metadata().partition_columns

if table.protocol().min_writer_version > 1:
if table.protocol().min_writer_version > 2:
raise DeltaTableProtocolError(
"This table's min_writer_version is "
f"{table.protocol().min_writer_version}, "
"but this method only supports version 1."
"but this method only supports version 2."
)
else: # creating a new table
current_version = -1
Expand Down Expand Up @@ -234,6 +235,29 @@ def visitor(written_file: Any) -> None:
)
)

if table is not None:
# We don't currently provide a way to set invariants
# (and maybe never will), so only enforce if already exist.
invariants = table.schema().invariants
checker = _DeltaDataChecker(invariants)

def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
checker.check_batch(batch)
return batch

if isinstance(data, RecordBatchReader):
batch_iter = data
elif isinstance(data, pa.RecordBatch):
batch_iter = [data]
elif isinstance(data, pa.Table):
batch_iter = data.to_batches()
else:
batch_iter = data

data = RecordBatchReader.from_batches(
schema, (validate_batch(batch) for batch in batch_iter)
)

ds.write_dataset(
data,
base_dir="/",
Expand Down
38 changes: 37 additions & 1 deletion python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ use chrono::{DateTime, FixedOffset, Utc};
use deltalake::action::{
self, Action, ColumnCountStat, ColumnValueStat, DeltaOperation, SaveMode, Stats,
};
use deltalake::arrow::record_batch::RecordBatch;
use deltalake::arrow::{self, datatypes::Schema as ArrowSchema};
use deltalake::builder::DeltaTableBuilder;
use deltalake::delta_datafusion::DeltaDataChecker;
use deltalake::partitions::PartitionFilter;
use deltalake::DeltaDataTypeLong;
use deltalake::DeltaDataTypeTimestamp;
use deltalake::DeltaTableMetaData;
use deltalake::DeltaTransactionOptions;
use deltalake::Schema;
use deltalake::{Invariant, Schema};
use pyo3::create_exception;
use pyo3::exceptions::PyException;
use pyo3::exceptions::PyValueError;
Expand Down Expand Up @@ -585,6 +587,39 @@ fn write_new_deltalake(
Ok(())
}

#[pyclass(name = "DeltaDataChecker", text_signature = "(invariants)")]
struct PyDeltaDataChecker {
inner: DeltaDataChecker,
rt: tokio::runtime::Runtime,
}

#[pymethods]
impl PyDeltaDataChecker {
#[new]
fn new(invariants: Vec<(String, String)>) -> Self {
let invariants: Vec<Invariant> = invariants
.into_iter()
.map(|(field_name, invariant_sql)| Invariant {
field_name,
invariant_sql,
})
.collect();
Self {
inner: DeltaDataChecker::new(invariants),
rt: tokio::runtime::Runtime::new().unwrap(),
}
}

fn check_batch(&self, batch: RecordBatch) -> PyResult<()> {
self.rt.block_on(async {
self.inner
.check_batch(&batch)
.await
.map_err(PyDeltaTableError::from_raw)
})
}
}

#[pymodule]
// module name need to match project name
fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
Expand All @@ -594,6 +629,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(pyo3::wrap_pyfunction!(write_new_deltalake, m)?)?;
m.add_class::<RawDeltaTable>()?;
m.add_class::<RawDeltaTableMetaData>()?;
m.add_class::<PyDeltaDataChecker>()?;
m.add("PyDeltaTableError", py.get_type::<PyDeltaTableError>())?;
// There are issues with submodules, so we will expose them flat for now
// See also: https://github.com/PyO3/pyo3/issues/759
Expand Down
18 changes: 18 additions & 0 deletions python/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,4 +1064,22 @@ impl PySchema {
Err(PyTypeError::new_err("Type is not a struct"))
}
}

/// The list of invariants on the table.
///
/// :rtype: List[Tuple[str, str]]
/// :return: a tuple of strings for each invariant. The first string is the
/// field path and the second is the SQL of the invariant.
#[getter]
fn invariants(self_: PyRef<'_, Self>) -> PyResult<Vec<(String, String)>> {
let super_ = self_.as_ref();
let invariants = super_
.inner_type
.get_invariants()
.map_err(|err| PyException::new_err(err.to_string()))?;
Ok(invariants
.into_iter()
.map(|invariant| (invariant.field_name, invariant.invariant_sql))
.collect())
}
}
115 changes: 115 additions & 0 deletions python/tests/pyspark_integration/test_write_to_pyspark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests that deltalake(delta-rs) can write to tables written by PySpark"""
import pathlib

import pyarrow as pa
import pytest

from deltalake import write_deltalake
from deltalake._internal import PyDeltaTableError
from deltalake.writer import DeltaTableProtocolError

from .utils import assert_spark_read_equal, get_spark

try:
import delta
import delta.pip_utils
import delta.tables
import pyspark

spark = get_spark()
except ModuleNotFoundError:
pass


@pytest.mark.pyspark
@pytest.mark.integration
def test_write_basic(tmp_path: pathlib.Path):
# Write table in Spark
spark = get_spark()
schema = pyspark.sql.types.StructType(
[
pyspark.sql.types.StructField(
"c1",
dataType=pyspark.sql.types.IntegerType(),
nullable=True,
)
]
)
spark.createDataFrame([(4,)], schema=schema).write.save(
str(tmp_path),
mode="append",
format="delta",
)
# Overwrite table in deltalake
data = pa.table({"c1": pa.array([5, 6], type=pa.int32())})
write_deltalake(str(tmp_path), data, mode="overwrite")

# Read table in Spark
assert_spark_read_equal(data, str(tmp_path), sort_by="c1")


@pytest.mark.pyspark
@pytest.mark.integration
def test_write_invariant(tmp_path: pathlib.Path):
# Write table in Spark with invariant
spark = get_spark()

schema = pyspark.sql.types.StructType(
[
pyspark.sql.types.StructField(
"c1",
dataType=pyspark.sql.types.IntegerType(),
nullable=True,
metadata={
"delta.invariants": '{"expression": { "expression": "c1 > 3"} }'
},
)
]
)

delta.tables.DeltaTable.create(spark).location(str(tmp_path)).addColumns(
schema
).execute()

spark.createDataFrame([(4,)], schema=schema).write.save(
str(tmp_path),
mode="append",
format="delta",
)

# Cannot write invalid data to the table
invalid_data = pa.table({"c1": pa.array([6, 2], type=pa.int32())})
with pytest.raises(
PyDeltaTableError, match="Invariant \(c1 > 3\) violated by value .+2"
):
# raise PyDeltaTableError("test")
write_deltalake(str(tmp_path), invalid_data, mode="overwrite")

# Can write valid data to the table
valid_data = pa.table({"c1": pa.array([5, 6], type=pa.int32())})
write_deltalake(str(tmp_path), valid_data, mode="append")

expected = pa.table({"c1": pa.array([4, 5, 6], type=pa.int32())})
assert_spark_read_equal(expected, str(tmp_path), sort_by="c1")


@pytest.mark.pyspark
@pytest.mark.integration
def test_checks_min_writer_version(tmp_path: pathlib.Path):
# Write table in Spark with constraint
spark = get_spark()

spark.createDataFrame([(4,)], schema=["c1"]).write.save(
str(tmp_path),
mode="append",
format="delta",
)

# Add a constraint upgrades the minWriterProtocol
spark.sql(f"ALTER TABLE delta.`{str(tmp_path)}` ADD CONSTRAINT x CHECK (c1 > 2)")

with pytest.raises(
DeltaTableProtocolError, match="This table's min_writer_version is 3, but"
):
valid_data = pa.table({"c1": pa.array([5, 6])})
write_deltalake(str(tmp_path), valid_data, mode="append")
38 changes: 1 addition & 37 deletions python/tests/pyspark_integration/test_writer_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,7 @@

from deltalake import DeltaTable, write_deltalake

try:
from pandas.testing import assert_frame_equal
except ModuleNotFoundError:
_has_pandas = False
else:
_has_pandas = True


def get_spark():
builder = (
pyspark.sql.SparkSession.builder.appName("MyApp")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
)
)
return delta.pip_utils.configure_spark_with_delta_pip(builder).getOrCreate()

from .utils import assert_spark_read_equal, get_spark

try:
import delta
Expand All @@ -38,24 +20,6 @@ def get_spark():
pass


def assert_spark_read_equal(
expected: pa.Table, uri: str, sort_by: List[str] = ["int32"]
):
df = spark.read.format("delta").load(uri)

# Spark and pyarrow don't convert these types to the same Pandas values
incompatible_types = ["timestamp", "struct"]

assert_frame_equal(
df.toPandas()
.sort_values(sort_by, ignore_index=True)
.drop(incompatible_types, axis="columns"),
expected.to_pandas()
.sort_values(sort_by, ignore_index=True)
.drop(incompatible_types, axis="columns"),
)


@pytest.mark.pyspark
@pytest.mark.integration
def test_basic_read(sample_data: pa.Table, existing_table: DeltaTable):
Expand Down
49 changes: 49 additions & 0 deletions python/tests/pyspark_integration/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import List

import pyarrow as pa

try:
import delta
import delta.pip_utils
import delta.tables
import pyspark
except ModuleNotFoundError:
pass

try:
from pandas.testing import assert_frame_equal
except ModuleNotFoundError:
_has_pandas = False
else:
_has_pandas = True


def get_spark():
builder = (
pyspark.sql.SparkSession.builder.appName("MyApp")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
)
)
return delta.pip_utils.configure_spark_with_delta_pip(builder).getOrCreate()


def assert_spark_read_equal(
expected: pa.Table, uri: str, sort_by: List[str] = ["int32"]
):
spark = get_spark()
df = spark.read.format("delta").load(uri)

# Spark and pyarrow don't convert these types to the same Pandas values
incompatible_types = ["timestamp", "struct"]

assert_frame_equal(
df.toPandas()
.sort_values(sort_by, ignore_index=True)
.drop(incompatible_types, axis="columns", errors="ignore"),
expected.to_pandas()
.sort_values(sort_by, ignore_index=True)
.drop(incompatible_types, axis="columns", errors="ignore"),
)
2 changes: 1 addition & 1 deletion python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def test_writer_null_stats(tmp_path: pathlib.Path):


def test_writer_fails_on_protocol(existing_table: DeltaTable, sample_data: pa.Table):
existing_table.protocol = Mock(return_value=ProtocolVersions(1, 2))
existing_table.protocol = Mock(return_value=ProtocolVersions(1, 3))
with pytest.raises(DeltaTableProtocolError):
write_deltalake(existing_table, sample_data, mode="overwrite")

Expand Down
Loading