@@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
8080
8181    def  enc_hook (self , obj : Any ) ->  Any :
8282        if  isinstance (obj , torch .Tensor ):
83-             return  self ._encode_ndarray (obj . numpy () )
83+             return  self ._encode_tensor (obj )
8484
8585        # Fall back to pickle for object or void kind ndarrays. 
8686        if  isinstance (obj , np .ndarray ) and  obj .dtype .kind  not  in 'O' , 'V' ):
@@ -133,9 +133,27 @@ def _encode_ndarray(
133133        # backing buffers that we've stashed in `aux_buffers`. 
134134        return  obj .dtype .str , obj .shape , data 
135135
136+     def  _encode_tensor (
137+         self , obj : torch .Tensor 
138+     ) ->  tuple [str , tuple [int , ...], Union [int , memoryview ]]:
139+         assert  self .aux_buffers  is  not None 
140+         # this creates a copy of the tensor if it's not already contiguous 
141+         obj  =  obj .contiguous ()
142+         #  view the tensor as a 1D array of bytes 
143+         arr  =  obj .view ((obj .numel (), )).view (torch .uint8 ).numpy ()
144+         if  obj .nbytes  <  self .size_threshold :
145+             # Smaller tensors are encoded inline, just like ndarrays. 
146+             data  =  msgpack .Ext (CUSTOM_TYPE_RAW_VIEW , arr .data )
147+         else :
148+             # Otherwise encode index of backing buffer to avoid copy. 
149+             data  =  len (self .aux_buffers )
150+             self .aux_buffers .append (arr .data )
151+         dtype  =  str (obj .dtype )[6 :]  # remove 'torch.' prefix 
152+         return  dtype , obj .shape , data 
153+ 
136154    def  _encode_nested_tensors (self , nt : NestedTensors ) ->  Any :
137155        if  isinstance (nt , torch .Tensor ):
138-             return  self ._encode_ndarray (nt . numpy () )
156+             return  self ._encode_tensor (nt )
139157        if  isinstance (nt , (int , float )):
140158            # Although it violates NestedTensors type, MultiModalKwargs 
141159            # values are sometimes floats. 
@@ -186,7 +204,7 @@ def dec_hook(self, t: type, obj: Any) -> Any:
186204            if  issubclass (t , np .ndarray ):
187205                return  self ._decode_ndarray (obj )
188206            if  issubclass (t , torch .Tensor ):
189-                 return  torch . from_numpy ( self ._decode_ndarray (obj ) )
207+                 return  self ._decode_tensor (obj )
190208            if  issubclass (t , MultiModalKwargs ):
191209                if  isinstance (obj , list ):
192210                    return  MultiModalKwargs .from_items (
@@ -199,11 +217,24 @@ def dec_hook(self, t: type, obj: Any) -> Any:
199217
200218    def  _decode_ndarray (self , arr : Any ) ->  np .ndarray :
201219        dtype , shape , data  =  arr 
202-         # Copy from inline representation, otherwise Torch is unhappy since 
203-         # the returned memory is non-writeable. 
220+         # zero-copy decode. We assume the ndarray will not be kept around, 
221+         # as it now locks the whole received message buffer in memory. 
222+         buffer  =  self .aux_buffers [data ] if  isinstance (data , int ) else  data 
223+         return  np .ndarray (buffer = buffer , dtype = np .dtype (dtype ), shape = shape )
224+ 
225+     def  _decode_tensor (self , arr : Any ) ->  torch .Tensor :
226+         dtype , shape , data  =  arr 
227+         # Copy from inline representation, to decouple the memory storage 
228+         # of the message from the original buffer. And also make Torch 
229+         # not complain about a readonly memoryview. 
204230        buffer  =  self .aux_buffers [data ] if  isinstance (data , int ) \
205231            else  bytearray (data )
206-         return  np .ndarray (buffer = buffer , dtype = np .dtype (dtype ), shape = shape )
232+         # Create numpy wrapper around the bytes 
233+         arr  =  np .ndarray (buffer = buffer , dtype = np .uint8 , shape = (len (buffer ), ))
234+         torch_dtype  =  getattr (torch , dtype )
235+         assert  isinstance (torch_dtype , torch .dtype )
236+         # Convert back to proper shape & type 
237+         return  torch .from_numpy (arr ).view (torch_dtype ).view (shape )
207238
208239    def  _decode_mm_items (self , obj : list ) ->  list [MultiModalKwargsItem ]:
209240        decoded_items  =  []
@@ -228,7 +259,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
228259        if  not  isinstance (obj , list ):
229260            raise  TypeError (f"Unexpected NestedTensors contents: { type (obj )}  )
230261        if  obj  and  isinstance (obj [0 ], str ):
231-             return  torch . from_numpy ( self ._decode_ndarray (obj ) )
262+             return  self ._decode_tensor (obj )
232263        return  [self ._decode_nested_tensors (x ) for  x  in  obj ]
233264
234265    def  ext_hook (self , code : int , data : memoryview ) ->  Any :
0 commit comments