Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle nulls in file-level stats #1520

Merged
merged 2 commits into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: make check-rust

test-minimal:
name: Python Build (Python 3.7 PyArrow 7.0.0)
name: Python Build (Python 3.7 PyArrow 8.0.0)
runs-on: ubuntu-latest
env:
RUSTFLAGS: "-C debuginfo=0"
Expand Down Expand Up @@ -70,7 +70,7 @@ jobs:
source venv/bin/activate
make setup
# Install minimum PyArrow version
pip install -e .[pandas,devel] pyarrow==7.0.0
pip install -e .[pandas,devel] pyarrow==8.0.0
env:
RUSTFLAGS: "-C debuginfo=0"

Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only"
]
dependencies = [
"pyarrow>=7",
"pyarrow>=8",
'typing-extensions;python_version<"3.8"',
]

Expand Down
64 changes: 48 additions & 16 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,14 @@ fn json_value_to_py(value: &serde_json::Value, py: Python) -> PyObject {
///
/// PyArrow uses this expression to determine which Dataset fragments may be
/// skipped during a scan.
///
/// Partition values are translated to equality expressions (if they are valid)
/// or is_null expression otherwise. For example, if the partition is
/// {"date": "2021-01-01", "x": null}, then the expression is:
/// field(date) = "2021-01-01" AND x IS NULL
///
/// Statistics are translated into inequalities. If there are null values, then
/// they must be OR'd with is_null.
fn filestats_to_expression<'py>(
py: Python<'py>,
schema: &PyArrowType<ArrowSchema>,
Expand Down Expand Up @@ -616,10 +624,27 @@ fn filestats_to_expression<'py>(
.call1((column,))?
.call_method1("__eq__", (converted_value,)),
);
} else {
expressions.push(field.call1((column,))?.call_method0("is_null"));
}
}

if let Some(stats) = stats {
let mut has_nulls_set: HashSet<String> = HashSet::new();

for (col_name, null_count) in stats.null_count.iter().filter_map(|(k, v)| match v {
ColumnCountStat::Value(val) => Some((k, val)),
_ => None,
}) {
if *null_count == 0 {
expressions.push(field.call1((col_name,))?.call_method0("is_valid"));
} else if *null_count == stats.num_records {
expressions.push(field.call1((col_name,))?.call_method0("is_null"));
} else {
has_nulls_set.insert(col_name.clone());
}
}

for (col_name, minimum) in stats.min_values.iter().filter_map(|(k, v)| match v {
ColumnValueStat::Value(val) => Some((k.clone(), json_value_to_py(val, py))),
// TODO(wjones127): Handle nested field statistics.
Expand All @@ -628,7 +653,17 @@ fn filestats_to_expression<'py>(
}) {
let maybe_minimum = cast_to_type(&col_name, minimum, &schema.0);
if let Ok(minimum) = maybe_minimum {
expressions.push(field.call1((col_name,))?.call_method1("__ge__", (minimum,)));
let field_expr = field.call1((&col_name,))?;
let expr = field_expr.call_method1("__ge__", (minimum,));
let expr = if has_nulls_set.contains(&col_name) {
// col >= min_value OR col is null
let is_null_expr = field_expr.call_method0("is_null");
expr?.call_method1("__or__", (is_null_expr?,))
} else {
// col >= min_value
expr
};
expressions.push(expr);
}
}

Expand All @@ -638,20 +673,17 @@ fn filestats_to_expression<'py>(
}) {
let maybe_maximum = cast_to_type(&col_name, maximum, &schema.0);
if let Ok(maximum) = maybe_maximum {
expressions.push(field.call1((col_name,))?.call_method1("__le__", (maximum,)));
}
}

for (col_name, null_count) in stats.null_count.iter().filter_map(|(k, v)| match v {
ColumnCountStat::Value(val) => Some((k, val)),
_ => None,
}) {
if *null_count == stats.num_records {
expressions.push(field.call1((col_name.clone(),))?.call_method0("is_null"));
}

if *null_count == 0 {
expressions.push(field.call1((col_name.clone(),))?.call_method0("is_valid"));
let field_expr = field.call1((&col_name,))?;
let expr = field_expr.call_method1("__le__", (maximum,));
let expr = if has_nulls_set.contains(&col_name) {
// col <= max_value OR col is null
let is_null_expr = field_expr.call_method0("is_null");
expr?.call_method1("__or__", (is_null_expr?,))
} else {
// col <= max_value
expr
};
expressions.push(expr);
}
}
}
Expand All @@ -661,7 +693,7 @@ fn filestats_to_expression<'py>(
} else {
expressions
.into_iter()
.reduce(|accum, item| accum?.getattr("__and__")?.call1((item?,)))
.reduce(|accum, item| accum?.call_method1("__and__", (item?,)))
.transpose()
}
}
Expand Down
73 changes: 73 additions & 0 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from deltalake.exceptions import DeltaProtocolError
from deltalake.table import ProtocolVersions
from deltalake.writer import write_deltalake

try:
import pandas as pd
Expand Down Expand Up @@ -543,3 +544,75 @@ def read_table():
t.start()
for t in threads:
t.join()


def assert_num_fragments(table, predicate, count):
frags = table.to_pyarrow_dataset().get_fragments(filter=predicate)
assert len(list(frags)) == count


def test_filter_nulls(tmp_path: Path):
def assert_scan_equals(table, predicate, expected):
data = table.to_pyarrow_dataset().to_table(filter=predicate).sort_by("part")
assert data == expected

# 1 all-valid part, 1 all-null part, and 1 mixed part.
data = pa.table(
{"part": ["a", "a", "b", "b", "c", "c"], "value": [1, 1, None, None, 2, None]}
)

write_deltalake(tmp_path, data, partition_by="part")

table = DeltaTable(tmp_path)

# Note: we assert number of fragments returned because that verifies
# that file skipping is working properly.

# is valid predicate
predicate = ds.field("value").is_valid()
assert_num_fragments(table, predicate, 2)
expected = pa.table({"part": ["a", "a", "c"], "value": [1, 1, 2]})
assert_scan_equals(table, predicate, expected)

# is null predicate
predicate = ds.field("value").is_null()
assert_num_fragments(table, predicate, 2)
expected = pa.table(
{"part": ["b", "b", "c"], "value": pa.array([None, None, None], pa.int64())}
)
assert_scan_equals(table, predicate, expected)

# inequality predicate
predicate = ds.field("value") > 1
assert_num_fragments(table, predicate, 1)
expected = pa.table({"part": ["c"], "value": pa.array([2], pa.int64())})
assert_scan_equals(table, predicate, expected)

# also test nulls in partition values
data = pa.table({"part": pa.array([None], pa.string()), "value": [3]})
write_deltalake(
table,
data,
mode="append",
partition_by="part",
)

# null predicate
predicate = ds.field("part").is_null()
assert_num_fragments(table, predicate, 1)
expected = pa.table({"part": pa.array([None], pa.string()), "value": [3]})
assert_scan_equals(table, predicate, expected)

# valid predicate
predicate = ds.field("part").is_valid()
assert_num_fragments(table, predicate, 3)
expected = pa.table(
{"part": ["a", "a", "b", "b", "c", "c"], "value": [1, 1, None, None, 2, None]}
)
assert_scan_equals(table, predicate, expected)

# inequality predicate
predicate = ds.field("part") < "c"
assert_num_fragments(table, predicate, 2)
expected = pa.table({"part": ["a", "a", "b", "b"], "value": [1, 1, None, None]})
assert_scan_equals(table, predicate, expected)