Skip to content

Commit 01b21f4

Browse files
xinyuangui2SheldonTsen
authored andcommitted
Cache PyArrow schema operations (ray-project#58583)
## Description This PR adds caching for PyArrow schema operations to improve performance during batching operations, especially for tables with a large number of columns. ### Main Changes - **Caching for Tensor Type Serialization/Deserialization**: Added cache for tensor type serialization and deserialization operations. This significantly reduces overhead for frequently accessed tensor types during schema operations. ### Performance Impact This optimization is particularly beneficial during batching operations for tables with a large number of columns. In one of our tests with 200 columns, the batching time per batch decreased from **0.30s to 0.11s** (~63% improvement). #### Without cache: <img width="1719" height="464" alt="Screenshot 2025-11-13 at 9 49 33 PM" src="https://github.com/user-attachments/assets/46122634-dd09-40ed-a2a8-725d14f85728" /> We can see `__arrow_ext_deserialize__` and `__arrow_ext_serialize__` in different places. Each time `__arrow_ext_deserialize__` will create a new object and `__arrow_ext_serialize__` includes expensive pickle. #### With cache <img width="1717" height="476" alt="Screenshot 2025-11-13 at 9 41 15 PM" src="https://github.com/user-attachments/assets/50e77253-d69d-40d9-9e1f-56e9341bc131" /> The time on `__arrow_ext_deserialize__` and `__arrow_ext_serialize__` is not a bottleneck anymore. --------- Signed-off-by: xgui <xgui@anyscale.com> Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com>
1 parent c25a789 commit 01b21f4

File tree

4 files changed

+386
-21
lines changed

4 files changed

+386
-21
lines changed

python/ray/air/util/tensor_extensions/arrow.py

Lines changed: 118 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import abc
2+
import functools
23
import itertools
34
import json
45
import logging
56
import sys
7+
import threading
8+
from abc import abstractmethod
69
from datetime import datetime
710
from enum import Enum
811
from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union
@@ -62,6 +65,11 @@ class _SerializationFormat(Enum):
6265
else _SerializationFormat.CLOUDPICKLE # default
6366
)
6467

68+
# 100,000 entries, about 10MB in memory.
69+
# Most users tables should have less than 100K columns.
70+
ARROW_EXTENSION_SERIALIZATION_CACHE_MAXSIZE = env_integer(
71+
"RAY_EXTENSION_SERIALIZATION_CACHE_MAXSIZE", 10**5
72+
)
6573

6674
logger = logging.getLogger(__name__)
6775

@@ -85,6 +93,88 @@ def _deserialize_with_fallback(serialized: bytes, field_name: str = "data"):
8593
)
8694

8795

96+
@DeveloperAPI(stability="beta")
97+
class ArrowExtensionSerializeDeserializeCache(abc.ABC):
98+
"""Base class for caching Arrow extension type serialization and deserialization.
99+
100+
The deserialization and serialization of Arrow extension types is frequent,
101+
so we cache the results here to improve performance.
102+
103+
The deserialization cache uses functools.lru_cache as a classmethod. There is
104+
a single cache instance shared across all subclasses, but the cache key includes
105+
the class (cls parameter) as the first argument, so different subclasses get
106+
different cache entries even when called with the same parameters. The cache is
107+
thread-safe and has a maximum size limit to control memory usage. The cache key
108+
is (cls, *args) where args are the parameters returned by _get_deserialize_parameter().
109+
110+
Attributes:
111+
_serialize_cache: Instance-level cache for serialization results.
112+
This is a simple cached value (bytes) that is computed once per
113+
instance and reused.
114+
"""
115+
116+
def __init__(self, *args: Any, **kwargs: Any) -> None:
117+
"""Initialize the extension type with caching support.
118+
119+
Args:
120+
*args: Positional arguments passed to the parent class.
121+
**kwargs: Keyword arguments passed to the parent class.
122+
"""
123+
# Instance-level cache for serialization results, no TTL
124+
self._serialize_cache = None
125+
self._cache_lock = threading.RLock()
126+
super().__init__(*args, **kwargs)
127+
128+
def __arrow_ext_serialize__(self) -> bytes:
129+
"""Serialize the extension type using caching if enabled."""
130+
with self._cache_lock:
131+
if self._serialize_cache is None:
132+
self._serialize_cache = self._arrow_ext_serialize_compute()
133+
return self._serialize_cache
134+
135+
@abstractmethod
136+
def _arrow_ext_serialize_compute(self) -> bytes:
137+
"""Subclasses must implement this method to compute serialization."""
138+
...
139+
140+
@classmethod
141+
@functools.lru_cache(maxsize=ARROW_EXTENSION_SERIALIZATION_CACHE_MAXSIZE)
142+
def _arrow_ext_deserialize_cache(cls: type, *args: Any, **kwargs: Any) -> Any:
143+
"""Deserialize the extension type using the class-level cache.
144+
145+
This method is cached using functools.lru_cache to improve performance
146+
when deserializing extension types. The cache key includes the class (cls)
147+
as the first argument, ensuring different subclasses get separate cache entries.
148+
149+
Args:
150+
*args: Positional arguments passed to _arrow_ext_deserialize_compute.
151+
**kwargs: Keyword arguments passed to _arrow_ext_deserialize_compute.
152+
153+
Returns:
154+
The deserialized extension type instance.
155+
"""
156+
return cls._arrow_ext_deserialize_compute(*args, **kwargs)
157+
158+
@classmethod
159+
@abstractmethod
160+
def _arrow_ext_deserialize_compute(cls, *args: Any, **kwargs: Any) -> Any:
161+
"""Subclasses must implement this method to compute deserialization."""
162+
...
163+
164+
@classmethod
165+
@abstractmethod
166+
def _get_deserialize_parameter(cls, storage_type, serialized) -> Tuple:
167+
"""Subclasses must implement this method to return the parameters for the deserialization cache."""
168+
...
169+
170+
@classmethod
171+
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> Any:
172+
"""Deserialize the extension type using caching if enabled."""
173+
return cls._arrow_ext_deserialize_cache(
174+
*cls._get_deserialize_parameter(storage_type, serialized)
175+
)
176+
177+
88178
@DeveloperAPI
89179
class ArrowConversionError(Exception):
90180
"""Error raised when there is an issue converting data to Arrow."""
@@ -431,7 +521,10 @@ def get_arrow_extension_variable_shape_tensor_types():
431521
return (ArrowVariableShapedTensorType,)
432522

433523

434-
class _BaseFixedShapeArrowTensorType(pa.ExtensionType, abc.ABC):
524+
# ArrowExtensionSerializeDeserializeCache needs to be first in the MRO to ensure the cache is used
525+
class _BaseFixedShapeArrowTensorType(
526+
ArrowExtensionSerializeDeserializeCache, pa.ExtensionType
527+
):
435528
"""
436529
Arrow ExtensionType for an array of fixed-shaped, homogeneous-typed
437530
tensors.
@@ -446,7 +539,6 @@ def __init__(
446539
self, shape: Tuple[int, ...], tensor_dtype: pa.DataType, ext_type_id: str
447540
):
448541
self._shape = shape
449-
450542
super().__init__(tensor_dtype, ext_type_id)
451543

452544
@property
@@ -478,7 +570,7 @@ def __reduce__(self):
478570
self.__arrow_ext_serialize__(),
479571
)
480572

481-
def __arrow_ext_serialize__(self):
573+
def _arrow_ext_serialize_compute(self):
482574
if ARROW_EXTENSION_SERIALIZATION_FORMAT == _SerializationFormat.CLOUDPICKLE:
483575
return cloudpickle.dumps(self._shape)
484576
elif ARROW_EXTENSION_SERIALIZATION_FORMAT == _SerializationFormat.JSON:
@@ -563,9 +655,13 @@ def __init__(self, shape: Tuple[int, ...], dtype: pa.DataType):
563655
super().__init__(shape, pa.list_(dtype), "ray.data.arrow_tensor")
564656

565657
@classmethod
566-
def __arrow_ext_deserialize__(cls, storage_type, serialized):
658+
def _get_deserialize_parameter(cls, storage_type, serialized):
659+
return (serialized, storage_type.value_type)
660+
661+
@classmethod
662+
def _arrow_ext_deserialize_compute(cls, serialized, value_type):
567663
shape = tuple(_deserialize_with_fallback(serialized, "shape"))
568-
return cls(shape, storage_type.value_type)
664+
return cls(shape, value_type)
569665

570666

571667
@PublicAPI(stability="alpha")
@@ -586,9 +682,13 @@ def __init__(self, shape: Tuple[int, ...], dtype: pa.DataType):
586682
super().__init__(shape, pa.large_list(dtype), "ray.data.arrow_tensor_v2")
587683

588684
@classmethod
589-
def __arrow_ext_deserialize__(cls, storage_type, serialized):
685+
def _get_deserialize_parameter(cls, storage_type, serialized):
686+
return (serialized, storage_type.value_type)
687+
688+
@classmethod
689+
def _arrow_ext_deserialize_compute(cls, serialized, value_type):
590690
shape = tuple(_deserialize_with_fallback(serialized, "shape"))
591-
return cls(shape, storage_type.value_type)
691+
return cls(shape, value_type)
592692

593693

594694
@PublicAPI(stability="beta")
@@ -878,8 +978,11 @@ def to_var_shaped_tensor_array(
878978
return target_type.wrap_array(storage)
879979

880980

981+
# ArrowExtensionSerializeDeserializeCache needs to be first in the MRO to ensure the cache is used
881982
@PublicAPI(stability="alpha")
882-
class ArrowVariableShapedTensorType(pa.ExtensionType):
983+
class ArrowVariableShapedTensorType(
984+
ArrowExtensionSerializeDeserializeCache, pa.ExtensionType
985+
):
883986
"""
884987
Arrow ExtensionType for an array of heterogeneous-shaped, homogeneous-typed
885988
tensors.
@@ -906,7 +1009,6 @@ def __init__(self, dtype: pa.DataType, ndim: int):
9061009
ndim: The number of dimensions in the tensor elements.
9071010
"""
9081011
self._ndim = ndim
909-
9101012
super().__init__(
9111013
pa.struct(
9121014
[("data", pa.large_list(dtype)), ("shape", pa.list_(self.OFFSET_DTYPE))]
@@ -949,7 +1051,7 @@ def __reduce__(self):
9491051
self.__arrow_ext_serialize__(),
9501052
)
9511053

952-
def __arrow_ext_serialize__(self):
1054+
def _arrow_ext_serialize_compute(self):
9531055
if ARROW_EXTENSION_SERIALIZATION_FORMAT == _SerializationFormat.CLOUDPICKLE:
9541056
return cloudpickle.dumps(self._ndim)
9551057
elif ARROW_EXTENSION_SERIALIZATION_FORMAT == _SerializationFormat.JSON:
@@ -960,10 +1062,13 @@ def __arrow_ext_serialize__(self):
9601062
)
9611063

9621064
@classmethod
963-
def __arrow_ext_deserialize__(cls, storage_type, serialized):
1065+
def _get_deserialize_parameter(cls, storage_type, serialized):
1066+
return (serialized, storage_type["data"].type.value_type)
1067+
1068+
@classmethod
1069+
def _arrow_ext_deserialize_compute(cls, serialized, value_type):
9641070
ndim = _deserialize_with_fallback(serialized, "ndim")
965-
dtype = storage_type["data"].type.value_type
966-
return cls(dtype, ndim)
1071+
return cls(value_type, ndim)
9671072

9681073
def __arrow_ext_class__(self):
9691074
"""

python/ray/air/util/tensor_extensions/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
import warnings
2-
from typing import TYPE_CHECKING, Any, List, Protocol, Sequence, Union
2+
from typing import (
3+
TYPE_CHECKING,
4+
Any,
5+
List,
6+
Protocol,
7+
Sequence,
8+
Union,
9+
)
310

411
import numpy as np
512

python/ray/data/_internal/arrow_ops/transform_pyarrow.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,10 @@ def unify_schemas(
306306
# Deduplicate schemas. Calling this before PyArrow's unify_schemas is more efficient (100x faster).
307307

308308
# Remove metadata for hashability
309-
schemas[0].remove_metadata()
309+
schema_to_compare = schemas[0].remove_metadata()
310310
schemas_to_unify = [schemas[0]]
311311
for schema in schemas[1:]:
312-
schema.remove_metadata()
313-
if not schema.equals(schemas[0]):
312+
if not schema.remove_metadata().equals(schema_to_compare):
314313
schemas_to_unify.append(schema)
315314

316315
pyarrow_exception = None
@@ -670,9 +669,8 @@ def _concat_cols_with_native_pyarrow_types(
670669
# NOTE: Type promotions aren't available in Arrow < 14.0
671670
subset_blocks = []
672671
for block in blocks:
673-
cols_to_select = [
674-
col_name for col_name in col_names if col_name in block.schema.names
675-
]
672+
block_cols = set(block.schema.names)
673+
cols_to_select = [col_name for col_name in col_names if col_name in block_cols]
676674
subset_blocks.append(block.select(cols_to_select))
677675
if get_pyarrow_version() < parse_version("14.0.0"):
678676
table = pa.concat_tables(subset_blocks, promote=True)

0 commit comments

Comments
 (0)