Skip to content

Commit

Permalink
[Dataset] Validate sort key in Sort LogicalOperator (ray-project#34282
Browse files Browse the repository at this point in the history
)

As a followup of ray-project#32133, we should validate key with block.py:_validate_key_fn(), in generate_sort_fn() before doing sort.

Signed-off-by: Scott Lee <sjl@anyscale.com>
Signed-off-by: Jack He <jackhe2345@gmail.com>
  • Loading branch information
scottjlee authored and ProjectsByJackHe committed May 4, 2023
1 parent cb32a57 commit e3bed7a
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 33 deletions.
27 changes: 4 additions & 23 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
)

import ray
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import BlockMetadata
from ray.data._internal.util import capitalize
from ray.types import ObjectRef
from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas
from ray.data._internal.block_list import BlockList
from ray.data._internal.compute import (
UDF,
Expand Down Expand Up @@ -419,29 +419,10 @@ def _get_unified_blocks_schema(
blocks.ensure_metadata_for_first_block()

metadata = blocks.get_metadata(fetch_if_missing=False)
# Some blocks could be empty, in which case we cannot get their schema.
# TODO(ekl) validate schema is the same across different blocks.

# First check if there are blocks with computed schemas, then unify
# valid schemas from all such blocks.
schemas_to_unify = []
for m in metadata:
if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
schemas_to_unify.append(m.schema)
if schemas_to_unify:
# Check valid pyarrow installation before attempting schema unification
try:
import pyarrow as pa
except ImportError:
pa = None
# If the result contains PyArrow schemas, unify them
if pa is not None and any(
isinstance(s, pa.Schema) for s in schemas_to_unify
):
return unify_schemas(schemas_to_unify)
# Otherwise, if the resulting schemas are simple types (e.g. int),
# return the first schema.
return schemas_to_unify[0]
unified_schema = unify_block_metadata_schema(metadata)
if unified_schema is not None:
return unified_schema
if not fetch_if_missing:
return None
# Synchronously fetch the schema.
Expand Down
9 changes: 7 additions & 2 deletions python/ray/data/_internal/planner/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
)
from ray.data._internal.planner.exchange.sort_task_spec import SortKeyT, SortTaskSpec
from ray.data._internal.stats import StatsDict
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import _validate_key_fn
from ray.data.context import DataContext


Expand All @@ -22,7 +24,6 @@ def generate_sort_fn(
descending: bool,
) -> AllToAllTransformFn:
"""Generate function to sort blocks by the specified key column or key function."""
# TODO: validate key with block._validate_key_fn.

def fn(
key: SortKeyT,
Expand All @@ -31,11 +32,15 @@ def fn(
ctx: TaskContext,
) -> Tuple[List[RefBundle], StatsDict]:
blocks = []
metadata = []
for ref_bundle in refs:
for block, _ in ref_bundle.blocks:
for block, block_metadata in ref_bundle.blocks:
blocks.append(block)
metadata.append(block_metadata)
if len(blocks) == 0:
return (blocks, {})
unified_schema = unify_block_metadata_schema(metadata)
_validate_key_fn(unified_schema, key)

if isinstance(key, str):
key = [(key, "descending" if descending else "ascending")]
Expand Down
5 changes: 3 additions & 2 deletions python/ray/data/_internal/stage_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,14 @@ def do_sort(
block_list.clear()
else:
blocks = block_list
schema = ds.schema(fetch_if_missing=True)
if isinstance(key, list):
if not key:
raise ValueError("`key` must be a list of non-zero length")
for subkey in key:
_validate_key_fn(ds, subkey)
_validate_key_fn(schema, subkey)
else:
_validate_key_fn(ds, key)
_validate_key_fn(schema, key)
return sort_impl(blocks, clear_input_blocks, key, descending, ctx)

super().__init__(
Expand Down
31 changes: 31 additions & 0 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ray
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas
from ray.data.context import DataContext
from ray._private.utils import _get_pyarrow_version

Expand Down Expand Up @@ -462,3 +463,33 @@ def get_table_block_metadata(
return BlockAccessor.for_block(table).get_metadata(
input_files=None, exec_stats=stats.build()
)


def unify_block_metadata_schema(
metadata: List["BlockMetadata"],
) -> Optional[Union[type, "pyarrow.lib.Schema"]]:
"""For the input list of BlockMetadata, return a unified schema of the
corresponding blocks. If the metadata have no valid schema, returns None.
"""
# Some blocks could be empty, in which case we cannot get their schema.
# TODO(ekl) validate schema is the same across different blocks.

# First check if there are blocks with computed schemas, then unify
# valid schemas from all such blocks.
schemas_to_unify = []
for m in metadata:
if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
schemas_to_unify.append(m.schema)
if schemas_to_unify:
# Check valid pyarrow installation before attempting schema unification
try:
import pyarrow as pa
except ImportError:
pa = None
# If the result contains PyArrow schemas, unify them
if pa is not None and any(isinstance(s, pa.Schema) for s in schemas_to_unify):
return unify_schemas(schemas_to_unify)
# Otherwise, if the resulting schemas are simple types (e.g. int),
# return the first schema.
return schemas_to_unify[0]
return None
2 changes: 1 addition & 1 deletion python/ray/data/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _set_key_fn(self, on: KeyFn):
self._key_fn = on

def _validate(self, ds: "Datastream") -> None:
_validate_key_fn(ds, self._key_fn)
_validate_key_fn(ds.schema(fetch_if_missing=True), self._key_fn)


@PublicAPI
Expand Down
9 changes: 5 additions & 4 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import pandas
import pyarrow

from ray.data import Datastream
from ray.data._internal.block_builder import BlockBuilder
from ray.data.aggregate import AggregateFn

Expand All @@ -58,9 +57,11 @@
KeyFn = Union[None, str, Callable[[T], Any]]


def _validate_key_fn(ds: "Datastream", key: KeyFn) -> None:
"""Check the key function is valid on the given datastream."""
schema = ds.schema(fetch_if_missing=True)
def _validate_key_fn(
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
key: KeyFn,
) -> None:
"""Check the key function is valid on the given schema."""
if schema is None:
# Datastream is empty/cleared, validation not possible.
return
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,7 @@ def groupby(self, key: Optional[KeyFn]) -> "GroupedData[T]":
# Always allow None since groupby interprets that as grouping all
# records into a single global group.
if key is not None:
_validate_key_fn(self, key)
_validate_key_fn(self.schema(fetch_if_missing=True), key)

return GroupedData(self, key)

Expand Down
47 changes: 47 additions & 0 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,53 @@ def test_sort_e2e(
# assert [d["one"] for d in r2] == list(reversed(range(100)))


def test_sort_validate_keys(
ray_start_regular_shared,
enable_optimizer,
):
ds = ray.data.range(10)
assert ds.sort().take_all() == list(range(10))

invalid_col_name = "invalid_column"
with pytest.raises(
ValueError,
match=f"String key '{invalid_col_name}' requires datastream format to be "
"'arrow' or 'pandas', was 'simple'",
):
ds.sort(invalid_col_name).take_all()

ds_named = ray.data.from_items(
[
{"col1": 1, "col2": 2},
{"col1": 3, "col2": 4},
{"col1": 5, "col2": 6},
{"col1": 7, "col2": 8},
]
)

ds_sorted_col1 = ds_named.sort("col1", descending=True)
r1 = ds_sorted_col1.select_columns(["col1"]).take_all()
r2 = ds_sorted_col1.select_columns(["col2"]).take_all()
assert [d["col1"] for d in r1] == [7, 5, 3, 1]
assert [d["col2"] for d in r2] == [8, 6, 4, 2]

with pytest.raises(
ValueError,
match=f"The column '{invalid_col_name}' does not exist in the schema",
):
ds_named.sort(invalid_col_name).take_all()

def dummy_sort_fn(x):
return x

with pytest.raises(
ValueError,
match=f"Callable key '{dummy_sort_fn}' requires datastream format to be "
"'simple'",
):
ds_named.sort(dummy_sort_fn).take_all()


def test_aggregate_operator(ray_start_regular_shared, enable_optimizer):
planner = Planner()
read_op = Read(ParquetDatasource())
Expand Down

0 comments on commit e3bed7a

Please sign in to comment.