Skip to content

Commit a73721a

Browse files
updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
1 parent 31d807e commit a73721a

File tree

3 files changed

+10
-28
lines changed

3 files changed

+10
-28
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
rm -rf local_storage/
22
rm output.txt
33

4-
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=6 python3 prefill_example.py
5-
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=6 python3 decode_example.py
4+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=5 python3 prefill_example.py
5+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=5 python3 decode_example.py

vllm/attention/layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def maybe_save_kv_layer_to_connector(
353353
):
354354
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
355355
return
356+
356357
connector = get_kv_transfer_group()
357358

358359
forward_context: ForwardContext = get_forward_context()
@@ -370,7 +371,7 @@ def unified_attention(
370371
value: torch.Tensor,
371372
layer_name: str,
372373
) -> torch.Tensor:
373-
# wait_for_kv_layer_from_connector(layer_name)
374+
wait_for_kv_layer_from_connector(layer_name)
374375

375376
forward_context: ForwardContext = get_forward_context()
376377
attn_metadata = forward_context.attn_metadata

vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,11 @@ def extract_kv_from_layer(
197197
Assume the shape of the layer is (2, num_pages, page_size, xxx).
198198
"""
199199
num_pages, page_size = layer.shape[1], layer.shape[2]
200-
reshaped = layer.reshape(2, num_pages * page_size, -1)
201200
print(f"{layer.shape=}")
202-
print(f"{reshaped.shape=}")
203-
print(f"{slot_mapping}")
204-
205-
return reshaped[:, slot_mapping, ...]
201+
print(f"{layer.reshape(2, num_pages * page_size, -1)=}")
202+
print(f"{slot_mapping.shape=}")
203+
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
204+
...]
206205

207206
connector_metadata = self._get_connector_metadata()
208207
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
@@ -212,8 +211,8 @@ def extract_kv_from_layer(
212211
layer_name, request.token_ids)
213212
kv_cache = extract_kv_from_layer(kv_layer,
214213
request.slot_mapping)
215-
assert False
216-
# torch.ops.save_lib.save_safetensors(kv_cache, filename)
214+
tensors = {"kv_cache": kv_cache.detach().cpu()}
215+
safetensors.torch.save_file(tensors, filename)
217216

218217
def wait_for_save(self):
219218
return
@@ -366,21 +365,3 @@ def align_to_block_size(num_tokens: int, block_size) -> int:
366365
"""Align the number of tokens to the block size.
367366
"""
368367
return (num_tokens - 1) // block_size * block_size
369-
370-
371-
# Register a custom library and print operator
372-
import torch
373-
from torch.library import Library, impl
374-
375-
lib = Library("save_lib", "DEF")
376-
lib.define("save_safetensors(Tensor kv_cache, str filename) -> ()")
377-
378-
379-
@impl(lib, "save_safetensors", "CompositeExplicitAutograd")
380-
def save_safetensors(kv_cache, filename):
381-
# tensors = {"kv_cache": kv_cache.detach().cpu()}
382-
# kv_cache = kv_cache.cpu()
383-
# tensors = {"kv_cache": kv_cache}
384-
# safetensors.torch.save_file(tensors, filename)
385-
a = torch.empty(10)
386-
return

0 commit comments

Comments
 (0)