From 59e8d4f60d5674033bd6d03f80cfab8c01c64b5c Mon Sep 17 00:00:00 2001 From: Goutam Date: Mon, 27 Oct 2025 13:39:44 -0700 Subject: [PATCH] [Data] - [1/n] Add Temporal, list, tensor, struct datatype support to RD Datatype Signed-off-by: Goutam --- python/ray/data/datatype.py | 762 +++++++++++++++++++- python/ray/data/tests/unit/test_datatype.py | 486 ++++++++++++- 2 files changed, 1214 insertions(+), 34 deletions(-) diff --git a/python/ray/data/datatype.py b/python/ray/data/datatype.py index 4c9fb79defce..1cc05b9f26e7 100644 --- a/python/ray/data/datatype.py +++ b/python/ray/data/datatype.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -9,6 +10,29 @@ ) from ray.util.annotations import PublicAPI + +class _LogicalDataType(str, Enum): + """DataType logical types for pattern matching. + + These are used when _physical_dtype is None to represent categories of types + rather than concrete types. For example, _LogicalDataType.LIST matches any list + type regardless of element type. + + Note: _LogicalDataType.ANY is exposed as DataType.ANY and used as the default + parameter in factory methods (e.g., DataType.list(DataType.ANY)) to explicitly + request pattern-matching types. When _logical_dtype field is None, that represents + matching "any type at all" (completely unspecified). + """ + + ANY = "any" # Sentinel for method parameters; not stored in _logical_dtype field + LIST = "list" + LARGE_LIST = "large_list" + STRUCT = "struct" + MAP = "map" + TENSOR = "tensor" + TEMPORAL = "temporal" + + PYARROW_TYPE_DEFINITIONS: Dict[str, Tuple[callable, str]] = { "int8": (pa.int8, "an 8-bit signed integer"), "int16": (pa.int16, "a 16-bit signed integer"), @@ -88,28 +112,67 @@ def factory_method(cls): class DataType: """A simplified Ray Data DataType supporting Arrow, NumPy, and Python types.""" - _internal_type: Union[pa.DataType, np.dtype, type] + # Physical dtype: The concrete type implementation (e.g., pa.list_(pa.int64()), np.float64, str) + # Logical dtype: Used for pattern matching to represent a category of types + # - When _physical_dtype is set: _logical_dtype is ANY (not used, indicates concrete type) + # - When _physical_dtype is None: _logical_dtype specifies the pattern (LIST, STRUCT, MAP, etc.) + _physical_dtype: Optional[Union[pa.DataType, np.dtype, type]] + _logical_dtype: _LogicalDataType = _LogicalDataType.ANY + + # Sentinel value for creating pattern-matching types. + # Used as default in factory methods to allow both DataType.list(DataType.ANY) and DataType.list(). + ANY = _LogicalDataType.ANY def __post_init__(self): - """Validate the _internal_type after initialization.""" + """Validate the _physical_dtype after initialization.""" + # Allow None for pattern-matching types + if self._physical_dtype is None: + return + # TODO: Support Pandas extension types if not isinstance( - self._internal_type, + self._physical_dtype, (pa.DataType, np.dtype, type), ): raise TypeError( - f"DataType supports only PyArrow DataType, NumPy dtype, or Python type, but was given type {type(self._internal_type)}." + f"DataType supports only PyArrow DataType, NumPy dtype, or Python type, but was given type {type(self._physical_dtype)}." ) # Type checking methods def is_arrow_type(self) -> bool: - return isinstance(self._internal_type, pa.DataType) + """Check if this DataType is backed by a PyArrow DataType. + + Returns: + bool: True if the internal type is a PyArrow DataType + """ + return isinstance(self._physical_dtype, pa.DataType) def is_numpy_type(self) -> bool: - return isinstance(self._internal_type, np.dtype) + """Check if this DataType is backed by a NumPy dtype. + + Returns: + bool: True if the internal type is a NumPy dtype + """ + return isinstance(self._physical_dtype, np.dtype) def is_python_type(self) -> bool: - return isinstance(self._internal_type, type) + """Check if this DataType is backed by a Python type. + + Returns: + bool: True if the internal type is a Python type + """ + return isinstance(self._physical_dtype, type) + + def is_pattern_matching(self) -> bool: + """Check if this DataType is a pattern-matching type. + + Pattern-matching types have _physical_dtype=None and are used to match + categories of types (e.g., any list, any struct) rather than concrete types. + + Returns: + bool: True if this is a pattern-matching type + """ + return self._physical_dtype is None # Conversion methods def to_arrow_dtype(self, values: Optional[List[Any]] = None) -> pa.DataType: @@ -121,12 +184,22 @@ def to_arrow_dtype(self, values: Optional[List[Any]] = None) -> pa.DataType: Returns: A PyArrow DataType + + Raises: + ValueError: If called on a pattern-matching type (where _physical_dtype is None) """ + if self.is_pattern_matching(): + raise ValueError( + f"Cannot convert pattern-matching type {self} to a concrete Arrow type. " + "Pattern-matching types represent abstract type categories (e.g., 'any list') " + "and do not have a concrete Arrow representation." + ) + if self.is_arrow_type(): - return self._internal_type + return self._physical_dtype else: - if isinstance(self._internal_type, np.dtype): - return pa.from_numpy_dtype(self._internal_type) + if isinstance(self._physical_dtype, np.dtype): + return pa.from_numpy_dtype(self._physical_dtype) else: assert ( values is not None and len(values) > 0 @@ -134,12 +207,37 @@ def to_arrow_dtype(self, values: Optional[List[Any]] = None) -> pa.DataType: return _infer_pyarrow_type(values) def to_numpy_dtype(self) -> np.dtype: + """Convert the DataType to a NumPy dtype. + + For PyArrow types, attempts to convert via pandas dtype. + For Python types, returns object dtype. + + Returns: + np.dtype: A NumPy dtype representation + + Raises: + ValueError: If called on a pattern-matching type (where _physical_dtype is None) + + Examples: + >>> import numpy as np + >>> DataType.from_numpy(np.dtype('int64')).to_numpy_dtype() + dtype('int64') + >>> DataType.from_numpy(np.dtype('float32')).to_numpy_dtype() + dtype('float32') + """ + if self.is_pattern_matching(): + raise ValueError( + f"Cannot convert pattern-matching type {self} to a concrete NumPy dtype. " + "Pattern-matching types represent abstract type categories (e.g., 'any list') " + "and do not have a concrete NumPy representation." + ) + if self.is_numpy_type(): - return self._internal_type + return self._physical_dtype elif self.is_arrow_type(): try: # For most basic arrow types, this will work - pandas_dtype = self._internal_type.to_pandas_dtype() + pandas_dtype = self._physical_dtype.to_pandas_dtype() if isinstance(pandas_dtype, np.dtype): return pandas_dtype else: @@ -151,10 +249,31 @@ def to_numpy_dtype(self) -> np.dtype: return np.dtype("object") def to_python_type(self) -> type: + """Get the internal type if it's a Python type. + + This method doesn't perform conversion, it only returns the internal + type if it's already a Python type. + + Returns: + type: The internal Python type + + Raises: + ValueError: If the DataType is not backed by a Python type + + Examples: + >>> dt = DataType(int) + >>> dt.to_python_type() + + >>> DataType.int64().to_python_type() # doctest: +SKIP + ValueError: DataType is not backed by a Python type + """ if self.is_python_type(): - return self._internal_type + return self._physical_dtype else: - raise ValueError(f"DataType {self} is not a Python type") + raise ValueError( + f"DataType {self} is not backed by a Python type. " + f"Use to_arrow_dtype() or to_numpy_dtype() for conversion." + ) # Factory methods from external systems @classmethod @@ -175,7 +294,7 @@ def from_arrow(cls, arrow_type: pa.DataType) -> "DataType": >>> DataType.from_arrow(pa.int64()) DataType(arrow:int64) """ - return cls(_internal_type=arrow_type) + return cls(_physical_dtype=arrow_type) @classmethod def from_numpy(cls, numpy_dtype: Union[np.dtype, str]) -> "DataType": @@ -197,7 +316,7 @@ def from_numpy(cls, numpy_dtype: Union[np.dtype, str]) -> "DataType": """ if isinstance(numpy_dtype, str): numpy_dtype = np.dtype(numpy_dtype) - return cls(_internal_type=numpy_dtype) + return cls(_physical_dtype=numpy_dtype) @classmethod def infer_dtype(cls, value: Any) -> "DataType": @@ -222,7 +341,7 @@ def infer_dtype(cls, value: Any) -> "DataType": # 1. Handle numpy arrays and scalars if isinstance(value, (np.ndarray, np.generic)): return cls.from_numpy(value.dtype) - # 3. Try PyArrow type inference for regular Python values + # 2. Try PyArrow type inference for regular Python values try: inferred_arrow_type = _infer_pyarrow_type([value]) if inferred_arrow_type is not None: @@ -231,25 +350,618 @@ def infer_dtype(cls, value: Any) -> "DataType": return cls(type(value)) def __repr__(self) -> str: - if self.is_arrow_type(): - return f"DataType(arrow:{self._internal_type})" + if self._physical_dtype is None: + return f"DataType(logical_dtype:{self._logical_dtype.name})" + elif self.is_arrow_type(): + return f"DataType(arrow:{self._physical_dtype})" elif self.is_numpy_type(): - return f"DataType(numpy:{self._internal_type})" + return f"DataType(numpy:{self._physical_dtype})" else: - return f"DataType(python:{self._internal_type.__name__})" + return f"DataType(python:{self._physical_dtype.__name__})" - def __eq__(self, other) -> bool: + def __eq__(self, other: "DataType") -> bool: if not isinstance(other, DataType): return False + # Handle pattern-matching types (None internal type) + self_is_pattern = self._physical_dtype is None + other_is_pattern = other._physical_dtype is None + + if self_is_pattern or other_is_pattern: + return ( + self_is_pattern + and other_is_pattern + and self._logical_dtype == other._logical_dtype + ) + # Ensure they're from the same type system by checking the actual type # of the internal type object, not just the value - if type(self._internal_type) is not type(other._internal_type): + if type(self._physical_dtype) is not type(other._physical_dtype): return False - return self._internal_type == other._internal_type + return self._physical_dtype == other._physical_dtype def __hash__(self) -> int: + # Handle pattern-matching types + if self._physical_dtype is None: + return hash(("PATTERN", None, self._logical_dtype)) # Include the type of the internal type in the hash to ensure # different type systems don't collide - return hash((type(self._internal_type), self._internal_type)) + return hash((type(self._physical_dtype), self._physical_dtype)) + + @classmethod + def _is_pattern_matching_arg(cls, arg: Union["DataType", _LogicalDataType]) -> bool: + """Check if an argument should be treated as pattern-matching. + + Args: + arg: Either a _LogicalDataType enum or a DataType instance + + Returns: + True if the argument represents a pattern-matching type + """ + return isinstance(arg, _LogicalDataType) or ( + isinstance(arg, DataType) and arg.is_pattern_matching() + ) + + @classmethod + def list( + cls, value_type: Union["DataType", _LogicalDataType] = _LogicalDataType.ANY + ) -> "DataType": + """Create a DataType representing a list with the given element type. + + Pass DataType.ANY (or omit the argument) to create a pattern that matches any list type. + + Args: + value_type: The DataType of the list elements, or DataType.ANY to match any list. + Defaults to DataType.ANY. + + Returns: + DataType: A DataType with PyArrow list type or a pattern-matching DataType + + Examples: + >>> from ray.data.datatype import DataType + >>> DataType.list(DataType.int64()) # Exact match: list + DataType(arrow:list) + >>> DataType.list(DataType.ANY) # Pattern: matches any list (explicit) + DataType(logical_dtype:LIST) + >>> DataType.list() # Same as above (terse) + DataType(logical_dtype:LIST) + """ + if cls._is_pattern_matching_arg(value_type): + return cls(_physical_dtype=None, _logical_dtype=_LogicalDataType.LIST) + + value_arrow_type = value_type.to_arrow_dtype() + return cls.from_arrow(pa.list_(value_arrow_type)) + + @classmethod + def large_list( + cls, value_type: Union["DataType", _LogicalDataType] = _LogicalDataType.ANY + ) -> "DataType": + """Create a DataType representing a large_list with the given element type. + + Pass DataType.ANY (or omit the argument) to create a pattern that matches any large_list type. + + Args: + value_type: The DataType of the list elements, or DataType.ANY to match any large_list. + Defaults to DataType.ANY. + + Returns: + DataType: A DataType with PyArrow large_list type or a pattern-matching DataType + + Examples: + >>> DataType.large_list(DataType.int64()) # Exact match + DataType(arrow:large_list) + >>> DataType.large_list(DataType.ANY) # Pattern: matches any large_list (explicit) + DataType(logical_dtype:LARGE_LIST) + >>> DataType.large_list() # Same as above (terse) + DataType(logical_dtype:LARGE_LIST) + """ + if cls._is_pattern_matching_arg(value_type): + return cls( + _physical_dtype=None, + _logical_dtype=_LogicalDataType.LARGE_LIST, + ) + + value_arrow_type = value_type.to_arrow_dtype() + return cls.from_arrow(pa.large_list(value_arrow_type)) + + @classmethod + def fixed_size_list(cls, value_type: "DataType", list_size: int) -> "DataType": + """Create a DataType representing a fixed-size list. + + Args: + value_type: The DataType of the list elements + list_size: The fixed size of the list + + Returns: + DataType: A DataType with PyArrow fixed_size_list type + + Examples: + >>> from ray.data.datatype import DataType + >>> DataType.fixed_size_list(DataType.float32(), 3) + DataType(arrow:fixed_size_list[3]) + """ + value_arrow_type = value_type.to_arrow_dtype() + return cls.from_arrow(pa.list_(value_arrow_type, list_size)) + + @classmethod + def struct( + cls, + fields: Union[ + List[Tuple[str, "DataType"]], _LogicalDataType + ] = _LogicalDataType.ANY, + ) -> "DataType": + """Create a DataType representing a struct with the given fields. + + Pass DataType.ANY (or omit the argument) to create a pattern that matches any struct type. + + Args: + fields: List of (field_name, field_type) tuples, or DataType.ANY to match any struct. + Defaults to DataType.ANY. + + Returns: + DataType: A DataType with PyArrow struct type or a pattern-matching DataType + + Examples: + >>> from ray.data.datatype import DataType + >>> DataType.struct([("x", DataType.int64()), ("y", DataType.float64())]) + DataType(arrow:struct) + >>> DataType.struct(DataType.ANY) # Pattern: matches any struct (explicit) + DataType(logical_dtype:STRUCT) + >>> DataType.struct() # Same as above (terse) + DataType(logical_dtype:STRUCT) + """ + if isinstance(fields, _LogicalDataType): + return cls(_physical_dtype=None, _logical_dtype=_LogicalDataType.STRUCT) + + # Check if any field type is pattern-matching + if any(cls._is_pattern_matching_arg(dtype) for _, dtype in fields): + return cls(_physical_dtype=None, _logical_dtype=_LogicalDataType.STRUCT) + + arrow_fields = [(name, dtype.to_arrow_dtype()) for name, dtype in fields] + return cls.from_arrow(pa.struct(arrow_fields)) + + @classmethod + def map( + cls, + key_type: Union["DataType", _LogicalDataType] = _LogicalDataType.ANY, + value_type: Union["DataType", _LogicalDataType] = _LogicalDataType.ANY, + ) -> "DataType": + """Create a DataType representing a map with the given key and value types. + + Pass DataType.ANY for either argument (or omit them) to create a pattern that matches any map type. + + Args: + key_type: The DataType of the map keys, or DataType.ANY to match any map. + Defaults to DataType.ANY. + value_type: The DataType of the map values, or DataType.ANY to match any map. + Defaults to DataType.ANY. + + Returns: + DataType: A DataType with PyArrow map type or a pattern-matching DataType + + Examples: + >>> from ray.data.datatype import DataType + >>> DataType.map(DataType.string(), DataType.int64()) + DataType(arrow:map) + >>> DataType.map(DataType.ANY, DataType.ANY) # Pattern: matches any map (explicit) + DataType(logical_dtype:MAP) + >>> DataType.map() # Same as above (terse) + DataType(logical_dtype:MAP) + >>> DataType.map(DataType.string(), DataType.ANY) # Also pattern (partial spec) + DataType(logical_dtype:MAP) + """ + if cls._is_pattern_matching_arg(key_type) or cls._is_pattern_matching_arg( + value_type + ): + return cls(_physical_dtype=None, _logical_dtype=_LogicalDataType.MAP) + + key_arrow_type = key_type.to_arrow_dtype() + value_arrow_type = value_type.to_arrow_dtype() + return cls.from_arrow(pa.map_(key_arrow_type, value_arrow_type)) + + @classmethod + def tensor( + cls, + shape: Union[Tuple[int, ...], _LogicalDataType] = _LogicalDataType.ANY, + dtype: Union["DataType", _LogicalDataType] = _LogicalDataType.ANY, + ) -> "DataType": + """Create a DataType representing a fixed-shape tensor. + + Pass DataType.ANY for arguments (or omit them) to create a pattern that matches any tensor type. + + Args: + shape: The fixed shape of the tensor, or DataType.ANY to match any tensor. + Defaults to DataType.ANY. + dtype: The DataType of the tensor elements, or DataType.ANY to match any tensor. + Defaults to DataType.ANY. + + Returns: + DataType: A DataType with Ray's ArrowTensorType or a pattern-matching DataType + + Examples: + >>> from ray.data.datatype import DataType + >>> DataType.tensor(shape=(3, 4), dtype=DataType.float32()) # doctest: +ELLIPSIS + DataType(arrow:ArrowTensorType(...)) + >>> DataType.tensor(DataType.ANY, DataType.ANY) # Pattern: matches any tensor (explicit) + DataType(logical_dtype:TENSOR) + >>> DataType.tensor() # Same as above (terse) + DataType(logical_dtype:TENSOR) + >>> DataType.tensor(shape=(3, 4), dtype=DataType.ANY) # Also pattern (partial spec) + DataType(logical_dtype:TENSOR) + """ + if isinstance(shape, _LogicalDataType) or cls._is_pattern_matching_arg(dtype): + return cls(_physical_dtype=None, _logical_dtype=_LogicalDataType.TENSOR) + + from ray.air.util.tensor_extensions.arrow import ArrowTensorType + + element_arrow_type = dtype.to_arrow_dtype() + return cls.from_arrow(ArrowTensorType(shape, element_arrow_type)) + + @classmethod + def variable_shaped_tensor( + cls, + dtype: Union["DataType", _LogicalDataType] = _LogicalDataType.ANY, + ndim: Optional[int] = None, + ) -> "DataType": + """Create a DataType representing a variable-shaped tensor. + + Pass DataType.ANY (or omit the argument) to create a pattern that matches any variable-shaped tensor. + + Args: + dtype: The DataType of the tensor elements, or DataType.ANY to match any tensor. + Defaults to DataType.ANY. + ndim: The number of dimensions of the tensor + + Returns: + DataType: A DataType with Ray's ArrowVariableShapedTensorType or pattern-matching DataType + + Examples: + >>> from ray.data.datatype import DataType + >>> DataType.variable_shaped_tensor(dtype=DataType.float32(), ndim=2) # doctest: +ELLIPSIS + DataType(arrow:ArrowVariableShapedTensorType(...)) + >>> DataType.variable_shaped_tensor(DataType.ANY) # Pattern: matches any var tensor (explicit) + DataType(logical_dtype:TENSOR) + >>> DataType.variable_shaped_tensor() # Same as above (terse) + DataType(logical_dtype:TENSOR) + """ + if cls._is_pattern_matching_arg(dtype): + return cls(_physical_dtype=None, _logical_dtype=_LogicalDataType.TENSOR) + + if ndim is None: + ndim = 2 + + from ray.air.util.tensor_extensions.arrow import ArrowVariableShapedTensorType + + element_arrow_type = dtype.to_arrow_dtype() + return cls.from_arrow(ArrowVariableShapedTensorType(element_arrow_type, ndim)) + + @classmethod + def temporal( + cls, + temporal_type: Union[str, _LogicalDataType] = _LogicalDataType.ANY, + unit: Optional[str] = None, + tz: Optional[str] = None, + ) -> "DataType": + """Create a DataType representing a temporal type. + + Pass DataType.ANY (or omit the argument) to create a pattern that matches any temporal type. + + Args: + temporal_type: Type of temporal value - one of: + - "timestamp": Timestamp with optional unit and timezone + - "date32": 32-bit date (days since UNIX epoch) + - "date64": 64-bit date (milliseconds since UNIX epoch) + - "time32": 32-bit time of day (s or ms precision) + - "time64": 64-bit time of day (us or ns precision) + - "duration": Time duration with unit + - DataType.ANY: Pattern to match any temporal type (default) + unit: Time unit for timestamp/time/duration types: + - timestamp: "s", "ms", "us", "ns" (default: "us") + - time32: "s", "ms" (default: "s") + - time64: "us", "ns" (default: "us") + - duration: "s", "ms", "us", "ns" (default: "us") + tz: Optional timezone string for timestamp types (e.g., "UTC", "America/New_York") + + Returns: + DataType: A DataType with PyArrow temporal type or a pattern-matching DataType + + Examples: + >>> from ray.data.datatype import DataType + >>> DataType.temporal("timestamp", unit="s") + DataType(arrow:timestamp[s]) + >>> DataType.temporal("timestamp", unit="us", tz="UTC") + DataType(arrow:timestamp[us, tz=UTC]) + >>> DataType.temporal("date32") + DataType(arrow:date32[day]) + >>> DataType.temporal("time64", unit="ns") + DataType(arrow:time64[ns]) + >>> DataType.temporal("duration", unit="ms") + DataType(arrow:duration[ms]) + >>> DataType.temporal(DataType.ANY) # Pattern: matches any temporal (explicit) + DataType(logical_dtype:TEMPORAL) + >>> DataType.temporal() # Same as above (terse) + DataType(logical_dtype:TEMPORAL) + """ + if isinstance(temporal_type, _LogicalDataType): + return cls(_physical_dtype=None, _logical_dtype=_LogicalDataType.TEMPORAL) + + temporal_type_lower = temporal_type.lower() + + if temporal_type_lower == "timestamp": + unit = unit or "us" + return cls.from_arrow(pa.timestamp(unit, tz=tz)) + elif temporal_type_lower == "date32": + return cls.from_arrow(pa.date32()) + elif temporal_type_lower == "date64": + return cls.from_arrow(pa.date64()) + elif temporal_type_lower == "time32": + unit = unit or "s" + if unit not in ("s", "ms"): + raise ValueError(f"time32 unit must be 's' or 'ms', got {unit}") + return cls.from_arrow(pa.time32(unit)) + elif temporal_type_lower == "time64": + unit = unit or "us" + if unit not in ("us", "ns"): + raise ValueError(f"time64 unit must be 'us' or 'ns', got {unit}") + return cls.from_arrow(pa.time64(unit)) + elif temporal_type_lower == "duration": + unit = unit or "us" + return cls.from_arrow(pa.duration(unit)) + else: + raise ValueError( + f"Invalid temporal_type '{temporal_type}'. Must be one of: " + f"'timestamp', 'date32', 'date64', 'time32', 'time64', 'duration'" + ) + + def is_list_type(self) -> bool: + """Check if this DataType represents a list type + + Returns: + True if this is any list variant (list, large_list, fixed_size_list) + + Examples: + >>> DataType.list(DataType.int64()).is_list_type() + True + >>> DataType.int64().is_list_type() + False + """ + if not self.is_arrow_type(): + return False + + pa_type = self._physical_dtype + return ( + pa.types.is_list(pa_type) + or pa.types.is_large_list(pa_type) + or pa.types.is_fixed_size_list(pa_type) + # Pyarrow 16.0.0+ supports list views + or (hasattr(pa.types, "is_list_view") and pa.types.is_list_view(pa_type)) + or ( + hasattr(pa.types, "is_large_list_view") + and pa.types.is_large_list_view(pa_type) + ) + ) + + def is_tensor_type(self) -> bool: + """Check if this DataType represents a tensor type. + + Returns: + True if this is a tensor type + """ + if not self.is_arrow_type(): + return False + + from ray.air.util.tensor_extensions.arrow import ( + get_arrow_extension_tensor_types, + ) + + return isinstance(self._physical_dtype, get_arrow_extension_tensor_types()) + + def is_struct_type(self) -> bool: + """Check if this DataType represents a struct type. + + Returns: + True if this is a struct type + + Examples: + >>> DataType.struct([("x", DataType.int64())]).is_struct_type() + True + >>> DataType.int64().is_struct_type() + False + """ + if not self.is_arrow_type(): + return False + return pa.types.is_struct(self._physical_dtype) + + def is_map_type(self) -> bool: + """Check if this DataType represents a map type. + + Returns: + True if this is a map type + + Examples: + >>> DataType.map(DataType.string(), DataType.int64()).is_map_type() + True + >>> DataType.int64().is_map_type() + False + """ + if not self.is_arrow_type(): + return False + return pa.types.is_map(self._physical_dtype) + + def is_nested_type(self) -> bool: + """Check if this DataType represents a nested type. + + Nested types include: lists, structs, maps, unions + + Returns: + True if this is any nested type + + Examples: + >>> DataType.list(DataType.int64()).is_nested_type() + True + >>> DataType.struct([("x", DataType.int64())]).is_nested_type() + True + >>> DataType.int64().is_nested_type() + False + """ + if not self.is_arrow_type(): + return False + return pa.types.is_nested(self._physical_dtype) + + def _get_underlying_arrow_type(self) -> pa.DataType: + """Get the underlying Arrow type, handling dictionary and run-end encoding. + + Returns: + The underlying PyArrow type, unwrapping dictionary/run-end encoding + + Raises: + ValueError: If called on a non-Arrow type (pattern-matching, NumPy, or Python types) + """ + if self.is_pattern_matching(): + raise ValueError( + f"Cannot get Arrow type for pattern-matching type {self}. " + "Pattern-matching types do not have a concrete Arrow representation." + ) + if not self.is_arrow_type(): + raise ValueError( + f"Cannot get Arrow type for non-Arrow DataType {self}. " + f"Type is: {type(self._physical_dtype)}" + ) + + pa_type = self._physical_dtype + if pa.types.is_dictionary(pa_type): + return pa_type.value_type + elif pa.types.is_run_end_encoded(pa_type): + return pa_type.value_type + return pa_type + + def is_numerical_type(self) -> bool: + """Check if this DataType represents a numerical type. + + Numerical types support arithmetic operations and include: + integers, floats, decimals + + Returns: + True if this is a numerical type + + Examples: + >>> DataType.int64().is_numerical_type() + True + >>> DataType.float32().is_numerical_type() + True + >>> DataType.string().is_numerical_type() + False + """ + if self.is_arrow_type(): + underlying = self._get_underlying_arrow_type() + return ( + pa.types.is_integer(underlying) + or pa.types.is_floating(underlying) + or pa.types.is_decimal(underlying) + ) + elif self.is_numpy_type(): + return ( + np.issubdtype(self._physical_dtype, np.integer) + or np.issubdtype(self._physical_dtype, np.floating) + or np.issubdtype(self._physical_dtype, np.complexfloating) + ) + elif self.is_python_type(): + return self._physical_dtype in (int, float, complex) + return False + + def is_string_type(self) -> bool: + """Check if this DataType represents a string type. + + Includes: string, large_string, string_view + + Returns: + True if this is a string type + + Examples: + >>> DataType.string().is_string_type() + True + >>> DataType.int64().is_string_type() + False + """ + if self.is_arrow_type(): + underlying = self._get_underlying_arrow_type() + return ( + pa.types.is_string(underlying) + or pa.types.is_large_string(underlying) + or ( + hasattr(pa.types, "is_string_view") + and pa.types.is_string_view(underlying) + ) + ) + elif self.is_numpy_type(): + # Check for Unicode (U) or byte string (S) types + return self._physical_dtype.kind in ("U", "S") + elif self.is_python_type(): + return self._physical_dtype is str + return False + + def is_binary_type(self) -> bool: + """Check if this DataType represents a binary type. + + Includes: binary, large_binary, binary_view, fixed_size_binary + + Returns: + True if this is a binary type + + Examples: + >>> DataType.binary().is_binary_type() + True + >>> DataType.string().is_binary_type() + False + """ + if self.is_arrow_type(): + underlying = self._get_underlying_arrow_type() + return ( + pa.types.is_binary(underlying) + or pa.types.is_large_binary(underlying) + or ( + hasattr(pa.types, "is_binary_view") + and pa.types.is_binary_view(underlying) + ) + or pa.types.is_fixed_size_binary(underlying) + ) + elif self.is_numpy_type(): + # NumPy doesn't have a specific binary type, but void or object dtypes might contain bytes + return self._physical_dtype.kind == "V" # void type (raw bytes) + elif self.is_python_type(): + return self._physical_dtype in (bytes, bytearray) + return False + + def is_temporal_type(self) -> bool: + """Check if this DataType represents a temporal type. + + Includes: date, time, timestamp, duration, interval + + Returns: + True if this is a temporal type + + Examples: + >>> import pyarrow as pa + >>> DataType.from_arrow(pa.timestamp('s')).is_temporal_type() + True + >>> DataType.int64().is_temporal_type() + False + """ + if self.is_arrow_type(): + underlying = self._get_underlying_arrow_type() + return pa.types.is_temporal(underlying) + elif self.is_numpy_type(): + return np.issubdtype(self._physical_dtype, np.datetime64) or np.issubdtype( + self._physical_dtype, np.timedelta64 + ) + elif self.is_python_type(): + import datetime + + return self._physical_dtype in ( + datetime.datetime, + datetime.date, + datetime.time, + datetime.timedelta, + ) + return False diff --git a/python/ray/data/tests/unit/test_datatype.py b/python/ray/data/tests/unit/test_datatype.py index ceb3a2650941..97b1ce3984b9 100644 --- a/python/ray/data/tests/unit/test_datatype.py +++ b/python/ray/data/tests/unit/test_datatype.py @@ -35,7 +35,7 @@ def test_factory_method_creates_correct_type( assert isinstance(result, DataType) assert result.is_arrow_type() - assert result._internal_type == pa_type + assert result._physical_dtype == pa_type @pytest.mark.parametrize( "method_name", @@ -85,7 +85,7 @@ def test_post_init_accepts_valid_types(self, valid_type): """Test that __post_init__ accepts valid type objects.""" # Should not raise dt = DataType(valid_type) - assert dt._internal_type == valid_type + assert dt._physical_dtype == valid_type @pytest.mark.parametrize( "invalid_type", @@ -94,7 +94,6 @@ def test_post_init_accepts_valid_types(self, valid_type): 123, [1, 2, 3], {"key": "value"}, - None, object(), ], ) @@ -147,7 +146,7 @@ def test_from_arrow(self, pa_type): assert isinstance(dt, DataType) assert dt.is_arrow_type() - assert dt._internal_type == pa_type + assert dt._physical_dtype == pa_type @pytest.mark.parametrize( "numpy_input,expected_dtype", @@ -164,7 +163,7 @@ def test_from_numpy(self, numpy_input, expected_dtype): assert isinstance(dt, DataType) assert dt.is_numpy_type() - assert dt._internal_type == expected_dtype + assert dt._physical_dtype == expected_dtype class TestDataTypeConversions: @@ -243,7 +242,7 @@ def test_to_python_type_success(self, python_type): ) def test_to_python_type_failure(self, non_python_dt): """Test to_python_type raises ValueError for non-Python types.""" - with pytest.raises(ValueError, match="is not a Python type"): + with pytest.raises(ValueError, match="is not backed by a Python type"): non_python_dt.to_python_type() @@ -264,7 +263,7 @@ def test_infer_dtype_numpy_values(self, numpy_value, expected_dtype): dt = DataType.infer_dtype(numpy_value) assert dt.is_numpy_type() - assert dt._internal_type == expected_dtype + assert dt._physical_dtype == expected_dtype # Removed test_infer_dtype_pyarrow_scalar - no longer works with current implementation @@ -348,8 +347,8 @@ def test_numpy_vs_python_inequality(self): # so they should not be equal # First verify they have different internal types - assert type(numpy_dt._internal_type) is not type(python_dt._internal_type) - assert numpy_dt._internal_type is not python_dt._internal_type + assert type(numpy_dt._physical_dtype) is not type(python_dt._physical_dtype) + assert numpy_dt._physical_dtype is not python_dt._physical_dtype # Test the type checkers return different results assert numpy_dt.is_numpy_type() and not python_dt.is_numpy_type() @@ -388,5 +387,474 @@ def test_hashability(self): assert dt_dict[dt2] == "first" # dt2 should match dt1 +class TestLogicalDataTypes: + """Test pattern-matching DataTypes with _LogicalDataType enum.""" + + @pytest.mark.parametrize( + "factory_method,logical_dtype_value", + [ + (lambda: DataType.list(), "list"), + (lambda: DataType.large_list(), "large_list"), + (lambda: DataType.struct(), "struct"), + (lambda: DataType.map(), "map"), + (lambda: DataType.tensor(), "tensor"), + (lambda: DataType.variable_shaped_tensor(), "tensor"), + (lambda: DataType.temporal(), "temporal"), + ], + ) + def test_logical_dtype_creation(self, factory_method, logical_dtype_value): + """Test that logical DataTypes have correct _logical_dtype.""" + from ray.data.datatype import _LogicalDataType + + dt = factory_method() + assert dt._physical_dtype is None + assert dt._logical_dtype == _LogicalDataType(logical_dtype_value) + assert isinstance(dt._logical_dtype, _LogicalDataType) + + @pytest.mark.parametrize( + "factory_method,expected_repr", + [ + (lambda: DataType.list(), "DataType(logical_dtype:LIST)"), + (lambda: DataType.large_list(), "DataType(logical_dtype:LARGE_LIST)"), + (lambda: DataType.struct(), "DataType(logical_dtype:STRUCT)"), + (lambda: DataType.map(), "DataType(logical_dtype:MAP)"), + (lambda: DataType.tensor(), "DataType(logical_dtype:TENSOR)"), + ( + lambda: DataType.variable_shaped_tensor(), + "DataType(logical_dtype:TENSOR)", + ), + (lambda: DataType.temporal(), "DataType(logical_dtype:TEMPORAL)"), + ], + ) + def test_logical_dtype_repr(self, factory_method, expected_repr): + """Test __repr__ for logical DataTypes.""" + dt = factory_method() + assert repr(dt) == expected_repr + + @pytest.mark.parametrize( + "dt1_factory,dt2_factory,should_be_equal", + [ + # Same logical DataTypes should be equal (including explicit ANY form) + (lambda: DataType.list(), lambda: DataType.list(DataType.ANY), True), + (lambda: DataType.list(), lambda: DataType.list(), True), + (lambda: DataType.struct(), lambda: DataType.struct(DataType.ANY), True), + ( + lambda: DataType.tensor(), + lambda: DataType.variable_shaped_tensor(), + True, + ), + # Different logical DataTypes should not be equal + (lambda: DataType.list(), lambda: DataType.large_list(), False), + (lambda: DataType.list(), lambda: DataType.struct(), False), + (lambda: DataType.map(), lambda: DataType.temporal(), False), + ], + ) + def test_logical_dtype_equality(self, dt1_factory, dt2_factory, should_be_equal): + """Test equality between logical DataTypes.""" + dt1 = dt1_factory() + dt2 = dt2_factory() + + if should_be_equal: + assert dt1 == dt2 + assert hash(dt1) == hash(dt2) + else: + assert dt1 != dt2 + + +class TestNestedTypeFactories: + """Test factory methods for nested types (list, struct, map, etc.).""" + + @pytest.mark.parametrize( + "factory_call,expected_arrow_type", + [ + (lambda: DataType.list(DataType.int64()), pa.list_(pa.int64())), + (lambda: DataType.list(DataType.string()), pa.list_(pa.string())), + ( + lambda: DataType.large_list(DataType.float32()), + pa.large_list(pa.float32()), + ), + ( + lambda: DataType.fixed_size_list(DataType.int32(), 5), + pa.list_(pa.int32(), 5), + ), + ], + ) + def test_list_type_factories(self, factory_call, expected_arrow_type): + """Test list-type factory methods create correct Arrow types.""" + dt = factory_call() + assert dt.is_arrow_type() + assert dt._physical_dtype == expected_arrow_type + + @pytest.mark.parametrize( + "fields,expected_arrow_type", + [ + ( + [("x", DataType.int64()), ("y", DataType.float64())], + pa.struct([("x", pa.int64()), ("y", pa.float64())]), + ), + ( + [("name", DataType.string()), ("age", DataType.int32())], + pa.struct([("name", pa.string()), ("age", pa.int32())]), + ), + ], + ) + def test_struct_factory(self, fields, expected_arrow_type): + """Test struct factory method creates correct Arrow types.""" + dt = DataType.struct(fields) + assert dt.is_arrow_type() + assert dt._physical_dtype == expected_arrow_type + + @pytest.mark.parametrize( + "key_type,value_type,expected_arrow_type", + [ + (DataType.string(), DataType.int64(), pa.map_(pa.string(), pa.int64())), + (DataType.int32(), DataType.float32(), pa.map_(pa.int32(), pa.float32())), + ], + ) + def test_map_factory(self, key_type, value_type, expected_arrow_type): + """Test map factory method creates correct Arrow types.""" + dt = DataType.map(key_type, value_type) + assert dt.is_arrow_type() + assert dt._physical_dtype == expected_arrow_type + + @pytest.mark.parametrize( + "temporal_type,unit,tz,expected_type", + [ + ("timestamp", "s", None, pa.timestamp("s")), + ("timestamp", "us", "UTC", pa.timestamp("us", tz="UTC")), + ("date32", None, None, pa.date32()), + ("date64", None, None, pa.date64()), + ("time32", "s", None, pa.time32("s")), + ("time64", "us", None, pa.time64("us")), + ("duration", "ms", None, pa.duration("ms")), + ], + ) + def test_temporal_factory(self, temporal_type, unit, tz, expected_type): + """Test temporal factory method creates correct Arrow types.""" + if tz is not None: + dt = DataType.temporal(temporal_type, unit=unit, tz=tz) + elif unit is not None: + dt = DataType.temporal(temporal_type, unit=unit) + else: + dt = DataType.temporal(temporal_type) + + assert dt.is_arrow_type() + assert dt._physical_dtype == expected_type + + @pytest.mark.parametrize( + "temporal_type,unit,error_msg", + [ + ("time32", "us", "time32 unit must be 's' or 'ms'"), + ("time64", "ms", "time64 unit must be 'us' or 'ns'"), + ("invalid", None, "Invalid temporal_type"), + ], + ) + def test_temporal_factory_validation(self, temporal_type, unit, error_msg): + """Test temporal factory validates inputs correctly.""" + with pytest.raises(ValueError, match=error_msg): + DataType.temporal(temporal_type, unit=unit) + + +class TestTypePredicates: + """Test type predicate methods (is_list_type, is_struct_type, etc.).""" + + @pytest.mark.parametrize( + "datatype,expected_result", + [ + # List types + (DataType.list(DataType.int64()), True), + (DataType.large_list(DataType.string()), True), + (DataType.fixed_size_list(DataType.float32(), 3), True), + # Tensor types (should return False) + (DataType.tensor(shape=(3, 4), dtype=DataType.float32()), False), + (DataType.variable_shaped_tensor(dtype=DataType.float64(), ndim=2), False), + # Non-list types + (DataType.int64(), False), + (DataType.string(), False), + (DataType.struct([("x", DataType.int32())]), False), + ], + ) + def test_is_list_type(self, datatype, expected_result): + """Test is_list_type predicate.""" + assert datatype.is_list_type() == expected_result + + @pytest.mark.parametrize( + "datatype,expected_result", + [ + (DataType.tensor(shape=(3, 4), dtype=DataType.float32()), True), + (DataType.variable_shaped_tensor(dtype=DataType.float64(), ndim=2), True), + ], + ) + def test_is_tensor_type(self, datatype, expected_result): + """Test is_tensor_type predicate.""" + assert datatype.is_tensor_type() == expected_result + + @pytest.mark.parametrize( + "datatype,expected_result", + [ + (DataType.struct([("x", DataType.int64())]), True), + ( + DataType.struct([("a", DataType.string()), ("b", DataType.float32())]), + True, + ), + (DataType.list(DataType.int64()), False), + (DataType.int64(), False), + ], + ) + def test_is_struct_type(self, datatype, expected_result): + """Test is_struct_type predicate.""" + assert datatype.is_struct_type() == expected_result + + @pytest.mark.parametrize( + "datatype,expected_result", + [ + (DataType.map(DataType.string(), DataType.int64()), True), + (DataType.map(DataType.int32(), DataType.float32()), True), + (DataType.list(DataType.int64()), False), + (DataType.int64(), False), + ], + ) + def test_is_map_type(self, datatype, expected_result): + """Test is_map_type predicate.""" + assert datatype.is_map_type() == expected_result + + @pytest.mark.parametrize( + "datatype,expected_result", + [ + # Nested types + (DataType.list(DataType.int64()), True), + (DataType.struct([("x", DataType.int32())]), True), + (DataType.map(DataType.string(), DataType.int64()), True), + # Non-nested types + (DataType.int64(), False), + (DataType.string(), False), + (DataType.float32(), False), + ], + ) + def test_is_nested_type(self, datatype, expected_result): + """Test is_nested_type predicate.""" + assert datatype.is_nested_type() == expected_result + + @pytest.mark.parametrize( + "datatype,expected_result", + [ + # Numerical Arrow types + (DataType.int64(), True), + (DataType.int32(), True), + (DataType.float32(), True), + (DataType.float64(), True), + # Numerical NumPy types + (DataType.from_numpy(np.dtype("int32")), True), + (DataType.from_numpy(np.dtype("float64")), True), + # Numerical Python types + (DataType(int), True), + (DataType(float), True), + # Non-numerical types + (DataType.string(), False), + (DataType.binary(), False), + (DataType(str), False), + ], + ) + def test_is_numerical_type(self, datatype, expected_result): + """Test is_numerical_type predicate.""" + assert datatype.is_numerical_type() == expected_result + + @pytest.mark.parametrize( + "datatype,expected_result", + [ + # String Arrow types + (DataType.string(), True), + (DataType.from_arrow(pa.large_string()), True), + # String NumPy types + (DataType.from_numpy(np.dtype("U10")), True), + # String Python types + (DataType(str), True), + # Non-string types + (DataType.int64(), False), + (DataType.binary(), False), + ], + ) + def test_is_string_type(self, datatype, expected_result): + """Test is_string_type predicate.""" + assert datatype.is_string_type() == expected_result + + @pytest.mark.parametrize( + "datatype,expected_result", + [ + # Binary Arrow types + (DataType.binary(), True), + (DataType.from_arrow(pa.large_binary()), True), + (DataType.from_arrow(pa.binary(10)), True), # fixed_size_binary + # Binary Python types + (DataType(bytes), True), + (DataType(bytearray), True), + # Non-binary types + (DataType.string(), False), + (DataType.int64(), False), + ], + ) + def test_is_binary_type(self, datatype, expected_result): + """Test is_binary_type predicate.""" + assert datatype.is_binary_type() == expected_result + + @pytest.mark.parametrize( + "datatype,expected_result", + [ + # Temporal Arrow types + (DataType.temporal("timestamp", unit="s"), True), + (DataType.temporal("date32"), True), + (DataType.temporal("time64", unit="us"), True), + (DataType.temporal("duration", unit="ms"), True), + # Temporal NumPy types + (DataType.from_numpy(np.dtype("datetime64[D]")), True), + (DataType.from_numpy(np.dtype("timedelta64[s]")), True), + # Non-temporal types + (DataType.int64(), False), + (DataType.string(), False), + ], + ) + def test_is_temporal_type(self, datatype, expected_result): + """Test is_temporal_type predicate.""" + assert datatype.is_temporal_type() == expected_result + + +class TestNestedPatternMatching: + """Test that pattern-matching DataTypes can be used as arguments to factory methods.""" + + @pytest.mark.parametrize( + "factory_call,expected_logical_dtype", + [ + # list with pattern-matching element type + (lambda: DataType.list(DataType.list()), "list"), + (lambda: DataType.list(DataType.struct()), "list"), + (lambda: DataType.list(DataType.map()), "list"), + # large_list with pattern-matching element type + (lambda: DataType.large_list(DataType.list()), "large_list"), + (lambda: DataType.large_list(DataType.tensor()), "large_list"), + # struct with pattern-matching field types + ( + lambda: DataType.struct( + [("a", DataType.list()), ("b", DataType.int64())] + ), + "struct", + ), + ( + lambda: DataType.struct( + [("x", DataType.tensor()), ("y", DataType.map())] + ), + "struct", + ), + # map with pattern-matching key/value types + (lambda: DataType.map(DataType.list(), DataType.int64()), "map"), + (lambda: DataType.map(DataType.string(), DataType.struct()), "map"), + (lambda: DataType.map(DataType.list(), DataType.map()), "map"), + # tensor with pattern-matching dtype + (lambda: DataType.tensor((3, 4), DataType.list()), "tensor"), + (lambda: DataType.tensor((2, 2), DataType.struct()), "tensor"), + # variable_shaped_tensor with pattern-matching dtype + ( + lambda: DataType.variable_shaped_tensor(DataType.list(), ndim=2), + "tensor", + ), + (lambda: DataType.variable_shaped_tensor(DataType.map(), ndim=3), "tensor"), + ], + ) + def test_nested_pattern_matching_types(self, factory_call, expected_logical_dtype): + """Test that pattern-matching DataTypes work as arguments to factory methods. + + When a pattern-matching DataType (one with _physical_dtype=None) is passed + as an argument to another factory method, it should result in a pattern-matching + type, not try to call to_arrow_dtype() on it. + """ + from ray.data.datatype import _LogicalDataType + + dt = factory_call() + # Should create a pattern-matching type, not a concrete type + assert dt._physical_dtype is None + assert dt._logical_dtype == _LogicalDataType(expected_logical_dtype) + + def test_list_with_nested_pattern(self): + """Test DataType.list(DataType.list()) returns pattern-matching LIST.""" + from ray.data.datatype import _LogicalDataType + + dt = DataType.list(DataType.list()) + assert dt._physical_dtype is None + assert dt._logical_dtype == _LogicalDataType.LIST + assert repr(dt) == "DataType(logical_dtype:LIST)" + + def test_struct_with_pattern_fields(self): + """Test DataType.struct with pattern-matching field types.""" + from ray.data.datatype import _LogicalDataType + + dt = DataType.struct([("a", DataType.list()), ("b", DataType.tensor())]) + assert dt._physical_dtype is None + assert dt._logical_dtype == _LogicalDataType.STRUCT + + +class TestPatternMatchingToArrowDtype: + """Test that pattern-matching types cannot be converted to concrete Arrow types.""" + + @pytest.mark.parametrize( + "pattern_type_factory", + [ + lambda: DataType.list(), + lambda: DataType.large_list(), + lambda: DataType.struct(), + lambda: DataType.map(), + lambda: DataType.tensor(), + lambda: DataType.variable_shaped_tensor(), + lambda: DataType.temporal(), + ], + ) + def test_pattern_matching_to_arrow_dtype_raises(self, pattern_type_factory): + """Test that calling to_arrow_dtype on pattern-matching types raises an error. + + Pattern-matching types represent abstract type categories (e.g., "any list") + and cannot be converted to concrete Arrow types. + """ + dt = pattern_type_factory() + assert dt.is_pattern_matching() + + with pytest.raises(ValueError, match="Cannot convert pattern-matching type"): + dt.to_arrow_dtype() + + def test_pattern_matching_to_arrow_dtype_with_values_still_raises(self): + """Test that even with values, pattern-matching types cannot be converted.""" + dt = DataType.list() + assert dt.is_pattern_matching() + + # Even with values provided, pattern-matching types shouldn't convert + with pytest.raises(ValueError, match="Cannot convert pattern-matching type"): + dt.to_arrow_dtype(values=[1, 2, 3]) + + +class TestPatternMatchingToNumpyDtype: + """Test that pattern-matching types cannot be converted to concrete NumPy dtypes.""" + + @pytest.mark.parametrize( + "pattern_type_factory", + [ + lambda: DataType.list(), + lambda: DataType.large_list(), + lambda: DataType.struct(), + lambda: DataType.map(), + lambda: DataType.tensor(), + lambda: DataType.variable_shaped_tensor(), + lambda: DataType.temporal(), + ], + ) + def test_pattern_matching_to_numpy_dtype_raises(self, pattern_type_factory): + """Test that calling to_numpy_dtype on pattern-matching types raises an error. + + Pattern-matching types represent abstract type categories (e.g., "any list") + and cannot be converted to concrete NumPy dtypes. + """ + dt = pattern_type_factory() + assert dt.is_pattern_matching() + + with pytest.raises(ValueError, match="Cannot convert pattern-matching type"): + dt.to_numpy_dtype() + + if __name__ == "__main__": pytest.main(["-v", __file__])