[v1] Add __repr__ to KVCacheBlock to avoid recursive print
#14081
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
As KVCacheBlock is a doubly linked list, the default print of one block will recursively print all KVCacheBlock objects in the list, and can cause stack overflow. To avoid this problem, print
prev_block_idandnext_block_idinstead.Reproduce script
from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.request import Requestdef make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
if mm_positions is None:
multi_modal_inputs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10000,
max_model_len=819200,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
common_token_ids = [i for i in range(10) for _ in range(16)]
req0 = make_request("0", common_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(req0, 55, computed_blocks)
manager.free(req0)
req1 = make_request("1", common_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
print("computed_blocks", computed_blocks)