|  | 
| 2 | 2 | 
 | 
| 3 | 3 | import json | 
| 4 | 4 | import os | 
| 5 |  | -import pickle | 
| 6 | 5 | from concurrent.futures import ThreadPoolExecutor | 
| 7 | 6 | from dataclasses import dataclass | 
| 8 | 7 | from typing import Optional, Union | 
| 9 | 8 | 
 | 
| 10 | 9 | import torch | 
| 11 | 10 | import zmq | 
|  | 11 | +from safetensors.torch import load as safetensors_load | 
|  | 12 | +from safetensors.torch import save as safetensors_save | 
| 12 | 13 | 
 | 
| 13 | 14 | from vllm.config import KVTransferConfig | 
| 14 | 15 | from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase | 
| @@ -237,14 +238,13 @@ def tensor_hash(self, tensor: torch.Tensor) -> int: | 
| 237 | 238 |         return hash(tensor.data_ptr()) | 
| 238 | 239 | 
 | 
| 239 | 240 |     def _send_impl(self, tensor: torch.Tensor) -> None: | 
| 240 |  | -        """Implement the tensor sending logic.""" | 
| 241 |  | -        value_bytes = pickle.dumps(tensor) | 
| 242 |  | -        self.transfer_engine.send_bytes(value_bytes) | 
|  | 241 | +        """Implement the tensor sending logic using safetensors.""" | 
|  | 242 | +        self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor})) | 
| 243 | 243 | 
 | 
| 244 | 244 |     def _recv_impl(self) -> torch.Tensor: | 
| 245 |  | -        """Implement the tensor receiving logic.""" | 
|  | 245 | +        """Implement the tensor receiving logic using safetensors.""" | 
| 246 | 246 |         data = self.transfer_engine.recv_bytes() | 
| 247 |  | -        return pickle.loads(data) | 
|  | 247 | +        return safetensors_load(data)["tensor"].to(self.device) | 
| 248 | 248 | 
 | 
| 249 | 249 |     def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: | 
| 250 | 250 |         """Send tensor to the target process.""" | 
|  | 
0 commit comments