Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions tests/v1/kv_connector/unit/test_shared_storage_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import asdict
from typing import NamedTuple

from PIL import Image

from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import KVTransferConfig
from vllm.multimodal.utils import encode_image_base64

MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"

SAMPLING_PARAMS = SamplingParams(temperature=0.0, top_k=1, max_tokens=128)

TEXT_PROMPTS = [
"What's in the image(s)? Around 30 words. What's special in 2nd image?",
"The future of AI is",
]


class InputCase(NamedTuple):
text: str
img: list[Image]
expected_len: int
info: str


def _check_path_len(path):
"""Return the latest length in path"""
return len(list(path.iterdir()))


def _list_path(path):
"""Return the list of foldername (hashes generatd) under the path"""
return list(path.iterdir())


def run_test(tmp_path, processor, llm: LLM, question: str,
image_urls: list[Image], expected_len: int, info: str):
"""
One individual test to process the prompt and output base on 1 set of input
Then check if the length in the strorage path matches the expected length
`info` introduces details or purpose of the individual test
"""
print(f"***info: {info}***")
print(
f"**Expected storage path length after llm generate: {expected_len}**")
process_prompt(processor, llm, question, image_urls)

print(f"Path matched expected length: {_check_path_len(tmp_path)}")
print(f"Hashes under the storage path: {_list_path(tmp_path)}")

assert _check_path_len(tmp_path) == expected_len, (
f"Expect storage path length {expected_len} ;",
f"but end up {_check_path_len(tmp_path)} instead. ", f"Info: {info}")


def process_prompt(processor, llm: LLM, question: str,
image_urls: list[Image]):
"""
Form the prompt based on the text and image input, then llm generate output
"""
placeholders = [{
"type": "image_url",
"image_url": {
"url": f"data:image;base64,{encode_image_base64(image_pil)}"
}
} for image_pil in image_urls]

messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
},
]

prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

outputs = llm.generate(
{
"prompt":
prompt,
**({
"multi_modal_data": {
"image": [*image_urls]
}
} if image_urls else {})
},
sampling_params=SAMPLING_PARAMS,
)

print("-" * 50)
print("Output:")
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
print("-" * 50)


def test_shared_storage_connector_hashes(tmp_path):
"""
Tests that SharedStorageConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but
differnt images (same size), or same multiple images but different orders.
"""
# Using tmp_path as the storage path to store KV
print(f"KV storage path at: {str(tmp_path)}")

# Configure the SharedStorageConnector
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": str(tmp_path)})

engine_args = EngineArgs(
model=MODEL_NAME,
max_model_len=8192,
max_num_seqs=1,
kv_transfer_config=kv_transfer_config,
limit_mm_per_prompt={"image": 2},
)

# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401

# Create processor to handle the chat prompt
processor = AutoProcessor.from_pretrained(MODEL_NAME)

# Prepare images for the tests
# Resize to the same size to check hashes correctness
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))

# Make sure that they are not the same picture
assert image_1 != image_2, "The images should not be identical"

# Create the LLM instance
engine_args = asdict(engine_args)
llm = LLM(**engine_args)

# Prepare the input cases
input_cases = [
InputCase(text=TEXT_PROMPTS[0],
img=[image_1],
expected_len=1,
info="image_1 single input the first time."),
InputCase(text=TEXT_PROMPTS[0],
img=[image_2],
expected_len=2,
info=("image_2 single input the first time. "
"It is in same pixel size with image_1, yet it "
"should be able to form a new unique hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[image_1],
expected_len=2,
info=("image_1 single input the 2nd time. "
"It should not form aother new hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[image_2],
expected_len=2,
info=("image_2 single input the 2nd time. "
"It should not form aother new hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[image_1, image_2],
expected_len=3,
info="image_1 with image_2 input the first time."),
InputCase(text=TEXT_PROMPTS[0],
img=[image_2, image_1],
expected_len=4,
info="The image order is swapped. Should form new hash."),
InputCase(text=TEXT_PROMPTS[0],
img=[image_1, image_2],
expected_len=4,
info=("[image_1, image_2] input the 2nd time. "
"It should not form aother new hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[image_2, image_1],
expected_len=4,
info=("[image_2, image_1] input the 2nd time. "
"It should not form aother new hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[],
expected_len=5,
info="Pure text input test as a case-control"),
InputCase(text=TEXT_PROMPTS[0],
img=[],
expected_len=5,
info="Identical pure text input as a case-control"),
InputCase(text=TEXT_PROMPTS[1],
img=[],
expected_len=6,
info="Another pure text input as a case-control"),
]

# Run tests
for case_id, (text, img, expected_len, info) in enumerate(input_cases):
print("\n", "=" * 25, f"Below running input case: {case_id}", "=" * 25)
run_test(tmp_path, processor, llm, text, img, expected_len, info)

print("All tests passed successfully!")
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ class ReqMeta:
slot_mapping: torch.Tensor
# Is store or load
is_store: bool
mm_hashes: list[str]

@staticmethod
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
is_store: bool) -> "ReqMeta":
is_store: bool, mm_hashes: list[str]) -> "ReqMeta":
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
block_ids_tensor = torch.tensor(block_ids)
Expand All @@ -48,6 +49,7 @@ def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
is_store=is_store,
mm_hashes=mm_hashes,
)


Expand All @@ -64,9 +66,11 @@ def add_request(
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store,
mm_hashes))


class SharedStorageConnector(KVConnectorBase_V1):
Expand Down Expand Up @@ -161,7 +165,7 @@ def inject_kv_into_layer(
forward_context.virtual_engine]

filename = self._generate_filename_debug(
layer_name, request.token_ids)
layer_name, request.token_ids, request.mm_hashes)
kv_cache = safetensors.torch.load_file(
filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache,
Expand Down Expand Up @@ -213,7 +217,7 @@ def extract_kv_from_layer(
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
layer_name, request.token_ids)
layer_name, request.token_ids, request.mm_hashes)
kv_cache = extract_kv_from_layer(kv_layer,
request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
Expand Down Expand Up @@ -291,7 +295,8 @@ def build_connector_meta(
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False)
is_store=False,
mm_hashes=new_req.mm_hashes)
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
Expand All @@ -302,7 +307,8 @@ def build_connector_meta(
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True)
is_store=True,
mm_hashes=new_req.mm_hashes)

cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
Expand Down Expand Up @@ -330,7 +336,8 @@ def build_connector_meta(
meta.add_request(token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False)
is_store=False,
mm_hashes=request.mm_hashes)
total_need_load += 1

assert total_need_load == len(self._requests_need_load)
Expand All @@ -351,20 +358,28 @@ def _found_match_for_request(
len(request.prompt_token_ids) - 1, self._block_size)
foldername = self._generate_foldername_debug(torch.tensor(
request.prompt_token_ids)[:num_tokens_to_check],
request.mm_hashes,
create_folder=False)
return os.path.exists(foldername)

def _generate_foldername_debug(
self,
input_ids: torch.Tensor,
token_ids: torch.Tensor,
mm_hashes: list[str],
create_folder=False,
) -> str:
"""Generate a folder name based on the hash of the bytes of the input
ids.
"""
input_ids_bytes = input_ids.numpy().tobytes()
input_ids_hash = hashlib.md5(input_ids_bytes,
token_bytes = token_ids.numpy().tobytes()
# Add mm_hashes to the bytes being hashed to avoid path traversal and
# to create a canonical key.
if mm_hashes:
mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode('utf-8')
input_ids_hash = hashlib.md5(token_bytes,
usedforsecurity=False).hexdigest()

foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
Expand All @@ -373,12 +388,14 @@ def _generate_foldername_debug(
def _generate_filename_debug(
self,
layer_name: str,
input_ids: torch.Tensor,
token_ids: torch.Tensor,
mm_hashes: list[str],
) -> str:
"""Generate a file name based on the layer name and the hash
of the bytes of the input ids.
"""
foldername = self._generate_foldername_debug(input_ids,
foldername = self._generate_foldername_debug(token_ids,
mm_hashes=mm_hashes,
create_folder=True)
return os.path.join(foldername, f"{layer_name}.safetensors")

Expand Down