Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust): raise schema mismatch when decimal is not subset #2330

Merged
merged 9 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions crates/core/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,15 +840,42 @@ fn try_cast_batch(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowE
(f.data_type(), target_field.data_type())
{
try_cast_batch(fields0, fields1)
} else if !can_cast_types(f.data_type(), target_field.data_type()) {
Err(ArrowError::SchemaError(format!(
"Cannot cast field {} from {} to {}",
f.name(),
f.data_type(),
target_field.data_type()
)))
} else {
Ok(())
match (f.data_type(), target_field.data_type()) {
(
DataType::Decimal128(left_precision, left_scale) | DataType::Decimal256(left_precision, left_scale),
DataType::Decimal128(right_precision, right_scale)
) => {
if left_precision <= right_precision && left_scale <= right_scale {
Ok(())
} else {
Err(ArrowError::SchemaError(format!(
"Cannot cast field {} from {} to {}",
f.name(),
f.data_type(),
target_field.data_type()
)))
}
},
(
_,
DataType::Decimal256(_, _),
) => {
unreachable!("Target field can never be Decimal 256. According to the protocol: 'The precision and scale can be up to 38.'")
},
(left, right) => {
if !can_cast_types(left, right) {
Err(ArrowError::SchemaError(format!(
"Cannot cast field {} from {} to {}",
f.name(),
f.data_type(),
target_field.data_type()
)))
} else {
Ok(())
}
}
}
}
} else {
Err(ArrowError::SchemaError(format!(
Expand Down
30 changes: 30 additions & 0 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,3 +1498,33 @@ def test_empty(existing_table: DeltaTable):
empty_table = pa.Table.from_pylist([], schema=schema)
with pytest.raises(DeltaError, match="No data source supplied to write command"):
write_deltalake(existing_table, empty_table, mode="append", engine="rust")


def test_rust_decimal_cast(tmp_path: pathlib.Path):
import re
from decimal import Decimal

data = pa.table({"x": pa.array([Decimal("100.1")])})

write_deltalake(tmp_path, data, mode="append", engine="rust")

assert DeltaTable(tmp_path).to_pyarrow_table()["x"][0].as_py() == Decimal("100.1")

# Write smaller decimal, works since it's fits in the previous decimal precision, scale
data = pa.table({"x": pa.array([Decimal("10.1")])})
write_deltalake(tmp_path, data, mode="append", engine="rust")

data = pa.table({"x": pa.array([Decimal("1000.1")])})
# write decimal that is larger than target type in table
with pytest.raises(
SchemaMismatchError,
match=re.escape(
"Cannot cast field x from Decimal128(5, 1) to Decimal128(4, 1)"
),
):
write_deltalake(tmp_path, data, mode="append", engine="rust")

with pytest.raises(SchemaMismatchError, match="Cannot merge types decimal"):
write_deltalake(
tmp_path, data, mode="append", schema_mode="merge", engine="rust"
)
Loading