11# SPDX-License-Identifier: Apache-2.0 
22
3+ import  dataclasses 
34import  pickle 
45from  collections .abc  import  Sequence 
56from  inspect  import  isclass 
1213import  zmq 
1314from  msgspec  import  msgpack 
1415
16+ from  vllm  import  envs 
17+ from  vllm .multimodal .inputs  import  (BaseMultiModalField ,
18+                                     MultiModalBatchedField ,
19+                                     MultiModalFieldConfig , MultiModalFieldElem ,
20+                                     MultiModalFlatField , MultiModalKwargs ,
21+                                     MultiModalKwargsItem ,
22+                                     MultiModalSharedField , NestedTensors )
23+ 
1524CUSTOM_TYPE_PICKLE  =  1 
1625CUSTOM_TYPE_CLOUDPICKLE  =  2 
1726CUSTOM_TYPE_RAW_VIEW  =  3 
1827
19- # TODO calibrate this size 
20- MIN_NOCOPY_BUF_SIZE  =  512 
28+ # MultiModalField class serialization type map. 
29+ # These need to list all possible field types and match them 
30+ # to factory methods in `MultiModalFieldConfig`. 
31+ MMF_CLASS_TO_FACTORY : dict [type [BaseMultiModalField ], str ] =  {
32+     MultiModalFlatField : "flat" ,
33+     MultiModalSharedField : "shared" ,
34+     MultiModalBatchedField : "batched" ,
35+ }
2136
2237bytestr  =  Union [bytes , bytearray , memoryview , zmq .Frame ]
2338
@@ -27,14 +42,20 @@ class MsgpackEncoder:
2742
2843    Note that unlike vanilla `msgspec` Encoders, this interface is generally 
2944    not thread-safe when encoding tensors / numpy arrays. 
45+ 
46+     By default, arrays below 256B are serialized inline Larger will get sent  
47+     via dedicated messages. Note that this is a per-tensor limit. 
3048    """ 
3149
32-     def  __init__ (self ):
50+     def  __init__ (self , size_threshold : Optional [int ] =  None ):
51+         if  size_threshold  is  None :
52+             size_threshold  =  envs .VLLM_MSGPACK_ZERO_COPY_THRESHOLD 
3353        self .encoder  =  msgpack .Encoder (enc_hook = self .enc_hook )
3454        # This is used as a local stash of buffers that we can then access from 
3555        # our custom `msgspec` hook, `enc_hook`. We don't have a way to 
3656        # pass custom data to the hook otherwise. 
3757        self .aux_buffers : Optional [list [bytestr ]] =  None 
58+         self .size_threshold  =  size_threshold 
3859
3960    def  encode (self , obj : Any ) ->  Sequence [bytestr ]:
4061        try :
@@ -65,6 +86,25 @@ def enc_hook(self, obj: Any) -> Any:
6586        if  isinstance (obj , np .ndarray ) and  obj .dtype .kind  not  in 'O' , 'V' ):
6687            return  self ._encode_ndarray (obj )
6788
89+         if  isinstance (obj , MultiModalKwargs ):
90+             mm : MultiModalKwargs  =  obj 
91+             if  not  mm .modalities :
92+                 # just return the main dict if there are no modalities. 
93+                 return  dict (mm )
94+ 
95+             # ignore the main dict, it will be re-indexed. 
96+             # Encode a list of MultiModalKwargsItems as plain dicts 
97+             # + special handling for .field. 
98+             # Any tensors *not* indexed by modality will be ignored. 
99+             return  [[{
100+                 "modality" : elem .modality ,
101+                 "key" : elem .key ,
102+                 "data" : self ._encode_nested_tensors (elem .data ),
103+                 "field" : self ._encode_mm_field (elem .field ),
104+             } for  elem  in  item .values ()]
105+                     for  itemlist  in  mm ._items_by_modality .values ()
106+                     for  item  in  itemlist ]
107+ 
68108        if  isinstance (obj , FunctionType ):
69109            # `pickle` is generally faster than cloudpickle, but can have 
70110            # problems serializing methods. 
@@ -77,8 +117,9 @@ def _encode_ndarray(
77117        self , obj : np .ndarray 
78118    ) ->  tuple [str , tuple [int , ...], Union [int , memoryview ]]:
79119        assert  self .aux_buffers  is  not None 
120+         # If the array is non-contiguous, we need to copy it first 
80121        arr_data  =  obj .data  if  obj .data .c_contiguous  else  obj .tobytes ()
81-         if  not  obj .shape  or  obj .nbytes  <  MIN_NOCOPY_BUF_SIZE :
122+         if  not  obj .shape  or  obj .nbytes  <  self . size_threshold :
82123            # Encode small arrays and scalars inline. Using this extension type 
83124            # ensures we can avoid copying when decoding. 
84125            data  =  msgpack .Ext (CUSTOM_TYPE_RAW_VIEW , arr_data )
@@ -92,6 +133,26 @@ def _encode_ndarray(
92133        # backing buffers that we've stashed in `aux_buffers`. 
93134        return  obj .dtype .str , obj .shape , data 
94135
136+     def  _encode_nested_tensors (self , nt : NestedTensors ) ->  Any :
137+         if  isinstance (nt , torch .Tensor ):
138+             return  self ._encode_ndarray (nt .numpy ())
139+         if  isinstance (nt , (int , float )):
140+             # Although it violates NestedTensors type, MultiModalKwargs 
141+             # values are sometimes floats. 
142+             return  nt 
143+         return  [self ._encode_nested_tensors (x ) for  x  in  nt ]
144+ 
145+     def  _encode_mm_field (self , field : BaseMultiModalField ):
146+         # Figure out the factory name for the field type. 
147+         name  =  MMF_CLASS_TO_FACTORY .get (field .__class__ )
148+         if  not  name :
149+             raise  TypeError (f"Unsupported field type: { field .__class__ }  )
150+         # We just need to copy all of the field values in order 
151+         # which will be then used to reconstruct the field. 
152+         field_values  =  (getattr (field , f .name )
153+                         for  f  in  dataclasses .fields (field ))
154+         return  name , * field_values 
155+ 
95156
96157class  MsgpackDecoder :
97158    """Decoder with custom torch tensor and numpy array serialization. 
@@ -126,13 +187,50 @@ def dec_hook(self, t: type, obj: Any) -> Any:
126187                return  self ._decode_ndarray (obj )
127188            if  issubclass (t , torch .Tensor ):
128189                return  torch .from_numpy (self ._decode_ndarray (obj ))
190+             if  issubclass (t , MultiModalKwargs ):
191+                 if  isinstance (obj , list ):
192+                     return  MultiModalKwargs .from_items (
193+                         self ._decode_mm_items (obj ))
194+                 return  MultiModalKwargs ({
195+                     k : self ._decode_nested_tensors (v )
196+                     for  k , v  in  obj .items ()
197+                 })
129198        return  obj 
130199
131200    def  _decode_ndarray (self , arr : Any ) ->  np .ndarray :
132201        dtype , shape , data  =  arr 
133-         buffer  =  self .aux_buffers [data ] if  isinstance (data , int ) else  data 
202+         # Copy from inline representation, otherwise Torch is unhappy since 
203+         # the returned memory is non-writeable. 
204+         buffer  =  self .aux_buffers [data ] if  isinstance (data , int ) \
205+             else  bytearray (data )
134206        return  np .ndarray (buffer = buffer , dtype = np .dtype (dtype ), shape = shape )
135207
208+     def  _decode_mm_items (self , obj : list ) ->  list [MultiModalKwargsItem ]:
209+         decoded_items  =  []
210+         for  item  in  obj :
211+             elems  =  []
212+             for  v  in  item :
213+                 v ["data" ] =  self ._decode_nested_tensors (v ["data" ])
214+                 # Reconstruct the field processor using MultiModalFieldConfig 
215+                 factory_meth_name , * field_args  =  v ["field" ]
216+                 factory_meth  =  getattr (MultiModalFieldConfig ,
217+                                        factory_meth_name )
218+                 v ["field" ] =  factory_meth (None , * field_args ).field 
219+                 elems .append (MultiModalFieldElem (** v ))
220+             decoded_items .append (MultiModalKwargsItem .from_elems (elems ))
221+         return  decoded_items 
222+ 
223+     def  _decode_nested_tensors (self , obj : Any ) ->  NestedTensors :
224+         if  isinstance (obj , (int , float )):
225+             # Although it violates NestedTensors type, MultiModalKwargs 
226+             # values are sometimes floats. 
227+             return  obj 
228+         if  not  isinstance (obj , list ):
229+             raise  TypeError (f"Unexpected NestedTensors contents: { type (obj )}  )
230+         if  obj  and  isinstance (obj [0 ], str ):
231+             return  torch .from_numpy (self ._decode_ndarray (obj ))
232+         return  [self ._decode_nested_tensors (x ) for  x  in  obj ]
233+ 
136234    def  ext_hook (self , code : int , data : memoryview ) ->  Any :
137235        if  code  ==  CUSTOM_TYPE_RAW_VIEW :
138236            return  data 
0 commit comments