diff --git a/python/ray/air/util/tensor_extensions/arrow.py b/python/ray/air/util/tensor_extensions/arrow.py index 91b5daecbefe9..7412e2d30c23e 100644 --- a/python/ray/air/util/tensor_extensions/arrow.py +++ b/python/ray/air/util/tensor_extensions/arrow.py @@ -79,6 +79,11 @@ def shape(self): """ return self._shape + @property + def scalar_type(self): + """Returns the type of the underlying tensor elements.""" + return self.storage_type.value_type + def to_pandas_dtype(self): """ Convert Arrow extension type to corresponding Pandas dtype. @@ -132,6 +137,48 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) + @classmethod + def _need_variable_shaped_tensor_array( + cls, + array_types: Sequence[ + Union["ArrowTensorType", "ArrowVariableShapedTensorType"] + ], + ) -> bool: + """ + Whether the provided list of tensor types needs a variable-shaped + representation (i.e. `ArrowVariableShapedTensorType`) when concatenating + or chunking. If one or more of the tensor types in `array_types` are + variable-shaped and/or any of the tensor arrays have a different shape + than the others, a variable-shaped tensor array representation will be + required and this method will return True. + + Args: + array_types: List of tensor types to check if a variable-shaped + representation is required for concatenation + + Returns: + True if concatenating arrays with types `array_types` requires + a variable-shaped representation + """ + shape = None + for arr_type in array_types: + # If at least one of the arrays is variable-shaped, we can immediately + # short-circuit since we require a variable-shaped representation. + if isinstance(arr_type, ArrowVariableShapedTensorType): + return True + if not isinstance(arr_type, ArrowTensorType): + raise ValueError( + "All provided array types must be an instance of either " + "ArrowTensorType or ArrowVariableShapedTensorType, but " + f"got {arr_type}" + ) + # We need variable-shaped representation if any of the tensor arrays have + # different shapes. + if shape is not None and arr_type.shape != shape: + return True + shape = arr_type.shape + return False + if _arrow_extension_scalars_are_subclassable(): # TODO(Clark): Remove this version guard once we only support Arrow 9.0.0+. @@ -410,7 +457,8 @@ def _concat_same_type( of the tensor arrays have a different shape than the others, a variable-shaped tensor array will be returned. """ - if cls._need_variable_shaped_tensor_array(to_concat): + to_concat_types = [arr.type for arr in to_concat] + if ArrowTensorType._need_variable_shaped_tensor_array(to_concat_types): # Need variable-shaped tensor array. # TODO(Clark): Eliminate this NumPy roundtrip by directly constructing the # underlying storage array buffers (NumPy roundtrip will not be zero-copy @@ -432,7 +480,8 @@ def _chunk_tensor_arrays( """ Create a ChunkedArray from multiple tensor arrays. """ - if cls._need_variable_shaped_tensor_array(arrs): + arrs_types = [arr.type for arr in arrs] + if ArrowTensorType._need_variable_shaped_tensor_array(arrs_types): new_arrs = [] for a in arrs: if isinstance(a.type, ArrowTensorType): @@ -442,31 +491,6 @@ def _chunk_tensor_arrays( arrs = new_arrs return pa.chunked_array(arrs) - @classmethod - def _need_variable_shaped_tensor_array( - cls, arrs: Sequence[Union["ArrowTensorArray", "ArrowVariableShapedTensorArray"]] - ) -> bool: - """ - Whether the provided tensor arrays need a variable-shaped representation when - concatenating or chunking. - - If one or more of the tensor arrays in arrs are variable-shaped and/or any of - the tensor arrays have a different shape than the others, a variable-shaped - tensor array representation will be required and this method will return True. - """ - needs_variable_shaped = False - shape = None - for a in arrs: - a_type = a.type - if isinstance(a_type, ArrowVariableShapedTensorType) or ( - shape is not None and a_type.shape != shape - ): - needs_variable_shaped = True - break - if shape is None: - shape = a_type.shape - return needs_variable_shaped - def to_variable_shaped_tensor_array(self) -> "ArrowVariableShapedTensorArray": """ Convert this tensor array to a variable-shaped tensor array. @@ -529,6 +553,12 @@ def ndim(self) -> int: """Return the number of dimensions in the tensor elements.""" return self._ndim + @property + def scalar_type(self): + """Returns the type of the underlying tensor elements.""" + data_field_index = self.storage_type.get_field_index("data") + return self.storage_type[data_field_index].type.value_type + def __reduce__(self): return ( ArrowVariableShapedTensorType, diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 2bb61e496090e..8a168e04cef15 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -45,6 +45,54 @@ def take_table( return table +def unify_schemas( + schemas: List["pyarrow.Schema"], +) -> "pyarrow.Schema": + """Version of `pyarrow.unify_schemas()` which also handles checks for + variable-shaped tensors in the given schemas.""" + from ray.air.util.tensor_extensions.arrow import ( + ArrowTensorType, + ArrowVariableShapedTensorType, + ) + + schemas_to_unify = [] + schema_tensor_field_overrides = {} + + if any(isinstance(type_, pyarrow.ExtensionType) for type_ in schemas[0].types): + # If we have pyarrow extension types that may potentially be variable shaped, + # examine the first schema to gather the columns that need type conversions. + for col_idx in range(len(schemas[0].types)): + tensor_array_types = [ + s.types[col_idx] + for s in schemas + if isinstance(s.types[col_idx], pyarrow.ExtensionType) + ] + if ArrowTensorType._need_variable_shaped_tensor_array(tensor_array_types): + if isinstance(tensor_array_types[0], ArrowVariableShapedTensorType): + new_type = tensor_array_types[0] + elif isinstance(tensor_array_types[0], ArrowTensorType): + new_type = ArrowVariableShapedTensorType( + dtype=tensor_array_types[0].scalar_type, + ndim=len(tensor_array_types[0].shape), + ) + else: + raise ValueError( + "Detected need for variable shaped tensor representation, " + f"but schema is not ArrayTensorType: {tensor_array_types[0]}" + ) + schema_tensor_field_overrides[col_idx] = new_type + # Go through all schemas and update the types of columns from the above loop. + for schema in schemas: + for col_idx, col_new_type in schema_tensor_field_overrides.items(): + var_shaped_col = schema.field(col_idx).with_type(col_new_type) + schema = schema.set(col_idx, var_shaped_col) + schemas_to_unify.append(schema) + else: + schemas_to_unify = schemas + # Let Arrow unify the schema of non-tensor extension type columns. + return pyarrow.unify_schemas(schemas_to_unify) + + def _concatenate_chunked_arrays(arrs: "pyarrow.ChunkedArray") -> "pyarrow.ChunkedArray": """ Concatenate provided chunked arrays into a single chunked array. @@ -94,7 +142,6 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table": if any(isinstance(type_, pyarrow.ExtensionType) for type_ in schema.types): # Custom handling for extension array columns. cols = [] - schema_tensor_field_overrides = {} for col_name in schema.names: col_chunked_arrays = [] for block in blocks: @@ -111,28 +158,27 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table": col = ArrowTensorArray._chunk_tensor_arrays( [chunk for ca in col_chunked_arrays for chunk in ca.chunks] ) - if schema.field(col_name).type != col.type: - # Ensure that the field's type is properly updated in the schema if - # a collection of homogeneous-shaped columns resulted in a - # variable-shaped tensor column once concatenated. - new_field = schema.field(col_name).with_type(col.type) - field_idx = schema.get_field_index(col_name) - schema_tensor_field_overrides[field_idx] = new_field else: col = _concatenate_chunked_arrays(col_chunked_arrays) cols.append(col) - # Unify schemas. - schemas = [] - for block in blocks: - schema = block.schema - # If concatenating uniform tensor columns results in a variable-shaped - # tensor columns, override the column type for all blocks. - if schema_tensor_field_overrides: - for idx, field in schema_tensor_field_overrides.items(): - schema = schema.set(idx, field) - schemas.append(schema) - # Let Arrow unify the schema of non-tensor extension type columns. - schema = pyarrow.unify_schemas(schemas) + + # If the result contains pyarrow schemas, unify them + schemas_to_unify = [b.schema for b in blocks] + if pyarrow is not None and any( + isinstance(s, pyarrow.Schema) for s in schemas_to_unify + ): + schema = unify_schemas(schemas_to_unify) + else: + # Otherwise, if the resulting schemas are simple types (e.g. int), + # check that all blocks with valid schemas have the same type. + schema = schemas_to_unify[0] + if schema is not None: + for s in schemas_to_unify: + if s is not None and s != schema: + raise ValueError( + "Found blocks with different types " + f"in schemas: {schemas_to_unify}" + ) # Build the concatenated table. table = pyarrow.Table.from_arrays(cols, schema=schema) # Validate table schema (this is a cheap check by default). diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index 8a2648a937129..592abf3486bcc 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -16,6 +16,7 @@ ) import ray +from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas from ray.data._internal.block_list import BlockList from ray.data._internal.compute import ( UDF, @@ -261,14 +262,33 @@ def schema( metadata = blocks.get_metadata(fetch_if_missing=False) # Some blocks could be empty, in which case we cannot get their schema. # TODO(ekl) validate schema is the same across different blocks. + + # First check if there are blocks with computed schemas, then unify + # valid schemas from all such blocks. + schemas_to_unify = [] for m in metadata: if m.schema is not None and (m.num_rows is None or m.num_rows > 0): - return m.schema + schemas_to_unify.append(m.schema) + if schemas_to_unify: + # Check valid pyarrow installation before attempting schema unification + try: + import pyarrow as pa + except ImportError: + pa = None + # If the result contains PyArrow schemas, unify them + if pa is not None and any( + isinstance(s, pa.Schema) for s in schemas_to_unify + ): + return unify_schemas(schemas_to_unify) + # Otherwise, if the resulting schemas are simple types (e.g. int), + # return the first schema. + return schemas_to_unify[0] if not fetch_if_missing: return None # Synchronously fetch the schema. # For lazy block lists, this launches read tasks and fetches block metadata - # until we find valid block schema. + # until we find the first valid block schema. This is to minimize new + # computations when fetching the schema. for _, m in blocks.iter_blocks_with_metadata(): if m.schema is not None and (m.num_rows is None or m.num_rows > 0): return m.schema diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index fce280e59d3b8..b46de7d3ee59f 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -13,6 +13,7 @@ import ray from ray._private.test_utils import wait_for_condition +from ray.air.util.tensor_extensions.arrow import ArrowVariableShapedTensorType from ray.data._internal.stats import _StatsActor from ray.data._internal.arrow_block import ArrowRow from ray.data._internal.block_builder import BlockBuilder @@ -5398,6 +5399,23 @@ def test_dataset_schema_after_read_stats(ray_start_cluster): assert schema == ds.schema() +def test_ragged_tensors(ray_start_regular_shared): + """Test Arrow type promotion between ArrowTensorType and + ArrowVariableShapedTensorType when a column contains ragged tensors.""" + import numpy as np + + ds = ray.data.from_items( + [ + {"spam": np.zeros((32, 32, 5))}, + {"spam": np.zeros((64, 64, 5))}, + ] + ) + new_type = ds.schema().types[0].scalar_type + assert ds.schema().types == [ + ArrowVariableShapedTensorType(dtype=new_type, ndim=3), + ] + + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_transform_pyarrow.py b/python/ray/data/tests/test_transform_pyarrow.py index 99c47c04bbc83..68b5a91594a0f 100644 --- a/python/ray/data/tests/test_transform_pyarrow.py +++ b/python/ray/data/tests/test_transform_pyarrow.py @@ -7,7 +7,7 @@ ArrowTensorType, ArrowVariableShapedTensorType, ) -from ray.data._internal.arrow_ops.transform_pyarrow import concat +from ray.data._internal.arrow_ops.transform_pyarrow import concat, unify_schemas def test_arrow_concat_empty(): @@ -176,6 +176,116 @@ def test_arrow_concat_tensor_extension_uniform_but_different(): # fails for this case. +def test_unify_schemas(): + # Unifying a schema with the same schema as itself + tensor_arr_1 = pa.schema([("tensor_arr", ArrowTensorType((3, 5), pa.int32()))]) + assert unify_schemas([tensor_arr_1, tensor_arr_1]) == tensor_arr_1 + + # Single columns with different shapes + tensor_arr_2 = pa.schema([("tensor_arr", ArrowTensorType((2, 1), pa.int32()))]) + contains_diff_shaped = [tensor_arr_1, tensor_arr_2] + assert unify_schemas(contains_diff_shaped) == pa.schema( + [ + ("tensor_arr", ArrowVariableShapedTensorType(pa.int32(), 2)), + ] + ) + + # Single columns with same shapes + tensor_arr_3 = pa.schema([("tensor_arr", ArrowTensorType((3, 5), pa.int32()))]) + contains_diff_types = [tensor_arr_1, tensor_arr_3] + assert unify_schemas(contains_diff_types) == pa.schema( + [ + ("tensor_arr", ArrowTensorType((3, 5), pa.int32())), + ] + ) + + # Single columns with a variable shaped tensor, same ndim + var_tensor_arr = pa.schema( + [ + ("tensor_arr", ArrowVariableShapedTensorType(pa.int32(), 2)), + ] + ) + contains_var_shaped = [tensor_arr_1, var_tensor_arr] + assert unify_schemas(contains_var_shaped) == pa.schema( + [ + ("tensor_arr", ArrowVariableShapedTensorType(pa.int32(), 2)), + ] + ) + + # Single columns with a variable shaped tensor, different ndim + var_tensor_arr_1d = pa.schema( + [ + ("tensor_arr", ArrowVariableShapedTensorType(pa.int32(), 1)), + ] + ) + var_tensor_arr_3d = pa.schema( + [ + ("tensor_arr", ArrowVariableShapedTensorType(pa.int32(), 3)), + ] + ) + contains_1d2d = [tensor_arr_1, var_tensor_arr_1d] + assert unify_schemas(contains_1d2d) == pa.schema( + [ + ("tensor_arr", ArrowVariableShapedTensorType(pa.int32(), 2)), + ] + ) + contains_2d3d = [tensor_arr_1, var_tensor_arr_3d] + assert unify_schemas(contains_2d3d) == pa.schema( + [ + ("tensor_arr", ArrowVariableShapedTensorType(pa.int32(), 3)), + ] + ) + + # Multi-column schemas + multicol_schema_1 = pa.schema( + [ + ("col_int", pa.int32()), + ("col_fixed_tensor", ArrowTensorType((4, 2), pa.int32())), + ("col_var_tensor", ArrowVariableShapedTensorType(pa.int16(), 5)), + ] + ) + multicol_schema_2 = pa.schema( + [ + ("col_int", pa.int32()), + ("col_fixed_tensor", ArrowTensorType((4, 2), pa.int32())), + ("col_var_tensor", ArrowTensorType((9, 4, 1, 0, 5), pa.int16())), + ] + ) + assert unify_schemas([multicol_schema_1, multicol_schema_2]) == pa.schema( + [ + ("col_int", pa.int32()), + ("col_fixed_tensor", ArrowTensorType((4, 2), pa.int32())), + ("col_var_tensor", ArrowVariableShapedTensorType(pa.int16(), 5)), + ] + ) + + multicol_schema_3 = pa.schema( + [ + ("col_int", pa.int32()), + ("col_fixed_tensor", ArrowVariableShapedTensorType(pa.int32(), 3)), + ("col_var_tensor", ArrowVariableShapedTensorType(pa.int16(), 5)), + ] + ) + assert unify_schemas([multicol_schema_1, multicol_schema_3]) == pa.schema( + [ + ("col_int", pa.int32()), + ("col_fixed_tensor", ArrowVariableShapedTensorType(pa.int32(), 3)), + ("col_var_tensor", ArrowVariableShapedTensorType(pa.int16(), 5)), + ] + ) + + # Unifying >2 schemas together + assert unify_schemas( + [multicol_schema_1, multicol_schema_2, multicol_schema_3] + ) == pa.schema( + [ + ("col_int", pa.int32()), + ("col_fixed_tensor", ArrowVariableShapedTensorType(pa.int32(), 3)), + ("col_var_tensor", ArrowVariableShapedTensorType(pa.int16(), 5)), + ] + ) + + if __name__ == "__main__": import sys