Skip to content

Commit

Permalink
fix(python): add support for pyarrow 13+ (#1804)
Browse files Browse the repository at this point in the history
# Description

I build on top of the branch of @wjones127
#1602. In pyarrow v13+ the
ParquetWriter by default uses the `compliant_nested_types = True` (see
related PR: https://github.com/apache/arrow/pull/35146/files)and the
docs:
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html).

In arrow/parquet-rs it fails when it compares schemas because it
expected the old non-compliant ones. For now we can have pyarrow 13+
supported by disabling it or updating the file options provided by a
user.

# Related Issue(s)
<!---
For example:

- closes #106
--->

- Closes #1744

# Documentation

<!---
Share links to useful documentation
--->

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
Co-authored-by: R. Tyler Croy <rtyler@brokenco.de>
  • Loading branch information
3 people authored Nov 5, 2023
1 parent 6e9894f commit d98a0c2
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 14 deletions.
1 change: 0 additions & 1 deletion crates/deltalake-core/src/operations/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,6 @@ pub(super) mod zorder {
assert_eq!(result.null_count(), 0);

let data: &BinaryArray = as_generic_binary_array(result.as_ref());
dbg!(data);
assert_eq!(data.value_data().len(), 3 * 16 * 3);
assert!(data.iter().all(|x| x.unwrap().len() == 3 * 16));
}
Expand Down
7 changes: 7 additions & 0 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
schema, (validate_batch(batch) for batch in batch_iter)
)

if file_options is not None:
file_options.update(use_compliant_nested_type=False)
else:
file_options = ds.ParquetFileFormat().make_write_options(
use_compliant_nested_type=False
)

ds.write_dataset(
data,
base_dir="/",
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
"Programming Language :: Python :: 3.11"
]
dependencies = [
"pyarrow>=8,<13",
"pyarrow>=8",
'typing-extensions;python_version<"3.8"',
]

Expand Down
36 changes: 26 additions & 10 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::time;
use std::time::{SystemTime, UNIX_EPOCH};

use arrow::pyarrow::PyArrowType;
use arrow_schema::DataType;
use chrono::{DateTime, Duration, FixedOffset, Utc};
use deltalake::arrow::compute::concat_batches;
use deltalake::arrow::ffi_stream::ArrowArrayStreamReader;
Expand Down Expand Up @@ -947,16 +948,31 @@ fn filestats_to_expression<'py>(
let mut expressions: Vec<PyResult<&PyAny>> = Vec::new();

let cast_to_type = |column_name: &String, value: PyObject, schema: &ArrowSchema| {
let column_type = PyArrowType(
schema
.field_with_name(column_name)
.map_err(|_| {
PyValueError::new_err(format!("Column not found in schema: {column_name}"))
})?
.data_type()
.clone(),
)
.into_py(py);
let column_type = schema
.field_with_name(column_name)
.map_err(|_| {
PyValueError::new_err(format!("Column not found in schema: {column_name}"))
})?
.data_type()
.clone();

let value = match column_type {
// Since PyArrow 13.0.0, casting string -> timestamp fails if it ends with "Z"
// and the target type is timezone naive.
DataType::Timestamp(_, _) if value.extract::<String>(py).is_ok() => {
value.call_method1(py, "rstrip", ("Z",))?
}
// PyArrow 13.0.0 lost the ability to cast from string to date32, so
// we have to implement that manually.
DataType::Date32 if value.extract::<String>(py).is_ok() => {
let date = Python::import(py, "datetime")?.getattr("date")?;
let date = date.call_method1("fromisoformat", (value,))?;
date.to_object(py)
}
_ => value,
};

let column_type = PyArrowType(column_type).into_py(py);
pa.call_method1("scalar", (value,))?
.call_method1("cast", (column_type,))
};
Expand Down
6 changes: 4 additions & 2 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ def test_write_recordbatchreader(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
batches = existing_table.to_pyarrow_dataset().to_batches()
reader = RecordBatchReader.from_batches(sample_data.schema, batches)
reader = RecordBatchReader.from_batches(
existing_table.to_pyarrow_dataset().schema, batches
)

write_deltalake(tmp_path, reader, mode="overwrite")
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data
Expand Down Expand Up @@ -892,7 +894,7 @@ def comp():
# concurrently, then this will fail.
assert data.num_rows == sample_data.num_rows
try:
write_deltalake(dt.table_uri, data, mode="overwrite")
write_deltalake(dt.table_uri, sample_data, mode="overwrite")
except Exception as e:
exception = e

Expand Down

0 comments on commit d98a0c2

Please sign in to comment.