|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import logging |
| 4 | +import random |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import TYPE_CHECKING |
| 7 | + |
| 8 | +import torch |
| 9 | + |
| 10 | +from vllm.config import VllmConfig |
| 11 | +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( |
| 12 | + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) |
| 13 | +from vllm.v1.core.sched.output import SchedulerOutput |
| 14 | + |
| 15 | +if TYPE_CHECKING: |
| 16 | + from vllm.attention.backends.abstract import AttentionMetadata |
| 17 | + from vllm.v1.core.kv_cache_manager import KVCacheBlocks |
| 18 | + from vllm.v1.request import Request |
| 19 | + |
| 20 | +logger = logging.getLogger() |
| 21 | +logging.basicConfig(level=logging.INFO) |
| 22 | + |
| 23 | + |
| 24 | +@dataclass |
| 25 | +class RandomDropConnectorMetadata(KVConnectorMetadata): |
| 26 | + req_meta: dict[str, list[int]] |
| 27 | + |
| 28 | + |
| 29 | +class RandomDropConnector(KVConnectorBase_V1): |
| 30 | + """ |
| 31 | + A connector designed for fault tolerance testing by randomly dropping |
| 32 | + kv data during the process of loading or receiving KV cache. |
| 33 | +
|
| 34 | + This class simulates real-world scenarios where requests or data |
| 35 | + might be lost or timeout, allowing developers to test and validate the |
| 36 | + system's ability to handle such failures. |
| 37 | +
|
| 38 | + Attributes: |
| 39 | + finished_recving_kv_req_ids (set[str]): A set of request IDs that |
| 40 | + have completed receiving KV cache data. |
| 41 | + finished_loading_dict (dict[str, int]): A dictionary that tracks |
| 42 | + the actual number of tokens loaded from the remote KV store |
| 43 | + for each completed request. The keys are request IDs, and |
| 44 | + the values are the corresponding token counts. |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): |
| 48 | + super().__init__(vllm_config=vllm_config, role=role) |
| 49 | + |
| 50 | + self.failure_request: list[str] = [] |
| 51 | + self._reqs_need_recv: dict[str, list[int]] = {} |
| 52 | + self._finish_load: dict[str, int] = {} |
| 53 | + |
| 54 | + self.chunk_size = 256 |
| 55 | + |
| 56 | + ############################################################ |
| 57 | + # Scheduler Side Methods |
| 58 | + ############################################################ |
| 59 | + |
| 60 | + def get_num_new_matched_tokens( |
| 61 | + self, request: "Request", |
| 62 | + num_computed_tokens: int) -> tuple[int, bool]: |
| 63 | + if request.request_id in self.failure_request: |
| 64 | + self.failure_request.remove(request.request_id) |
| 65 | + return 0, False |
| 66 | + num_external_hit_tokens = request.num_prompt_tokens - 1 |
| 67 | + logger.info( |
| 68 | + "request %s num_prompt_tokens %d num_external_hit_tokens %d", |
| 69 | + request.request_id, request.num_prompt_tokens, |
| 70 | + num_external_hit_tokens) |
| 71 | + return num_external_hit_tokens, True |
| 72 | + |
| 73 | + def update_state_after_alloc(self, request: "Request", |
| 74 | + blocks: "KVCacheBlocks", |
| 75 | + num_external_tokens: int): |
| 76 | + if num_external_tokens > 0: |
| 77 | + self._reqs_need_recv[ |
| 78 | + request. |
| 79 | + request_id] = request.prompt_token_ids[:num_external_tokens] |
| 80 | + |
| 81 | + def build_connector_meta( |
| 82 | + self, |
| 83 | + scheduler_output: SchedulerOutput, |
| 84 | + ) -> KVConnectorMetadata: |
| 85 | + req_meta = self._reqs_need_recv.copy() |
| 86 | + self._reqs_need_recv.clear() |
| 87 | + return RandomDropConnectorMetadata(req_meta) |
| 88 | + |
| 89 | + def add_failure_request(self, request: "Request"): |
| 90 | + self.failure_request.append(request.request_id) |
| 91 | + |
| 92 | + def start_load_kv(self, forward_context, **kwargs) -> None: |
| 93 | + for request_id, hit_tokens in self._get_connector_metadata( |
| 94 | + ).req_meta.items(): |
| 95 | + num_actual_load_tokens = self.load_kv(request_id, hit_tokens) |
| 96 | + logger.info("request %s hit_tokens %d num_actual_load_tokens %d", |
| 97 | + request_id, len(hit_tokens), num_actual_load_tokens) |
| 98 | + self._finish_load[request_id] = num_actual_load_tokens |
| 99 | + |
| 100 | + def wait_for_layer_load(self, layer_name: str) -> None: |
| 101 | + pass |
| 102 | + |
| 103 | + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, |
| 104 | + attn_metadata: "AttentionMetadata", **kwargs) -> None: |
| 105 | + pass |
| 106 | + |
| 107 | + def wait_for_save(self): |
| 108 | + pass |
| 109 | + |
| 110 | + def load_kv(self, request_id, hit_tokens): |
| 111 | + num_actual_load_tokens = random.randint(0, len(hit_tokens)) |
| 112 | + return num_actual_load_tokens |
| 113 | + |
| 114 | + def get_finished_loading(self) -> dict[str, int]: |
| 115 | + if not self._finish_load: |
| 116 | + return {} |
| 117 | + finished_loading = self._finish_load.copy() |
| 118 | + self._finish_load.clear() |
| 119 | + |
| 120 | + return finished_loading |
0 commit comments