Skip to content

Commit

Permalink
[Datasets] Fix schema unification for Datasets with ragged Arrow arra…
Browse files Browse the repository at this point in the history
…ys (ray-project#31076)

When creating Datasets with ragged arrays, the resulting Dataset incorrectly uses ArrowTensorArray instead of ArrowVariableShapedTensorArray as the underlying schema type. This PR refactors existing logic for schema unification into a separate function, which is now called during Arrow table concatenation and schema fetching to correct type promotion involving ragged arrays.

Signed-off-by: Scott Lee <sjl@anyscale.com>
Signed-off-by: tmynn <hovhannes.tamoyan@gmail.com>
  • Loading branch information
scottjlee authored and tamohannes committed Jan 25, 2023
1 parent 49b8674 commit c697b4e
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 50 deletions.
84 changes: 57 additions & 27 deletions python/ray/air/util/tensor_extensions/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def shape(self):
"""
return self._shape

@property
def scalar_type(self):
"""Returns the type of the underlying tensor elements."""
return self.storage_type.value_type

def to_pandas_dtype(self):
"""
Convert Arrow extension type to corresponding Pandas dtype.
Expand Down Expand Up @@ -132,6 +137,48 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return str(self)

@classmethod
def _need_variable_shaped_tensor_array(
cls,
array_types: Sequence[
Union["ArrowTensorType", "ArrowVariableShapedTensorType"]
],
) -> bool:
"""
Whether the provided list of tensor types needs a variable-shaped
representation (i.e. `ArrowVariableShapedTensorType`) when concatenating
or chunking. If one or more of the tensor types in `array_types` are
variable-shaped and/or any of the tensor arrays have a different shape
than the others, a variable-shaped tensor array representation will be
required and this method will return True.
Args:
array_types: List of tensor types to check if a variable-shaped
representation is required for concatenation
Returns:
True if concatenating arrays with types `array_types` requires
a variable-shaped representation
"""
shape = None
for arr_type in array_types:
# If at least one of the arrays is variable-shaped, we can immediately
# short-circuit since we require a variable-shaped representation.
if isinstance(arr_type, ArrowVariableShapedTensorType):
return True
if not isinstance(arr_type, ArrowTensorType):
raise ValueError(
"All provided array types must be an instance of either "
"ArrowTensorType or ArrowVariableShapedTensorType, but "
f"got {arr_type}"
)
# We need variable-shaped representation if any of the tensor arrays have
# different shapes.
if shape is not None and arr_type.shape != shape:
return True
shape = arr_type.shape
return False


if _arrow_extension_scalars_are_subclassable():
# TODO(Clark): Remove this version guard once we only support Arrow 9.0.0+.
Expand Down Expand Up @@ -410,7 +457,8 @@ def _concat_same_type(
of the tensor arrays have a different shape than the others, a variable-shaped
tensor array will be returned.
"""
if cls._need_variable_shaped_tensor_array(to_concat):
to_concat_types = [arr.type for arr in to_concat]
if ArrowTensorType._need_variable_shaped_tensor_array(to_concat_types):
# Need variable-shaped tensor array.
# TODO(Clark): Eliminate this NumPy roundtrip by directly constructing the
# underlying storage array buffers (NumPy roundtrip will not be zero-copy
Expand All @@ -432,7 +480,8 @@ def _chunk_tensor_arrays(
"""
Create a ChunkedArray from multiple tensor arrays.
"""
if cls._need_variable_shaped_tensor_array(arrs):
arrs_types = [arr.type for arr in arrs]
if ArrowTensorType._need_variable_shaped_tensor_array(arrs_types):
new_arrs = []
for a in arrs:
if isinstance(a.type, ArrowTensorType):
Expand All @@ -442,31 +491,6 @@ def _chunk_tensor_arrays(
arrs = new_arrs
return pa.chunked_array(arrs)

@classmethod
def _need_variable_shaped_tensor_array(
cls, arrs: Sequence[Union["ArrowTensorArray", "ArrowVariableShapedTensorArray"]]
) -> bool:
"""
Whether the provided tensor arrays need a variable-shaped representation when
concatenating or chunking.
If one or more of the tensor arrays in arrs are variable-shaped and/or any of
the tensor arrays have a different shape than the others, a variable-shaped
tensor array representation will be required and this method will return True.
"""
needs_variable_shaped = False
shape = None
for a in arrs:
a_type = a.type
if isinstance(a_type, ArrowVariableShapedTensorType) or (
shape is not None and a_type.shape != shape
):
needs_variable_shaped = True
break
if shape is None:
shape = a_type.shape
return needs_variable_shaped

def to_variable_shaped_tensor_array(self) -> "ArrowVariableShapedTensorArray":
"""
Convert this tensor array to a variable-shaped tensor array.
Expand Down Expand Up @@ -529,6 +553,12 @@ def ndim(self) -> int:
"""Return the number of dimensions in the tensor elements."""
return self._ndim

@property
def scalar_type(self):
"""Returns the type of the underlying tensor elements."""
data_field_index = self.storage_type.get_field_index("data")
return self.storage_type[data_field_index].type.value_type

def __reduce__(self):
return (
ArrowVariableShapedTensorType,
Expand Down
86 changes: 66 additions & 20 deletions python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,54 @@ def take_table(
return table


def unify_schemas(
schemas: List["pyarrow.Schema"],
) -> "pyarrow.Schema":
"""Version of `pyarrow.unify_schemas()` which also handles checks for
variable-shaped tensors in the given schemas."""
from ray.air.util.tensor_extensions.arrow import (
ArrowTensorType,
ArrowVariableShapedTensorType,
)

schemas_to_unify = []
schema_tensor_field_overrides = {}

if any(isinstance(type_, pyarrow.ExtensionType) for type_ in schemas[0].types):
# If we have pyarrow extension types that may potentially be variable shaped,
# examine the first schema to gather the columns that need type conversions.
for col_idx in range(len(schemas[0].types)):
tensor_array_types = [
s.types[col_idx]
for s in schemas
if isinstance(s.types[col_idx], pyarrow.ExtensionType)
]
if ArrowTensorType._need_variable_shaped_tensor_array(tensor_array_types):
if isinstance(tensor_array_types[0], ArrowVariableShapedTensorType):
new_type = tensor_array_types[0]
elif isinstance(tensor_array_types[0], ArrowTensorType):
new_type = ArrowVariableShapedTensorType(
dtype=tensor_array_types[0].scalar_type,
ndim=len(tensor_array_types[0].shape),
)
else:
raise ValueError(
"Detected need for variable shaped tensor representation, "
f"but schema is not ArrayTensorType: {tensor_array_types[0]}"
)
schema_tensor_field_overrides[col_idx] = new_type
# Go through all schemas and update the types of columns from the above loop.
for schema in schemas:
for col_idx, col_new_type in schema_tensor_field_overrides.items():
var_shaped_col = schema.field(col_idx).with_type(col_new_type)
schema = schema.set(col_idx, var_shaped_col)
schemas_to_unify.append(schema)
else:
schemas_to_unify = schemas
# Let Arrow unify the schema of non-tensor extension type columns.
return pyarrow.unify_schemas(schemas_to_unify)


def _concatenate_chunked_arrays(arrs: "pyarrow.ChunkedArray") -> "pyarrow.ChunkedArray":
"""
Concatenate provided chunked arrays into a single chunked array.
Expand Down Expand Up @@ -94,7 +142,6 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table":
if any(isinstance(type_, pyarrow.ExtensionType) for type_ in schema.types):
# Custom handling for extension array columns.
cols = []
schema_tensor_field_overrides = {}
for col_name in schema.names:
col_chunked_arrays = []
for block in blocks:
Expand All @@ -111,28 +158,27 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table":
col = ArrowTensorArray._chunk_tensor_arrays(
[chunk for ca in col_chunked_arrays for chunk in ca.chunks]
)
if schema.field(col_name).type != col.type:
# Ensure that the field's type is properly updated in the schema if
# a collection of homogeneous-shaped columns resulted in a
# variable-shaped tensor column once concatenated.
new_field = schema.field(col_name).with_type(col.type)
field_idx = schema.get_field_index(col_name)
schema_tensor_field_overrides[field_idx] = new_field
else:
col = _concatenate_chunked_arrays(col_chunked_arrays)
cols.append(col)
# Unify schemas.
schemas = []
for block in blocks:
schema = block.schema
# If concatenating uniform tensor columns results in a variable-shaped
# tensor columns, override the column type for all blocks.
if schema_tensor_field_overrides:
for idx, field in schema_tensor_field_overrides.items():
schema = schema.set(idx, field)
schemas.append(schema)
# Let Arrow unify the schema of non-tensor extension type columns.
schema = pyarrow.unify_schemas(schemas)

# If the result contains pyarrow schemas, unify them
schemas_to_unify = [b.schema for b in blocks]
if pyarrow is not None and any(
isinstance(s, pyarrow.Schema) for s in schemas_to_unify
):
schema = unify_schemas(schemas_to_unify)
else:
# Otherwise, if the resulting schemas are simple types (e.g. int),
# check that all blocks with valid schemas have the same type.
schema = schemas_to_unify[0]
if schema is not None:
for s in schemas_to_unify:
if s is not None and s != schema:
raise ValueError(
"Found blocks with different types "
f"in schemas: {schemas_to_unify}"
)
# Build the concatenated table.
table = pyarrow.Table.from_arrays(cols, schema=schema)
# Validate table schema (this is a cheap check by default).
Expand Down
24 changes: 22 additions & 2 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

import ray
from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas
from ray.data._internal.block_list import BlockList
from ray.data._internal.compute import (
UDF,
Expand Down Expand Up @@ -261,14 +262,33 @@ def schema(
metadata = blocks.get_metadata(fetch_if_missing=False)
# Some blocks could be empty, in which case we cannot get their schema.
# TODO(ekl) validate schema is the same across different blocks.

# First check if there are blocks with computed schemas, then unify
# valid schemas from all such blocks.
schemas_to_unify = []
for m in metadata:
if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
return m.schema
schemas_to_unify.append(m.schema)
if schemas_to_unify:
# Check valid pyarrow installation before attempting schema unification
try:
import pyarrow as pa
except ImportError:
pa = None
# If the result contains PyArrow schemas, unify them
if pa is not None and any(
isinstance(s, pa.Schema) for s in schemas_to_unify
):
return unify_schemas(schemas_to_unify)
# Otherwise, if the resulting schemas are simple types (e.g. int),
# return the first schema.
return schemas_to_unify[0]
if not fetch_if_missing:
return None
# Synchronously fetch the schema.
# For lazy block lists, this launches read tasks and fetches block metadata
# until we find valid block schema.
# until we find the first valid block schema. This is to minimize new
# computations when fetching the schema.
for _, m in blocks.iter_blocks_with_metadata():
if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
return m.schema
Expand Down
18 changes: 18 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import ray
from ray._private.test_utils import wait_for_condition
from ray.air.util.tensor_extensions.arrow import ArrowVariableShapedTensorType
from ray.data._internal.stats import _StatsActor
from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.block_builder import BlockBuilder
Expand Down Expand Up @@ -5398,6 +5399,23 @@ def test_dataset_schema_after_read_stats(ray_start_cluster):
assert schema == ds.schema()


def test_ragged_tensors(ray_start_regular_shared):
"""Test Arrow type promotion between ArrowTensorType and
ArrowVariableShapedTensorType when a column contains ragged tensors."""
import numpy as np

ds = ray.data.from_items(
[
{"spam": np.zeros((32, 32, 5))},
{"spam": np.zeros((64, 64, 5))},
]
)
new_type = ds.schema().types[0].scalar_type
assert ds.schema().types == [
ArrowVariableShapedTensorType(dtype=new_type, ndim=3),
]


if __name__ == "__main__":
import sys

Expand Down
Loading

0 comments on commit c697b4e

Please sign in to comment.