Skip to content

Commit

Permalink
feat(python): add pyarrow to delta compatible schema conversion in wr…
Browse files Browse the repository at this point in the history
…iter/merge (delta-io#1820)

This ports some functionality that @stinodego and I had worked on in
Polars. Where we converted a pyarrow schema to a compatible delta
schema. It converts the following:

- uint -> int
- timestamp(any timeunit) -> timestamp(us)

I adjusted the functionality to do schema conversion from large to
normal when necessary, which is still needed in MERGE as workaround
delta-io#1753.

Additional things I've added:

- Schema conversion for every input in write_deltalake/merge
- Add Pandas dataframe conversion
- Add Pandas dataframe as input in merge

- closes delta-io#686
- closes delta-io#1467

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
  • Loading branch information
ion-elgreco and wjones127 committed Nov 25, 2023
1 parent c38b518 commit 3ed7df0
Show file tree
Hide file tree
Showing 10 changed files with 483 additions and 69 deletions.
135 changes: 104 additions & 31 deletions python/deltalake/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import TYPE_CHECKING, Tuple, Union
from typing import Generator, Union

import pyarrow as pa

if TYPE_CHECKING:
import pandas as pd
import pyarrow.dataset as ds

from ._internal import ArrayType as ArrayType
from ._internal import Field as Field
Expand All @@ -17,34 +15,109 @@
DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"]


def delta_arrow_schema_from_pandas(
data: "pd.DataFrame",
) -> Tuple[pa.Table, pa.Schema]:
"""
Infers the schema for the delta table from the Pandas DataFrame.
Necessary because of issues such as: https://github.com/delta-io/delta-rs/issues/686
Args:
data: Data to write.
### Inspired from Pola-rs repo - licensed with MIT License, see license in python/licenses/polars_license.txt.###
def _convert_pa_schema_to_delta(
schema: pa.schema, large_dtypes: bool = False
) -> pa.schema:
"""Convert a PyArrow schema to a schema compatible with Delta Lake. Converts unsigned to signed equivalent, and
converts all timestamps to `us` timestamps. With the boolean flag large_dtypes you can control if the schema
should keep cast normal to large types in the schema, or from large to normal.
Returns:
A PyArrow Table and the inferred schema for the Delta Table
Args
schema: Source schema
large_dtypes: If True, the pyarrow schema is casted to large_dtypes
"""
dtype_map = {
pa.uint8(): pa.int8(),
pa.uint16(): pa.int16(),
pa.uint32(): pa.int32(),
pa.uint64(): pa.int64(),
}
if large_dtypes:
dtype_map = {
**dtype_map,
**{pa.string(): pa.large_string(), pa.binary(): pa.large_binary()},
}
else:
dtype_map = {
**dtype_map,
**{pa.large_string(): pa.string(), pa.large_binary(): pa.binary()},
}

table = pa.Table.from_pandas(data)
schema = table.schema
schema_out = []
for field in schema:
if isinstance(field.type, pa.TimestampType):
f = pa.field(
name=field.name,
type=pa.timestamp("us"),
nullable=field.nullable,
metadata=field.metadata,
)
schema_out.append(f)
def dtype_to_delta_dtype(dtype: pa.DataType) -> pa.DataType:
# Handle nested types
if isinstance(dtype, (pa.LargeListType, pa.ListType)):
return list_to_delta_dtype(dtype)
elif isinstance(dtype, pa.StructType):
return struct_to_delta_dtype(dtype)
elif isinstance(dtype, pa.TimestampType):
return pa.timestamp(
"us"
) # TODO(ion): propagate also timezone information during writeonce we can properly read TZ in delta schema
try:
return dtype_map[dtype]
except KeyError:
return dtype

def list_to_delta_dtype(
dtype: Union[pa.LargeListType, pa.ListType],
) -> Union[pa.LargeListType, pa.ListType]:
nested_dtype = dtype.value_type
nested_dtype_cast = dtype_to_delta_dtype(nested_dtype)
if large_dtypes:
return pa.large_list(nested_dtype_cast)
else:
schema_out.append(field)
schema = pa.schema(schema_out, metadata=schema.metadata)
table = table.cast(target_schema=schema)
return table, schema
return pa.list_(nested_dtype_cast)

def struct_to_delta_dtype(dtype: pa.StructType) -> pa.StructType:
fields = [dtype[i] for i in range(dtype.num_fields)]
fields_cast = [f.with_type(dtype_to_delta_dtype(f.type)) for f in fields]
return pa.struct(fields_cast)

return pa.schema([f.with_type(dtype_to_delta_dtype(f.type)) for f in schema])


def _cast_schema_to_recordbatchreader(
reader: pa.RecordBatchReader, schema: pa.schema
) -> Generator[pa.RecordBatch, None, None]:
"""Creates recordbatch generator."""
for batch in reader:
yield pa.Table.from_batches([batch]).cast(schema).to_batches()[0]


def convert_pyarrow_recordbatchreader(
data: pa.RecordBatchReader, large_dtypes: bool
) -> pa.RecordBatchReader:
"""Converts a PyArrow RecordBatchReader to a PyArrow RecordBatchReader with a compatible delta schema"""
schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes)

data = pa.RecordBatchReader.from_batches(
schema,
_cast_schema_to_recordbatchreader(data, schema),
)
return data


def convert_pyarrow_recordbatch(
data: pa.RecordBatch, large_dtypes: bool
) -> pa.RecordBatchReader:
"""Converts a PyArrow RecordBatch to a PyArrow RecordBatchReader with a compatible delta schema"""
schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes)
data = pa.Table.from_batches([data]).cast(schema).to_reader()
return data


def convert_pyarrow_table(data: pa.Table, large_dtypes: bool) -> pa.RecordBatchReader:
"""Converts a PyArrow table to a PyArrow RecordBatchReader with a compatible delta schema"""
schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes)
data = data.cast(schema).to_reader()
return data


def convert_pyarrow_dataset(
data: ds.Dataset, large_dtypes: bool
) -> pa.RecordBatchReader:
"""Converts a PyArrow dataset to a PyArrow RecordBatchReader with a compatible delta schema"""
data = data.scanner().to_reader()
data = convert_pyarrow_recordbatchreader(data, large_dtypes)
return data
34 changes: 26 additions & 8 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

import pyarrow
import pyarrow.dataset as ds
import pyarrow.fs as pa_fs
from pyarrow.dataset import (
Expression,
Expand Down Expand Up @@ -596,7 +597,13 @@ def optimize(

def merge(
self,
source: Union[pyarrow.Table, pyarrow.RecordBatch, pyarrow.RecordBatchReader],
source: Union[
pyarrow.Table,
pyarrow.RecordBatch,
pyarrow.RecordBatchReader,
ds.Dataset,
"pandas.DataFrame",
],
predicate: str,
source_alias: Optional[str] = None,
target_alias: Optional[str] = None,
Expand All @@ -619,25 +626,36 @@ def merge(
invariants = self.schema().invariants
checker = _DeltaDataChecker(invariants)

from .schema import (
convert_pyarrow_dataset,
convert_pyarrow_recordbatch,
convert_pyarrow_recordbatchreader,
convert_pyarrow_table,
)

if isinstance(source, pyarrow.RecordBatchReader):
schema = source.schema
source = convert_pyarrow_recordbatchreader(source, large_dtypes=True)
elif isinstance(source, pyarrow.RecordBatch):
schema = source.schema
source = [source]
source = convert_pyarrow_recordbatch(source, large_dtypes=True)
elif isinstance(source, pyarrow.Table):
schema = source.schema
source = source.to_reader()
source = convert_pyarrow_table(source, large_dtypes=True)
elif isinstance(source, ds.Dataset):
source = convert_pyarrow_dataset(source, large_dtypes=True)
elif isinstance(source, pandas.DataFrame):
source = convert_pyarrow_table(
pyarrow.Table.from_pandas(source), large_dtypes=True
)
else:
raise TypeError(
f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch or Table are valid inputs for source."
f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Table or Pandas DataFrame are valid inputs for source."
)

def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch:
checker.check_batch(batch)
return batch

source = pyarrow.RecordBatchReader.from_batches(
schema, (validate_batch(batch) for batch in source)
source.schema, (validate_batch(batch) for batch in source)
)

return TableMerger(
Expand Down
63 changes: 39 additions & 24 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@
import pyarrow.fs as pa_fs
from pyarrow.lib import RecordBatchReader

from deltalake.schema import delta_arrow_schema_from_pandas

from ._internal import DeltaDataChecker as _DeltaDataChecker
from ._internal import batch_distinct
from ._internal import convert_to_deltalake as _convert_to_deltalake
from ._internal import write_new_deltalake as write_deltalake_pyarrow
from ._internal import write_to_deltalake as write_deltalake_rust
from .exceptions import DeltaProtocolError, TableNotFoundError
from .schema import (
convert_pyarrow_dataset,
convert_pyarrow_recordbatch,
convert_pyarrow_recordbatchreader,
convert_pyarrow_table,
)
from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable

try:
Expand Down Expand Up @@ -161,7 +165,7 @@ def write_deltalake(
overwrite_schema: If True, allows updating the schema of the table.
storage_options: options passed to the native delta filesystem. Unused if 'filesystem' is defined.
partition_filters: the partition filters that will be used for partition overwrite. Only used in pyarrow engine.
large_dtypes: If True, the table schema is checked against large_dtypes
large_dtypes: If True, the data schema is kept in large_dtypes, has no effect on pandas dataframe input
"""
table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options)
if table is not None:
Expand Down Expand Up @@ -230,13 +234,35 @@ def write_deltalake(
else:
data, schema = delta_arrow_schema_from_pandas(data)

table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options)

# We need to write against the latest table version
if table:
table.update_incremental()

if isinstance(data, RecordBatchReader):
data = convert_pyarrow_recordbatchreader(data, large_dtypes)
elif isinstance(data, pa.RecordBatch):
data = convert_pyarrow_recordbatch(data, large_dtypes)
elif isinstance(data, pa.Table):
data = convert_pyarrow_table(data, large_dtypes)
elif isinstance(data, ds.Dataset):
data = convert_pyarrow_dataset(data, large_dtypes)
elif _has_pandas and isinstance(data, pd.DataFrame):
if schema is not None:
data = pa.Table.from_pandas(data, schema=schema)
else:
data = convert_pyarrow_table(pa.Table.from_pandas(data), False)
elif isinstance(data, Iterable):
if schema is None:
if isinstance(data, RecordBatchReader):
schema = data.schema
elif isinstance(data, Iterable):
raise ValueError("You must provide schema if data is Iterable")
else:
schema = data.schema
raise ValueError("You must provide schema if data is Iterable")
else:
raise TypeError(
f"{type(data).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Iterable[RecordBatch], Table, Dataset or Pandas DataFrame are valid inputs for source."
)

if schema is None:
schema = data.schema

if filesystem is not None:
raise NotImplementedError(
Expand Down Expand Up @@ -269,7 +295,7 @@ def write_deltalake(
current_version = -1

dtype_map = {
pa.large_string(): pa.string(), # type: ignore
pa.large_string(): pa.string(),
}

def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType:
Expand Down Expand Up @@ -373,20 +399,9 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:

return batch

if isinstance(data, RecordBatchReader):
batch_iter = data
elif isinstance(data, pa.RecordBatch):
batch_iter = [data]
elif isinstance(data, pa.Table):
batch_iter = data.to_batches()
elif isinstance(data, ds.Dataset):
batch_iter = data.to_batches()
else:
batch_iter = data

data = RecordBatchReader.from_batches(
schema, (validate_batch(batch) for batch in batch_iter)
)
data = RecordBatchReader.from_batches(
schema, (validate_batch(batch) for batch in data)
)

if file_options is not None:
file_options.update(use_compliant_nested_type=False)
Expand Down
8 changes: 8 additions & 0 deletions python/licenses/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Licenses
Below are described which licenses apply to the deltalake package and to which areas of the source code.

### deltalake_license.txt (APACHE 2.0 License)
Applies to the full deltalake package source code.

### polars_license.txt (MIT License)
Applies solely to the `_convert_pa_schema_to_delta` function in `deltalake/schema.py`.
File renamed without changes.
19 changes: 19 additions & 0 deletions python/licenses/polars_license.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Copyright (c) 2020 Ritchie Vink

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "maturin"
name = "deltalake"
description = "Native Delta Lake Python binding based on delta-rs with Pandas integration"
readme = "README.md"
license = {file = "LICENSE.txt"}
license = {file = "licenses/deltalake_license.txt"}
requires-python = ">=3.8"
keywords = ["deltalake", "delta", "datalake", "pandas", "arrow"]
classifiers = [
Expand Down
13 changes: 13 additions & 0 deletions python/stubs/pyarrow/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,23 @@ type_for_alias: Any
date32: Any
date64: Any
decimal128: Any
int8: Any
int16: Any
int32: Any
int64: Any
uint8: Any
uint16: Any
uint32: Any
uint64: Any
float16: Any
float32: Any
float64: Any
large_string: Any
string: Any
large_binary: Any
binary: Any
large_list: Any
LargeListType: Any
dictionary: Any
timestamp: Any
TimestampType: Any
Expand Down
Loading

0 comments on commit 3ed7df0

Please sign in to comment.