|
| 1 | +import warnings |
| 2 | +from typing import TYPE_CHECKING, Any, Tuple |
| 3 | + |
| 4 | +if TYPE_CHECKING: |
| 5 | + import numpy as np |
| 6 | + import torch |
| 7 | + |
| 8 | + |
| 9 | +class ZeroCopyTensorsWarning(UserWarning): |
| 10 | + """ |
| 11 | + Warning for unsafe or failed zero-copy tensor serialization/deserialization. |
| 12 | + """ |
| 13 | + |
| 14 | + pass |
| 15 | + |
| 16 | + |
| 17 | +warnings.filterwarnings("once", category=ZeroCopyTensorsWarning) |
| 18 | + |
| 19 | + |
| 20 | +def _zero_copy_tensors_deserializer( |
| 21 | + np_array: "np.ndarray", dtype_str: str, shape: Tuple[int, ...], device_str: str |
| 22 | +) -> "torch.Tensor": |
| 23 | + """ |
| 24 | + Reconstructs a torch.Tensor from a zero-copy NumPy byte array. |
| 25 | +
|
| 26 | + Args: |
| 27 | + np_array: 1D uint8 NumPy array of the original tensor's raw bytes. |
| 28 | + dtype_str: Full string representation of the original tensor's dtype (e.g., 'torch.float32'). |
| 29 | + shape: The original shape of the tensor before serialization. |
| 30 | + device_str: String representation of the original device (e.g., 'cpu', 'cuda:0'). |
| 31 | +
|
| 32 | + Returns: |
| 33 | + Reconstructed torch.Tensor on the specified device if successful; |
| 34 | + otherwise, returns the input np_array unchanged and issues a warning. |
| 35 | +
|
| 36 | + Raises: |
| 37 | + ImportError/DeserializationError: If deserialization fails for any reason (e.g., missing PyTorch |
| 38 | + dtype mismatch, shape inconsistency, device error, etc.). |
| 39 | + """ |
| 40 | + try: |
| 41 | + import torch |
| 42 | + except ImportError as e: |
| 43 | + raise ImportError( |
| 44 | + "Zero-copy tensor deserialization failed: PyTorch is not installed." |
| 45 | + ) from e |
| 46 | + |
| 47 | + try: |
| 48 | + # Step 1: Convert uint8 numpy array back to torch tensor |
| 49 | + uint8_tensor = torch.from_numpy(np_array) |
| 50 | + |
| 51 | + # Step 2: Restore original dtype |
| 52 | + dtype_name = dtype_str.split(".")[-1] |
| 53 | + if not hasattr(torch, dtype_name): |
| 54 | + raise ValueError(f"Invalid or unsupported dtype string: {dtype_str}") |
| 55 | + original_dtype = getattr(torch, dtype_name) |
| 56 | + |
| 57 | + # Compute number of bytes per element |
| 58 | + dtype_size = torch.tensor([], dtype=original_dtype).element_size() |
| 59 | + if np_array.size % dtype_size != 0: |
| 60 | + raise ValueError( |
| 61 | + f"Byte array size ({np_array.size}) is not divisible by " |
| 62 | + f"dtype size ({dtype_size}) for dtype {dtype_str}" |
| 63 | + ) |
| 64 | + |
| 65 | + # Step 3: Reshape and reinterpret bytes as target dtype |
| 66 | + restored_tensor = uint8_tensor.view(original_dtype).reshape(shape) |
| 67 | + |
| 68 | + # Step 4: Move to target device |
| 69 | + return restored_tensor.to(device=device_str) |
| 70 | + |
| 71 | + except Exception as e: |
| 72 | + from ray._private.serialization import DeserializationError |
| 73 | + |
| 74 | + raise DeserializationError( |
| 75 | + f"Failed to deserialize zero-copy tensor from byte array. " |
| 76 | + f"Input dtype={dtype_str}, shape={shape}, device={device_str}. " |
| 77 | + f"Underlying error: {type(e).__name__}: {e}" |
| 78 | + ) from e |
| 79 | + |
| 80 | + |
| 81 | +def zero_copy_tensors_reducer(tensor: "torch.Tensor") -> Tuple[Any, Tuple[Any, ...]]: |
| 82 | + """Pickle serializer for zero-copy serialization of read-only torch.Tensor. |
| 83 | +
|
| 84 | + This serializer aims to avoid copying tensor data by using a NumPy uint8 view, |
| 85 | + which enables pickle5's out-of-band buffer transmission. However, true zero-copy |
| 86 | + is only possible when the input tensor is already: |
| 87 | +
|
| 88 | + - On CPU, |
| 89 | + - Detached from the computation graph (no gradients), |
| 90 | + - Contiguous in memory. |
| 91 | +
|
| 92 | + If the input tensor does **not** meet these conditions, this function will: |
| 93 | +
|
| 94 | + - Call `.detach()` to remove gradient information, |
| 95 | + - Move the tensor to CPU (copying data if it's on GPU or another device), |
| 96 | + - Make the tensor contiguous (copying data if it's non-contiguous). |
| 97 | +
|
| 98 | + These operations may incur one or two full copies of the tensor data, |
| 99 | + negating zero-copy benefits. A warning is issued in such cases. |
| 100 | +
|
| 101 | + Args: |
| 102 | + tensor: The input torch.Tensor to serialize. Can be on any device, |
| 103 | + with or without gradients, contiguous or not — but zero-copy |
| 104 | + is only achieved if it is already CPU, detached, and contiguous. |
| 105 | +
|
| 106 | + Returns: |
| 107 | + A tuple (deserializer_callable, args_tuple) suitable for pickle. |
| 108 | + """ |
| 109 | + warnings.warn( |
| 110 | + "Zero-copy tensor serialization is enabled, but it only works safely for read-only tensors " |
| 111 | + "(detached, no gradients, contiguous). Modifiable or non-contiguous tensors may cause data corruption.", |
| 112 | + ZeroCopyTensorsWarning, |
| 113 | + stacklevel=3, |
| 114 | + ) |
| 115 | + |
| 116 | + import torch |
| 117 | + |
| 118 | + # Detach the tensor from gradients and computation graph. |
| 119 | + # Move it to cpu (this is a noop if the tensor is already in main memory, but will create a copy if the |
| 120 | + # the tensor is on an accelerator). |
| 121 | + # Ensure that the tensor is contiguous. If the tensor is not contiguous, this will create a contiguous |
| 122 | + # copy. |
| 123 | + cpu_tensor = tensor.detach().cpu() |
| 124 | + if not cpu_tensor.is_contiguous(): |
| 125 | + warnings.warn( |
| 126 | + "The input tensor is non-contiguous. A copy will be made to ensure contiguity. " |
| 127 | + "For zero-copy serialization, please ensure the tensor is contiguous before passing it " |
| 128 | + "(e.g., by calling `.contiguous()`).", |
| 129 | + ZeroCopyTensorsWarning, |
| 130 | + stacklevel=3, |
| 131 | + ) |
| 132 | + cpu_tensor = cpu_tensor.contiguous() |
| 133 | + |
| 134 | + # Flatten to 1D for safe uint8 view (handles scalars) |
| 135 | + flat_tensor = cpu_tensor.reshape(-1) |
| 136 | + # View as uint8 bytes |
| 137 | + uint8_view = flat_tensor.view(torch.uint8) |
| 138 | + np_array = uint8_view.numpy() |
| 139 | + |
| 140 | + return _zero_copy_tensors_deserializer, ( |
| 141 | + np_array, |
| 142 | + str(tensor.dtype), |
| 143 | + tuple(tensor.shape), |
| 144 | + str(tensor.device), |
| 145 | + ) |
0 commit comments