@@ -197,12 +197,11 @@ def extract_kv_from_layer(
197197 Assume the shape of the layer is (2, num_pages, page_size, xxx).
198198 """
199199 num_pages , page_size = layer .shape [1 ], layer .shape [2 ]
200- reshaped = layer .reshape (2 , num_pages * page_size , - 1 )
201200 print (f"{ layer .shape = } " )
202- print (f"{ reshaped . shape = } " )
203- print (f"{ slot_mapping } " )
204-
205- return reshaped [:, slot_mapping , ...]
201+ print (f"{ layer . reshape ( 2 , num_pages * page_size , - 1 ) = } " )
202+ print (f"{ slot_mapping . shape = } " )
203+ return layer . reshape ( 2 , num_pages * page_size , - 1 )[:, slot_mapping ,
204+ ...]
206205
207206 connector_metadata = self ._get_connector_metadata ()
208207 assert isinstance (connector_metadata , SharedStorageConnectorMetadata )
@@ -212,8 +211,8 @@ def extract_kv_from_layer(
212211 layer_name , request .token_ids )
213212 kv_cache = extract_kv_from_layer (kv_layer ,
214213 request .slot_mapping )
215- assert False
216- # torch.ops.save_lib.save_safetensors(kv_cache , filename)
214+ tensors = { "kv_cache" : kv_cache . detach (). cpu ()}
215+ safetensors . torch .save_file ( tensors , filename )
217216
218217 def wait_for_save (self ):
219218 return
@@ -366,21 +365,3 @@ def align_to_block_size(num_tokens: int, block_size) -> int:
366365 """Align the number of tokens to the block size.
367366 """
368367 return (num_tokens - 1 ) // block_size * block_size
369-
370-
371- # Register a custom library and print operator
372- import torch
373- from torch .library import Library , impl
374-
375- lib = Library ("save_lib" , "DEF" )
376- lib .define ("save_safetensors(Tensor kv_cache, str filename) -> ()" )
377-
378-
379- @impl (lib , "save_safetensors" , "CompositeExplicitAutograd" )
380- def save_safetensors (kv_cache , filename ):
381- # tensors = {"kv_cache": kv_cache.detach().cpu()}
382- # kv_cache = kv_cache.cpu()
383- # tensors = {"kv_cache": kv_cache}
384- # safetensors.torch.save_file(tensors, filename)
385- a = torch .empty (10 )
386- return
0 commit comments