- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
[V1] Zero-copy tensor/ndarray serialization/transmission #13790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            20 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      7b3b6ea
              
                [V1] Zero-copy tensor/ndarray serialization/transmission
              
              
                njhill 35d1cd9
              
                TypeAlias keyword is python >= 3.10 only
              
              
                njhill f6f26b6
              
                use highest pickle protocol
              
              
                njhill 4382a16
              
                Merge remote-tracking branch 'origin/main' into tensor-nocopy
              
              
                njhill 9d91483
              
                Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
              
              
                njhill ea75bd3
              
                Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
              
              
                njhill 95b0600
              
                Add unit test
              
              
                njhill 910f30f
              
                pre-commit fix
              
              
                njhill 747ce1c
              
                Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
              
              
                njhill 478ce09
              
                Fix unrecognized type decode
              
              
                njhill 7ea02a8
              
                Handle scalars properly
              
              
                njhill e7d010d
              
                Optimization: encode small tensors inline.
              
              
                njhill f946398
              
                Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
              
              
                njhill 60797b4
              
                Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
              
              
                njhill c0c6e43
              
                Update vllm/v1/serial_utils.py
              
              
                njhill 3b978ad
              
                Update vllm/v1/serial_utils.py
              
              
                njhill 80d90a5
              
                Update vllm/v1/serial_utils.py
              
              
                njhill 6bd45dc
              
                Update vllm/v1/serial_utils.py
              
              
                njhill 97c144b
              
                Update vllm/v1/serial_utils.py
              
              
                njhill c6c2a90
              
                Comment/docstring updates
              
              
                njhill File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| from collections import UserDict | ||
| from dataclasses import dataclass | ||
|  | ||
| import numpy as np | ||
| import torch | ||
|  | ||
| from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder | ||
|  | ||
|  | ||
| class UnrecognizedType(UserDict): | ||
|  | ||
| def __init__(self, an_int: int): | ||
| super().__init__() | ||
| self.an_int = an_int | ||
|  | ||
|  | ||
| @dataclass | ||
| class MyType: | ||
| tensor1: torch.Tensor | ||
| a_string: str | ||
| list_of_tensors: list[torch.Tensor] | ||
| numpy_array: np.ndarray | ||
| unrecognized: UnrecognizedType | ||
|  | ||
|  | ||
| def test_encode_decode(): | ||
| """Test encode/decode loop with zero-copy tensors.""" | ||
|  | ||
| obj = MyType( | ||
| tensor1=torch.randint(low=0, | ||
| high=100, | ||
| size=(1024, ), | ||
| dtype=torch.int32), | ||
| a_string="hello", | ||
| list_of_tensors=[ | ||
| torch.rand((1, 10), dtype=torch.float32), | ||
| torch.rand((3, 5, 4000), dtype=torch.float64), | ||
| torch.tensor(1984), # test scalar too | ||
| ], | ||
| numpy_array=np.arange(512), | ||
| unrecognized=UnrecognizedType(33), | ||
| ) | ||
|  | ||
| encoder = MsgpackEncoder() | ||
| decoder = MsgpackDecoder(MyType) | ||
|  | ||
| encoded = encoder.encode(obj) | ||
|  | ||
| # There should be the main buffer + 2 large tensor buffers | ||
| # + 1 large numpy array. "large" is <= 256 bytes. | ||
| # The two small tensors are encoded inline. | ||
| assert len(encoded) == 4 | ||
|  | ||
| decoded: MyType = decoder.decode(encoded) | ||
|  | ||
| assert_equal(decoded, obj) | ||
|  | ||
| # Test encode_into case | ||
|  | ||
| preallocated = bytearray() | ||
|  | ||
| encoded2 = encoder.encode_into(obj, preallocated) | ||
|  | ||
| assert len(encoded2) == 4 | ||
| assert encoded2[0] is preallocated | ||
|  | ||
| decoded2: MyType = decoder.decode(encoded2) | ||
|  | ||
| assert_equal(decoded2, obj) | ||
|  | ||
|  | ||
| def assert_equal(obj1: MyType, obj2: MyType): | ||
| assert torch.equal(obj1.tensor1, obj2.tensor1) | ||
| assert obj1.a_string == obj2.a_string | ||
| assert all( | ||
| torch.equal(a, b) | ||
| for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)) | ||
| assert np.array_equal(obj1.numpy_array, obj2.numpy_array) | ||
| assert obj1.unrecognized.an_int == obj2.unrecognized.an_int | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,61 +1,140 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|  | ||
| import pickle | ||
| from collections.abc import Sequence | ||
| from inspect import isclass | ||
| from types import FunctionType | ||
| from typing import Any, Optional | ||
| from typing import Any, Optional, Union | ||
|  | ||
| import cloudpickle | ||
| import numpy as np | ||
| import torch | ||
| import zmq | ||
| from msgspec import msgpack | ||
|  | ||
| CUSTOM_TYPE_TENSOR = 1 | ||
| CUSTOM_TYPE_PICKLE = 2 | ||
| CUSTOM_TYPE_CLOUDPICKLE = 3 | ||
| CUSTOM_TYPE_PICKLE = 1 | ||
| CUSTOM_TYPE_CLOUDPICKLE = 2 | ||
|  | ||
| # TODO calibrate this size | ||
| INLINE_BUF_SIZE_THRESHOLD = 256 | ||
|  | ||
| class MsgpackEncoder: | ||
| """Encoder with custom torch tensor serialization.""" | ||
| bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] | ||
|  | ||
| def __init__(self): | ||
| self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook) | ||
|  | ||
| def encode(self, obj: Any) -> bytes: | ||
| return self.encoder.encode(obj) | ||
| class MsgpackEncoder: | ||
| """Encoder with custom torch tensor and numpy array serialization. | ||
|  | ||
| def encode_into(self, obj: Any, buf: bytearray) -> None: | ||
| self.encoder.encode_into(obj, buf) | ||
| Note that unlike vanilla `msgspec` Encoders, this interface is generally | ||
| not thread-safe when encoding tensors / numpy arrays. | ||
| """ | ||
|  | ||
| def __init__(self): | ||
| self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) | ||
| # This is used as a local stash of buffers that we can then access from | ||
| # our custom `msgspec` hook, `enc_hook`. We don't have a way to | ||
| # pass custom data to the hook otherwise. | ||
| self.aux_buffers: Optional[list[bytestr]] = None | ||
|  | ||
| def encode(self, obj: Any) -> Sequence[bytestr]: | ||
| try: | ||
| self.aux_buffers = bufs = [b''] | ||
| bufs[0] = self.encoder.encode(obj) | ||
| # This `bufs` list allows us to collect direct pointers to backing | ||
| # buffers of tensors and np arrays, and return them along with the | ||
| # top-level encoded buffer instead of copying their data into the | ||
| # new buffer. | ||
| return bufs | ||
|         
                  njhill marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| finally: | ||
| self.aux_buffers = None | ||
|  | ||
| def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: | ||
| try: | ||
| self.aux_buffers = [buf] | ||
| bufs = self.aux_buffers | ||
| self.encoder.encode_into(obj, buf) | ||
| return bufs | ||
| finally: | ||
| self.aux_buffers = None | ||
|  | ||
| def enc_hook(self, obj: Any) -> Any: | ||
| if isinstance(obj, torch.Tensor): | ||
|         
                  njhill marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| return self._encode_ndarray(obj.numpy()) | ||
|  | ||
| # Fall back to pickle for object or void kind ndarrays. | ||
| if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): | ||
| return self._encode_ndarray(obj) | ||
|  | ||
| if isinstance(obj, FunctionType): | ||
|         
                  njhill marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| # `pickle` is generally faster than cloudpickle, but can have | ||
| # problems serializing methods. | ||
| return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) | ||
|  | ||
| return msgpack.Ext(CUSTOM_TYPE_PICKLE, | ||
| pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) | ||
|  | ||
| def _encode_ndarray( | ||
| self, obj: np.ndarray | ||
| ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: | ||
| assert self.aux_buffers is not None | ||
| if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD: | ||
| # Encode small arrays and scalars inline. | ||
| data = obj.data | ||
| else: | ||
| # Otherwise encode index of backing buffer. | ||
| obj = np.ascontiguousarray(obj) | ||
| data = len(self.aux_buffers) | ||
| self.aux_buffers.append(obj.data) | ||
| # We serialize the ndarray as a tuple of native types. | ||
| # The data is either inlined if small, or an index into a list of | ||
| # backing buffers that we've stashed in `aux_buffers`. | ||
| return obj.dtype.str, obj.shape, data | ||
|         
                  njhill marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
|  | ||
|  | ||
| class MsgpackDecoder: | ||
| """Decoder with custom torch tensor serialization.""" | ||
| """Decoder with custom torch tensor and numpy array serialization. | ||
|  | ||
| Note that unlike vanilla `msgspec` Decoders, this interface is generally | ||
| not thread-safe when encoding tensors / numpy arrays. | ||
| """ | ||
|  | ||
| def __init__(self, t: Optional[Any] = None): | ||
| args = () if t is None else (t, ) | ||
| self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook) | ||
|  | ||
| def decode(self, obj: Any): | ||
| return self.decoder.decode(obj) | ||
|  | ||
|  | ||
| def custom_enc_hook(obj: Any) -> Any: | ||
| if isinstance(obj, torch.Tensor): | ||
| # NOTE(rob): it is fastest to use numpy + pickle | ||
| # when serializing torch tensors. | ||
| # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 | ||
| return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) | ||
|  | ||
| if isinstance(obj, FunctionType): | ||
| return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) | ||
|  | ||
| return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) | ||
|  | ||
|  | ||
| def custom_ext_hook(code: int, data: memoryview) -> Any: | ||
| if code == CUSTOM_TYPE_TENSOR: | ||
| return torch.from_numpy(pickle.loads(data)) | ||
| if code == CUSTOM_TYPE_PICKLE: | ||
| return pickle.loads(data) | ||
| if code == CUSTOM_TYPE_CLOUDPICKLE: | ||
| return cloudpickle.loads(data) | ||
|  | ||
| raise NotImplementedError(f"Extension type code {code} is not supported") | ||
| self.decoder = msgpack.Decoder(*args, | ||
| ext_hook=self.ext_hook, | ||
| dec_hook=self.dec_hook) | ||
| self.aux_buffers: Sequence[bytestr] = () | ||
|  | ||
| def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: | ||
| if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): | ||
|         
                  njhill marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| # TODO - This check can become `isinstance(bufs, bytestr)` | ||
| # as of Python 3.10. | ||
| return self.decoder.decode(bufs) | ||
|  | ||
| self.aux_buffers = bufs | ||
| try: | ||
| return self.decoder.decode(bufs[0]) | ||
| finally: | ||
| self.aux_buffers = () | ||
|  | ||
| def dec_hook(self, t: type, obj: Any) -> Any: | ||
|         
                  njhill marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| # Given native types in `obj`, convert to type `t`. | ||
| if isclass(t): | ||
| if issubclass(t, np.ndarray): | ||
| return self._decode_ndarray(obj) | ||
| if issubclass(t, torch.Tensor): | ||
| return torch.from_numpy(self._decode_ndarray(obj)) | ||
| return obj | ||
|  | ||
| def _decode_ndarray(self, arr: Any) -> np.ndarray: | ||
| dtype, shape, data = arr | ||
| buffer = self.aux_buffers[data] if isinstance(data, int) else data | ||
| return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) | ||
|  | ||
| def ext_hook(self, code: int, data: memoryview) -> Any: | ||
| if code == CUSTOM_TYPE_PICKLE: | ||
| return pickle.loads(data) | ||
| if code == CUSTOM_TYPE_CLOUDPICKLE: | ||
| return cloudpickle.loads(data) | ||
|  | ||
| raise NotImplementedError( | ||
| f"Extension type code {code} is not supported") | ||
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.