Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dataset] Validate sort key in Sort LogicalOperator #34282

Merged
merged 13 commits into from
Apr 14, 2023
27 changes: 4 additions & 23 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
)

import ray
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import BlockMetadata
from ray.data._internal.util import capitalize
from ray.types import ObjectRef
from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas
from ray.data._internal.block_list import BlockList
from ray.data._internal.compute import (
UDF,
Expand Down Expand Up @@ -419,29 +419,10 @@ def _get_unified_blocks_schema(
blocks.ensure_metadata_for_first_block()

metadata = blocks.get_metadata(fetch_if_missing=False)
# Some blocks could be empty, in which case we cannot get their schema.
# TODO(ekl) validate schema is the same across different blocks.

# First check if there are blocks with computed schemas, then unify
# valid schemas from all such blocks.
schemas_to_unify = []
for m in metadata:
if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
schemas_to_unify.append(m.schema)
if schemas_to_unify:
# Check valid pyarrow installation before attempting schema unification
try:
import pyarrow as pa
except ImportError:
pa = None
# If the result contains PyArrow schemas, unify them
if pa is not None and any(
isinstance(s, pa.Schema) for s in schemas_to_unify
):
return unify_schemas(schemas_to_unify)
# Otherwise, if the resulting schemas are simple types (e.g. int),
# return the first schema.
return schemas_to_unify[0]
unified_schema = unify_block_metadata_schema(metadata)
if unified_schema is not None:
return unified_schema
if not fetch_if_missing:
return None
# Synchronously fetch the schema.
Expand Down
9 changes: 7 additions & 2 deletions python/ray/data/_internal/planner/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
)
from ray.data._internal.planner.exchange.sort_task_spec import SortKeyT, SortTaskSpec
from ray.data._internal.stats import StatsDict
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import _validate_key_fn
from ray.data.context import DataContext


Expand All @@ -22,7 +24,6 @@ def generate_sort_fn(
descending: bool,
) -> AllToAllTransformFn:
"""Generate function to sort blocks by the specified key column or key function."""
# TODO: validate key with block._validate_key_fn.

def fn(
key: SortKeyT,
Expand All @@ -31,11 +32,15 @@ def fn(
ctx: TaskContext,
) -> Tuple[List[RefBundle], StatsDict]:
blocks = []
metadata = []
for ref_bundle in refs:
for block, _ in ref_bundle.blocks:
for block, block_metadata in ref_bundle.blocks:
blocks.append(block)
metadata.append(block_metadata)
if len(blocks) == 0:
return (blocks, {})
unified_schema = unify_block_metadata_schema(metadata)
_validate_key_fn(unified_schema, key)

if isinstance(key, str):
key = [(key, "descending" if descending else "ascending")]
Expand Down
5 changes: 3 additions & 2 deletions python/ray/data/_internal/stage_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,14 @@ def do_sort(
block_list.clear()
else:
blocks = block_list
schema = ds.schema(fetch_if_missing=True)
if isinstance(key, list):
if not key:
raise ValueError("`key` must be a list of non-zero length")
for subkey in key:
_validate_key_fn(ds, subkey)
_validate_key_fn(schema, subkey)
else:
_validate_key_fn(ds, key)
_validate_key_fn(schema, key)
return sort_impl(blocks, clear_input_blocks, key, descending, ctx)

super().__init__(
Expand Down
31 changes: 31 additions & 0 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ray
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas
from ray.data.context import DataContext
from ray._private.utils import _get_pyarrow_version

Expand Down Expand Up @@ -462,3 +463,33 @@ def get_table_block_metadata(
return BlockAccessor.for_block(table).get_metadata(
input_files=None, exec_stats=stats.build()
)


def unify_block_metadata_schema(
metadata: List["BlockMetadata"],
) -> Optional[Union[type, "pyarrow.lib.Schema"]]:
"""For the input list of BlockMetadata, return a unified schema of the
corresponding blocks. If the metadata have no valid schema, returns None.
"""
# Some blocks could be empty, in which case we cannot get their schema.
# TODO(ekl) validate schema is the same across different blocks.

# First check if there are blocks with computed schemas, then unify
# valid schemas from all such blocks.
schemas_to_unify = []
for m in metadata:
if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
schemas_to_unify.append(m.schema)
if schemas_to_unify:
# Check valid pyarrow installation before attempting schema unification
try:
import pyarrow as pa
except ImportError:
pa = None
# If the result contains PyArrow schemas, unify them
if pa is not None and any(isinstance(s, pa.Schema) for s in schemas_to_unify):
return unify_schemas(schemas_to_unify)
# Otherwise, if the resulting schemas are simple types (e.g. int),
# return the first schema.
return schemas_to_unify[0]
return None
2 changes: 1 addition & 1 deletion python/ray/data/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _set_key_fn(self, on: KeyFn):
self._key_fn = on

def _validate(self, ds: "Datastream") -> None:
_validate_key_fn(ds, self._key_fn)
_validate_key_fn(ds.schema(fetch_if_missing=True), self._key_fn)


@PublicAPI
Expand Down
9 changes: 5 additions & 4 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import pandas
import pyarrow

from ray.data import Datastream
from ray.data._internal.block_builder import BlockBuilder
from ray.data.aggregate import AggregateFn

Expand All @@ -58,9 +57,11 @@
KeyFn = Union[None, str, Callable[[T], Any]]


def _validate_key_fn(ds: "Datastream", key: KeyFn) -> None:
"""Check the key function is valid on the given datastream."""
schema = ds.schema(fetch_if_missing=True)
def _validate_key_fn(
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
key: KeyFn,
) -> None:
"""Check the key function is valid on the given schema."""
if schema is None:
# Datastream is empty/cleared, validation not possible.
return
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,7 @@ def groupby(self, key: Optional[KeyFn]) -> "GroupedData[T]":
# Always allow None since groupby interprets that as grouping all
# records into a single global group.
if key is not None:
_validate_key_fn(self, key)
_validate_key_fn(self.schema(fetch_if_missing=True), key)

return GroupedData(self, key)

Expand Down
47 changes: 47 additions & 0 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,53 @@ def test_sort_e2e(
# assert [d["one"] for d in r2] == list(reversed(range(100)))


def test_sort_validate_keys(
ray_start_regular_shared,
enable_optimizer,
):
ds = ray.data.range(10)
assert ds.sort().take_all() == list(range(10))

invalid_col_name = "invalid_column"
with pytest.raises(
ValueError,
match=f"String key '{invalid_col_name}' requires datastream format to be "
"'arrow' or 'pandas', was 'simple'",
):
ds.sort(invalid_col_name).take_all()

ds_named = ray.data.from_items(
[
{"col1": 1, "col2": 2},
{"col1": 3, "col2": 4},
{"col1": 5, "col2": 6},
{"col1": 7, "col2": 8},
]
)

ds_sorted_col1 = ds_named.sort("col1", descending=True)
r1 = ds_sorted_col1.select_columns(["col1"]).take_all()
r2 = ds_sorted_col1.select_columns(["col2"]).take_all()
assert [d["col1"] for d in r1] == [7, 5, 3, 1]
assert [d["col2"] for d in r2] == [8, 6, 4, 2]

with pytest.raises(
ValueError,
match=f"The column '{invalid_col_name}' does not exist in the schema",
):
ds_named.sort(invalid_col_name).take_all()

def dummy_sort_fn(x):
return x

with pytest.raises(
ValueError,
match=f"Callable key '{dummy_sort_fn}' requires datastream format to be "
"'simple'",
):
ds_named.sort(dummy_sort_fn).take_all()


def test_aggregate_operator(ray_start_regular_shared, enable_optimizer):
planner = Planner()
read_op = Read(ParquetDatasource())
Expand Down