Skip to content

Commit 9f19049

Browse files
KaisennHuisrabbani
andauthored
[Core] Support Zero-Copy Serialization for Read-Only Tensors (#57639)
Enable zero-copy serialization for all PyTorch tensors by setting `RAY_ENABLE_ZERO_COPY_TORCH_TENSORS=1` to accelerate serialization. Example test script: ```python import os # Must be set before `import ray` to ensure that the zero-copy tensor pickle reducer # is properly registered in driver. os.environ["RAY_ENABLE_ZERO_COPY_TORCH_TENSORS"] = "1" import ray import torch from datetime import datetime ray.init(runtime_env={"env_vars": {"RAY_ENABLE_ZERO_COPY_TORCH_TENSORS": "1"}}) @ray.remote def process(tensor): return tensor.sum() x = torch.ones(1024, 1024, 256) start_time = datetime.now() x_ref = process.remote(x) result = ray.get(x_ref) time_diff = datetime.now() - start_time print(f"result : {result}") print(f"between time: {time_diff.total_seconds()}s") print(f"result type : {type(result)}") ``` Below are the performance gains and validation results: <img width="1977" height="965" alt="zuizhongxiaoguo" src="https://github.com/user-attachments/assets/e3d5210c-142d-4ec3-908c-fe590514cfc8" /> Closes #56740 #26229 --------- Signed-off-by: Haichuan Hu <kaisennhu@gmail.com> Co-authored-by: Ibrahim Rabbani <irabbani@anyscale.com>
1 parent c01abb6 commit 9f19049

File tree

7 files changed

+438
-12
lines changed

7 files changed

+438
-12
lines changed

python/ray/_common/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,14 @@ def f():
247247
assert all(
248248
[extra_usage_tags[k] == v for k, v in expected_extra_usage_tags.items()]
249249
), extra_usage_tags
250+
251+
252+
def is_named_tuple(cls):
253+
"""Return True if cls is a namedtuple and False otherwise."""
254+
b = cls.__bases__
255+
if len(b) != 1 or b[0] is not tuple:
256+
return False
257+
f = getattr(cls, "_fields", None)
258+
if not isinstance(f, tuple):
259+
return False
260+
return all(type(n) is str for n in f)

python/ray/_private/ray_constants.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,28 @@ def gcs_actor_scheduling_enabled():
590590
RDT_FETCH_FAIL_TIMEOUT_SECONDS = (
591591
env_integer("RAY_rdt_fetch_fail_timeout_milliseconds", 60000) / 1000
592592
)
593+
594+
# Whether to enable zero-copy serialization for PyTorch tensors.
595+
# When enabled, Ray serializes PyTorch tensors by converting them to NumPy arrays
596+
# and leveraging pickle5's zero-copy buffer sharing. This avoids copying the
597+
# underlying tensor data, which can improve performance when passing large tensors
598+
# across tasks or actors. Note that this is experimental and should be used with caution
599+
# as we won't copy and allow a write to shared memory. One process changing a tensor
600+
# after ray.get could be reflected in another process.
601+
#
602+
# This feature is experimental and works best under the following conditions:
603+
# - The tensor has `requires_grad=False` (i.e., is detached from the autograd graph).
604+
# - The tensor is contiguous in memory
605+
# - Performance benefits from this are larger if the tensor resides in CPU memory
606+
# - You are not using Ray Direct Transport
607+
#
608+
# Tensors on GPU or non-contiguous tensors are still supported: Ray will
609+
# automatically move them to CPU and/or make them contiguous as needed.
610+
# While this incurs an initial copy, subsequent serialization may still benefit
611+
# from reduced overhead compared to the default path.
612+
#
613+
# Use with caution and ensure tensors meet the above criteria before enabling.
614+
# Default: False.
615+
RAY_ENABLE_ZERO_COPY_TORCH_TENSORS = env_bool(
616+
"RAY_ENABLE_ZERO_COPY_TORCH_TENSORS", False
617+
)

python/ray/_private/serialization.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import threading
33
import traceback
4+
import warnings
45
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
56

67
if TYPE_CHECKING:
@@ -11,7 +12,10 @@
1112
import ray._private.utils
1213
import ray.cloudpickle as pickle
1314
import ray.exceptions
14-
from ray._private import ray_constants
15+
from ray._private import (
16+
ray_constants,
17+
tensor_serialization_utils,
18+
)
1519
from ray._raylet import (
1620
DynamicObjectRefGenerator,
1721
MessagePackSerializedObject,
@@ -159,6 +163,28 @@ def __init__(self, worker):
159163
# instead of the normal serialize -> object store -> deserialize codepath.
160164
self._torch_custom_serializer_registered = False
161165

166+
# Enable zero-copy serialization of tensors if the environment variable is set.
167+
self._zero_copy_tensors_enabled = (
168+
ray_constants.RAY_ENABLE_ZERO_COPY_TORCH_TENSORS
169+
)
170+
if self._zero_copy_tensors_enabled:
171+
try:
172+
import torch
173+
174+
self._register_cloudpickle_reducer(
175+
torch.Tensor, tensor_serialization_utils.zero_copy_tensors_reducer
176+
)
177+
except ImportError:
178+
# Warn and disable zero-copy tensor serialization when PyTorch is missing,
179+
# even if RAY_ENABLE_ZERO_COPY_TORCH_TENSORS is set.
180+
warnings.warn(
181+
"PyTorch is not installed. Disabling zero-copy tensor serialization "
182+
"even though RAY_ENABLE_ZERO_COPY_TORCH_TENSORS is set.",
183+
tensor_serialization_utils.ZeroCopyTensorsWarning,
184+
stacklevel=3,
185+
)
186+
self._zero_copy_tensors_enabled = False
187+
162188
def actor_handle_reducer(obj):
163189
ray._private.worker.global_worker.check_connected()
164190
serialized, actor_handle_id, weak_ref = obj._serialization_helper()
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import warnings
2+
from typing import TYPE_CHECKING, Any, Tuple
3+
4+
if TYPE_CHECKING:
5+
import numpy as np
6+
import torch
7+
8+
9+
class ZeroCopyTensorsWarning(UserWarning):
10+
"""
11+
Warning for unsafe or failed zero-copy tensor serialization/deserialization.
12+
"""
13+
14+
pass
15+
16+
17+
warnings.filterwarnings("once", category=ZeroCopyTensorsWarning)
18+
19+
20+
def _zero_copy_tensors_deserializer(
21+
np_array: "np.ndarray", dtype_str: str, shape: Tuple[int, ...], device_str: str
22+
) -> "torch.Tensor":
23+
"""
24+
Reconstructs a torch.Tensor from a zero-copy NumPy byte array.
25+
26+
Args:
27+
np_array: 1D uint8 NumPy array of the original tensor's raw bytes.
28+
dtype_str: Full string representation of the original tensor's dtype (e.g., 'torch.float32').
29+
shape: The original shape of the tensor before serialization.
30+
device_str: String representation of the original device (e.g., 'cpu', 'cuda:0').
31+
32+
Returns:
33+
Reconstructed torch.Tensor on the specified device if successful;
34+
otherwise, returns the input np_array unchanged and issues a warning.
35+
36+
Raises:
37+
ImportError/DeserializationError: If deserialization fails for any reason (e.g., missing PyTorch
38+
dtype mismatch, shape inconsistency, device error, etc.).
39+
"""
40+
try:
41+
import torch
42+
except ImportError as e:
43+
raise ImportError(
44+
"Zero-copy tensor deserialization failed: PyTorch is not installed."
45+
) from e
46+
47+
try:
48+
# Step 1: Convert uint8 numpy array back to torch tensor
49+
uint8_tensor = torch.from_numpy(np_array)
50+
51+
# Step 2: Restore original dtype
52+
dtype_name = dtype_str.split(".")[-1]
53+
if not hasattr(torch, dtype_name):
54+
raise ValueError(f"Invalid or unsupported dtype string: {dtype_str}")
55+
original_dtype = getattr(torch, dtype_name)
56+
57+
# Compute number of bytes per element
58+
dtype_size = torch.tensor([], dtype=original_dtype).element_size()
59+
if np_array.size % dtype_size != 0:
60+
raise ValueError(
61+
f"Byte array size ({np_array.size}) is not divisible by "
62+
f"dtype size ({dtype_size}) for dtype {dtype_str}"
63+
)
64+
65+
# Step 3: Reshape and reinterpret bytes as target dtype
66+
restored_tensor = uint8_tensor.view(original_dtype).reshape(shape)
67+
68+
# Step 4: Move to target device
69+
return restored_tensor.to(device=device_str)
70+
71+
except Exception as e:
72+
from ray._private.serialization import DeserializationError
73+
74+
raise DeserializationError(
75+
f"Failed to deserialize zero-copy tensor from byte array. "
76+
f"Input dtype={dtype_str}, shape={shape}, device={device_str}. "
77+
f"Underlying error: {type(e).__name__}: {e}"
78+
) from e
79+
80+
81+
def zero_copy_tensors_reducer(tensor: "torch.Tensor") -> Tuple[Any, Tuple[Any, ...]]:
82+
"""Pickle serializer for zero-copy serialization of read-only torch.Tensor.
83+
84+
This serializer aims to avoid copying tensor data by using a NumPy uint8 view,
85+
which enables pickle5's out-of-band buffer transmission. However, true zero-copy
86+
is only possible when the input tensor is already:
87+
88+
- On CPU,
89+
- Detached from the computation graph (no gradients),
90+
- Contiguous in memory.
91+
92+
If the input tensor does **not** meet these conditions, this function will:
93+
94+
- Call `.detach()` to remove gradient information,
95+
- Move the tensor to CPU (copying data if it's on GPU or another device),
96+
- Make the tensor contiguous (copying data if it's non-contiguous).
97+
98+
These operations may incur one or two full copies of the tensor data,
99+
negating zero-copy benefits. A warning is issued in such cases.
100+
101+
Args:
102+
tensor: The input torch.Tensor to serialize. Can be on any device,
103+
with or without gradients, contiguous or not — but zero-copy
104+
is only achieved if it is already CPU, detached, and contiguous.
105+
106+
Returns:
107+
A tuple (deserializer_callable, args_tuple) suitable for pickle.
108+
"""
109+
warnings.warn(
110+
"Zero-copy tensor serialization is enabled, but it only works safely for read-only tensors "
111+
"(detached, no gradients, contiguous). Modifiable or non-contiguous tensors may cause data corruption.",
112+
ZeroCopyTensorsWarning,
113+
stacklevel=3,
114+
)
115+
116+
import torch
117+
118+
# Detach the tensor from gradients and computation graph.
119+
# Move it to cpu (this is a noop if the tensor is already in main memory, but will create a copy if the
120+
# the tensor is on an accelerator).
121+
# Ensure that the tensor is contiguous. If the tensor is not contiguous, this will create a contiguous
122+
# copy.
123+
cpu_tensor = tensor.detach().cpu()
124+
if not cpu_tensor.is_contiguous():
125+
warnings.warn(
126+
"The input tensor is non-contiguous. A copy will be made to ensure contiguity. "
127+
"For zero-copy serialization, please ensure the tensor is contiguous before passing it "
128+
"(e.g., by calling `.contiguous()`).",
129+
ZeroCopyTensorsWarning,
130+
stacklevel=3,
131+
)
132+
cpu_tensor = cpu_tensor.contiguous()
133+
134+
# Flatten to 1D for safe uint8 view (handles scalars)
135+
flat_tensor = cpu_tensor.reshape(-1)
136+
# View as uint8 bytes
137+
uint8_view = flat_tensor.view(torch.uint8)
138+
np_array = uint8_view.numpy()
139+
140+
return _zero_copy_tensors_deserializer, (
141+
np_array,
142+
str(tensor.dtype),
143+
tuple(tensor.shape),
144+
str(tensor.device),
145+
)

python/ray/tests/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ py_test_module_list(
640640
files = [
641641
"gpu_objects/test_gpu_objects_nccl.py",
642642
"gpu_objects/test_gpu_objects_nixl.py",
643+
"test_tensor_zero_copy_serialization.py",
643644
],
644645
tags = [
645646
"custom_setup",

python/ray/tests/test_serialization.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,11 @@
1616
import ray.cluster_utils
1717
import ray.exceptions
1818
from ray import cloudpickle
19+
from ray._common.test_utils import is_named_tuple
1920

2021
logger = logging.getLogger(__name__)
2122

2223

23-
def is_named_tuple(cls):
24-
"""Return True if cls is a namedtuple and False otherwise."""
25-
b = cls.__bases__
26-
if len(b) != 1 or b[0] is not tuple:
27-
return False
28-
f = getattr(cls, "_fields", None)
29-
if not isinstance(f, tuple):
30-
return False
31-
return all(type(n) is str for n in f)
32-
33-
3424
@pytest.mark.parametrize(
3525
"ray_start_regular", [{"local_mode": True}, {"local_mode": False}], indirect=True
3626
)

0 commit comments

Comments
 (0)