From 12634beffc237b5f89531d29679d0f2251e2d462 Mon Sep 17 00:00:00 2001 From: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com> Date: Thu, 7 Mar 2024 18:03:22 -0500 Subject: [PATCH] Possible fix for conflict between Automated Prefix Caching (#2762) and multi-LoRA support (#1804) (#3263) --- tests/test_cache_block_hashing.py | 46 +++++++++++++++++++++---------- vllm/sequence.py | 3 +- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index c2067e52b59c0..fb541f38f3489 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -2,8 +2,11 @@ Run `pytest tests/test_cache_block_hashing.py`. """ +from typing import List, Optional + import pytest +from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import TokenizerGroup from vllm.sequence import Sequence @@ -36,7 +39,10 @@ def flatten_2d(li): @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("max_num_seqs", [256]) -def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int): +@pytest.mark.parametrize("concurrent_lora_int_ids", + [[None], [1], [None, 1], [None, 1, 2], [1, 2]]) +def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, + concurrent_lora_int_ids: List[Optional[int]]): tokenizer = TokenizerGroup( tokenizer_id="facebook/opt-125m", @@ -48,20 +54,30 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int): hashes = [] for prefix in prefixes: - hashes.append([]) - prompts = [prefix + prompt for prompt in sample_prompts] - seq_id = 0 - for prompt in prompts: - hashes[-1].append([]) - prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - tokenizer.tokenizer.eos_token_id) - - num_blocks = len(prompt_token_ids) // block_size - for idx in range(num_blocks): - hashes[-1][-1].append(seq.hash_of_block(idx)) - - seq_id += 1 + for lora_int_id in concurrent_lora_int_ids: + lora_request = None + + if lora_int_id is not None: + lora_request = LoRARequest( + f"example_lora_{lora_int_id}", + lora_int_id, + f"example/path/to/lora_{lora_int_id}", + ) + + hashes.append([]) + prompts = [prefix + prompt for prompt in sample_prompts] + seq_id = 0 + for prompt in prompts: + hashes[-1].append([]) + prompt_token_ids = tokenizer.encode(prompt) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, + tokenizer.tokenizer.eos_token_id, lora_request) + + num_blocks = len(prompt_token_ids) // block_size + for idx in range(num_blocks): + hashes[-1][-1].append(seq.hash_of_block(idx)) + + seq_id += 1 # Check that hashes made with two prefixes with different first blocks are # different everywhere. diff --git a/vllm/sequence.py b/vllm/sequence.py index c983c7a37a85e..0a7dacd0c0823 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -187,7 +187,8 @@ def hash_of_block(self, logical_idx: int) -> int: # TODO: The current hashing function is O(L^2). We should optimize # this in the future. num_tokens = self.num_hashed_tokens_of_block(logical_idx) - return hash(tuple(self.data.get_token_ids()[0:num_tokens])) + return hash( + (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id)) def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size