Skip to content

Commit

Permalink
prevent reading to return large arrow schema
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Mar 9, 2024
1 parent 9a7ee6f commit ac54d9c
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 6 deletions.
14 changes: 9 additions & 5 deletions crates/core/src/operations/transaction/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use datafusion_common::{Column, DFSchema};
use datafusion_expr::{utils::conjunction, Expr};
use itertools::Either;
use object_store::ObjectStore;
use parquet::arrow::arrow_reader::ArrowReaderOptions;
use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStreamBuilder};

use crate::delta_datafusion::expr::parse_predicate_expression;
Expand Down Expand Up @@ -125,11 +126,14 @@ impl DeltaTableState {
{
let file_meta = add.try_into()?;
let file_reader = ParquetObjectReader::new(object_store, file_meta);
let file_schema = ParquetRecordBatchStreamBuilder::new(file_reader)
.await?
.build()?
.schema()
.clone();
let file_schema = ParquetRecordBatchStreamBuilder::new_with_options(
file_reader,
ArrowReaderOptions::new().with_skip_arrow_metadata(true),
)
.await?
.build()?
.schema()
.clone();

let table_schema = Arc::new(ArrowSchema::new(
self.arrow_schema()?
Expand Down
24 changes: 24 additions & 0 deletions python/tests/test_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pyarrow as pa
import pyarrow.compute as pc
import pytest

from deltalake.table import DeltaTable
from deltalake.writer import write_deltalake
Expand Down Expand Up @@ -57,3 +58,26 @@ def test_delete_some_rows(existing_table: DeltaTable):

table = existing_table.to_pyarrow_table()
assert table.equals(expected_table)


@pytest.mark.parametrize("engine", ["pyarrow", "rust"])
def test_delete_large_dtypes(
tmp_path: pathlib.Path, sample_table: pa.table, engine: str
):
write_deltalake(tmp_path, sample_table, large_dtypes=True, engine=engine) # type: ignore

dt = DeltaTable(tmp_path)
old_version = dt.version()

existing = dt.to_pyarrow_table()
mask = pc.invert(pc.is_in(existing["id"], pa.array(["1"])))
expected_table = existing.filter(mask)

dt.delete(predicate="id = '1'")

last_action = dt.history(1)[0]
assert last_action["operation"] == "DELETE"
assert dt.version() == old_version + 1

table = dt.to_pyarrow_table()
assert table.equals(expected_table)
31 changes: 31 additions & 0 deletions python/tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,37 @@ def test_update_with_predicate(tmp_path: pathlib.Path, sample_table: pa.Table):
assert result == expected


def test_update_with_predicate_large_dtypes(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append", large_dtypes=True)

dt = DeltaTable(tmp_path)

nrows = 5
expected = pa.table(
{
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([True, False, False, False, False]),
}
)

dt.update(
updates={"deleted": "True"},
predicate="id = '1'",
)

result = dt.to_pyarrow_table()
last_action = dt.history(1)[0]

assert last_action["operation"] == "UPDATE"
assert result == expected


def test_update_wo_predicate(tmp_path: pathlib.Path, sample_table: pa.Table):
write_deltalake(tmp_path, sample_table, mode="append")

Expand Down
7 changes: 6 additions & 1 deletion python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ def test_partition_overwrite_unfiltered_data_fails(
)


@pytest.mark.parametrize("large_dtypes", [True, False])
@pytest.mark.parametrize(
"value_1,value_2,value_type,filter_string",
[
Expand All @@ -1008,6 +1009,7 @@ def test_replace_where_overwrite(
value_2: Any,
value_type: pa.DataType,
filter_string: str,
large_dtypes: bool,
):
table_path = tmp_path

Expand All @@ -1018,7 +1020,9 @@ def test_replace_where_overwrite(
"val": pa.array([1, 1, 1, 1], pa.int64()),
}
)
write_deltalake(table_path, sample_data, mode="overwrite")
write_deltalake(
table_path, sample_data, mode="overwrite", large_dtypes=large_dtypes
)

delta_table = DeltaTable(table_path)
assert (
Expand Down Expand Up @@ -1049,6 +1053,7 @@ def test_replace_where_overwrite(
mode="overwrite",
predicate="p1 = '1'",
engine="rust",
large_dtypes=large_dtypes,
)

delta_table.update_incremental()
Expand Down

0 comments on commit ac54d9c

Please sign in to comment.