11import abc
2+ import functools
23import itertools
34import json
45import logging
56import sys
7+ import threading
8+ from abc import abstractmethod
69from datetime import datetime
710from enum import Enum
811from typing import Any , Collection , Dict , Iterable , List , Optional , Tuple , Union
@@ -62,6 +65,11 @@ class _SerializationFormat(Enum):
6265 else _SerializationFormat .CLOUDPICKLE # default
6366)
6467
68+ # 100,000 entries, about 10MB in memory.
69+ # Most users tables should have less than 100K columns.
70+ ARROW_EXTENSION_SERIALIZATION_CACHE_MAXSIZE = env_integer (
71+ "RAY_EXTENSION_SERIALIZATION_CACHE_MAXSIZE" , 10 ** 5
72+ )
6573
6674logger = logging .getLogger (__name__ )
6775
@@ -85,6 +93,88 @@ def _deserialize_with_fallback(serialized: bytes, field_name: str = "data"):
8593 )
8694
8795
96+ @DeveloperAPI (stability = "beta" )
97+ class ArrowExtensionSerializeDeserializeCache (abc .ABC ):
98+ """Base class for caching Arrow extension type serialization and deserialization.
99+
100+ The deserialization and serialization of Arrow extension types is frequent,
101+ so we cache the results here to improve performance.
102+
103+ The deserialization cache uses functools.lru_cache as a classmethod. There is
104+ a single cache instance shared across all subclasses, but the cache key includes
105+ the class (cls parameter) as the first argument, so different subclasses get
106+ different cache entries even when called with the same parameters. The cache is
107+ thread-safe and has a maximum size limit to control memory usage. The cache key
108+ is (cls, *args) where args are the parameters returned by _get_deserialize_parameter().
109+
110+ Attributes:
111+ _serialize_cache: Instance-level cache for serialization results.
112+ This is a simple cached value (bytes) that is computed once per
113+ instance and reused.
114+ """
115+
116+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
117+ """Initialize the extension type with caching support.
118+
119+ Args:
120+ *args: Positional arguments passed to the parent class.
121+ **kwargs: Keyword arguments passed to the parent class.
122+ """
123+ # Instance-level cache for serialization results, no TTL
124+ self ._serialize_cache = None
125+ self ._cache_lock = threading .RLock ()
126+ super ().__init__ (* args , ** kwargs )
127+
128+ def __arrow_ext_serialize__ (self ) -> bytes :
129+ """Serialize the extension type using caching if enabled."""
130+ with self ._cache_lock :
131+ if self ._serialize_cache is None :
132+ self ._serialize_cache = self ._arrow_ext_serialize_compute ()
133+ return self ._serialize_cache
134+
135+ @abstractmethod
136+ def _arrow_ext_serialize_compute (self ) -> bytes :
137+ """Subclasses must implement this method to compute serialization."""
138+ ...
139+
140+ @classmethod
141+ @functools .lru_cache (maxsize = ARROW_EXTENSION_SERIALIZATION_CACHE_MAXSIZE )
142+ def _arrow_ext_deserialize_cache (cls : type , * args : Any , ** kwargs : Any ) -> Any :
143+ """Deserialize the extension type using the class-level cache.
144+
145+ This method is cached using functools.lru_cache to improve performance
146+ when deserializing extension types. The cache key includes the class (cls)
147+ as the first argument, ensuring different subclasses get separate cache entries.
148+
149+ Args:
150+ *args: Positional arguments passed to _arrow_ext_deserialize_compute.
151+ **kwargs: Keyword arguments passed to _arrow_ext_deserialize_compute.
152+
153+ Returns:
154+ The deserialized extension type instance.
155+ """
156+ return cls ._arrow_ext_deserialize_compute (* args , ** kwargs )
157+
158+ @classmethod
159+ @abstractmethod
160+ def _arrow_ext_deserialize_compute (cls , * args : Any , ** kwargs : Any ) -> Any :
161+ """Subclasses must implement this method to compute deserialization."""
162+ ...
163+
164+ @classmethod
165+ @abstractmethod
166+ def _get_deserialize_parameter (cls , storage_type , serialized ) -> Tuple :
167+ """Subclasses must implement this method to return the parameters for the deserialization cache."""
168+ ...
169+
170+ @classmethod
171+ def __arrow_ext_deserialize__ (cls , storage_type , serialized ) -> Any :
172+ """Deserialize the extension type using caching if enabled."""
173+ return cls ._arrow_ext_deserialize_cache (
174+ * cls ._get_deserialize_parameter (storage_type , serialized )
175+ )
176+
177+
88178@DeveloperAPI
89179class ArrowConversionError (Exception ):
90180 """Error raised when there is an issue converting data to Arrow."""
@@ -431,7 +521,10 @@ def get_arrow_extension_variable_shape_tensor_types():
431521 return (ArrowVariableShapedTensorType ,)
432522
433523
434- class _BaseFixedShapeArrowTensorType (pa .ExtensionType , abc .ABC ):
524+ # ArrowExtensionSerializeDeserializeCache needs to be first in the MRO to ensure the cache is used
525+ class _BaseFixedShapeArrowTensorType (
526+ ArrowExtensionSerializeDeserializeCache , pa .ExtensionType
527+ ):
435528 """
436529 Arrow ExtensionType for an array of fixed-shaped, homogeneous-typed
437530 tensors.
@@ -446,7 +539,6 @@ def __init__(
446539 self , shape : Tuple [int , ...], tensor_dtype : pa .DataType , ext_type_id : str
447540 ):
448541 self ._shape = shape
449-
450542 super ().__init__ (tensor_dtype , ext_type_id )
451543
452544 @property
@@ -478,7 +570,7 @@ def __reduce__(self):
478570 self .__arrow_ext_serialize__ (),
479571 )
480572
481- def __arrow_ext_serialize__ (self ):
573+ def _arrow_ext_serialize_compute (self ):
482574 if ARROW_EXTENSION_SERIALIZATION_FORMAT == _SerializationFormat .CLOUDPICKLE :
483575 return cloudpickle .dumps (self ._shape )
484576 elif ARROW_EXTENSION_SERIALIZATION_FORMAT == _SerializationFormat .JSON :
@@ -563,9 +655,13 @@ def __init__(self, shape: Tuple[int, ...], dtype: pa.DataType):
563655 super ().__init__ (shape , pa .list_ (dtype ), "ray.data.arrow_tensor" )
564656
565657 @classmethod
566- def __arrow_ext_deserialize__ (cls , storage_type , serialized ):
658+ def _get_deserialize_parameter (cls , storage_type , serialized ):
659+ return (serialized , storage_type .value_type )
660+
661+ @classmethod
662+ def _arrow_ext_deserialize_compute (cls , serialized , value_type ):
567663 shape = tuple (_deserialize_with_fallback (serialized , "shape" ))
568- return cls (shape , storage_type . value_type )
664+ return cls (shape , value_type )
569665
570666
571667@PublicAPI (stability = "alpha" )
@@ -586,9 +682,13 @@ def __init__(self, shape: Tuple[int, ...], dtype: pa.DataType):
586682 super ().__init__ (shape , pa .large_list (dtype ), "ray.data.arrow_tensor_v2" )
587683
588684 @classmethod
589- def __arrow_ext_deserialize__ (cls , storage_type , serialized ):
685+ def _get_deserialize_parameter (cls , storage_type , serialized ):
686+ return (serialized , storage_type .value_type )
687+
688+ @classmethod
689+ def _arrow_ext_deserialize_compute (cls , serialized , value_type ):
590690 shape = tuple (_deserialize_with_fallback (serialized , "shape" ))
591- return cls (shape , storage_type . value_type )
691+ return cls (shape , value_type )
592692
593693
594694@PublicAPI (stability = "beta" )
@@ -878,8 +978,11 @@ def to_var_shaped_tensor_array(
878978 return target_type .wrap_array (storage )
879979
880980
981+ # ArrowExtensionSerializeDeserializeCache needs to be first in the MRO to ensure the cache is used
881982@PublicAPI (stability = "alpha" )
882- class ArrowVariableShapedTensorType (pa .ExtensionType ):
983+ class ArrowVariableShapedTensorType (
984+ ArrowExtensionSerializeDeserializeCache , pa .ExtensionType
985+ ):
883986 """
884987 Arrow ExtensionType for an array of heterogeneous-shaped, homogeneous-typed
885988 tensors.
@@ -906,7 +1009,6 @@ def __init__(self, dtype: pa.DataType, ndim: int):
9061009 ndim: The number of dimensions in the tensor elements.
9071010 """
9081011 self ._ndim = ndim
909-
9101012 super ().__init__ (
9111013 pa .struct (
9121014 [("data" , pa .large_list (dtype )), ("shape" , pa .list_ (self .OFFSET_DTYPE ))]
@@ -949,7 +1051,7 @@ def __reduce__(self):
9491051 self .__arrow_ext_serialize__ (),
9501052 )
9511053
952- def __arrow_ext_serialize__ (self ):
1054+ def _arrow_ext_serialize_compute (self ):
9531055 if ARROW_EXTENSION_SERIALIZATION_FORMAT == _SerializationFormat .CLOUDPICKLE :
9541056 return cloudpickle .dumps (self ._ndim )
9551057 elif ARROW_EXTENSION_SERIALIZATION_FORMAT == _SerializationFormat .JSON :
@@ -960,10 +1062,13 @@ def __arrow_ext_serialize__(self):
9601062 )
9611063
9621064 @classmethod
963- def __arrow_ext_deserialize__ (cls , storage_type , serialized ):
1065+ def _get_deserialize_parameter (cls , storage_type , serialized ):
1066+ return (serialized , storage_type ["data" ].type .value_type )
1067+
1068+ @classmethod
1069+ def _arrow_ext_deserialize_compute (cls , serialized , value_type ):
9641070 ndim = _deserialize_with_fallback (serialized , "ndim" )
965- dtype = storage_type ["data" ].type .value_type
966- return cls (dtype , ndim )
1071+ return cls (value_type , ndim )
9671072
9681073 def __arrow_ext_class__ (self ):
9691074 """
0 commit comments