From 69bc21cc460e8432ef53cde13a06ee915e091977 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Sun, 27 Jul 2025 08:48:24 +0800 Subject: [PATCH] [Misc]Remove PD v0 code Signed-off-by: wangxiyuan --- .../disaggregated_prefill_offline.py | 141 ------ examples/disaggregated_prefill/dp_proxy.py | 466 ----------------- .../disaggregated_prefill/find_device_ips.py | 69 --- .../p2p_disaggrefated_prefill_proxy.py | 196 -------- .../run_decode_server.sh | 37 -- .../run_prefill_server.sh | 37 -- .../kv_transfer/test_simple_buffer.py | 71 --- .../kv_transfer/test_simple_connector.py | 146 ------ .../kv_transfer/test_simple_pipe.py | 145 ------ vllm_ascend/distributed/__init__.py | 8 - .../distributed/kv_transfer/__init__.py | 0 .../distributed/kv_transfer/simple_buffer.py | 207 -------- .../kv_transfer/simple_connector.py | 379 -------------- .../distributed/kv_transfer/simple_pipe.py | 207 -------- vllm_ascend/distributed/kv_transfer/utils.py | 40 -- .../distributed/llmdatadist_connector.py | 470 ------------------ 16 files changed, 2619 deletions(-) delete mode 100644 examples/disaggregated_prefill/disaggregated_prefill_offline.py delete mode 100644 examples/disaggregated_prefill/dp_proxy.py delete mode 100644 examples/disaggregated_prefill/find_device_ips.py delete mode 100644 examples/disaggregated_prefill/p2p_disaggrefated_prefill_proxy.py delete mode 100644 examples/disaggregated_prefill/run_decode_server.sh delete mode 100644 examples/disaggregated_prefill/run_prefill_server.sh delete mode 100644 tests/ut/distributed/kv_transfer/test_simple_buffer.py delete mode 100644 tests/ut/distributed/kv_transfer/test_simple_connector.py delete mode 100644 tests/ut/distributed/kv_transfer/test_simple_pipe.py delete mode 100644 vllm_ascend/distributed/kv_transfer/__init__.py delete mode 100644 vllm_ascend/distributed/kv_transfer/simple_buffer.py delete mode 100644 vllm_ascend/distributed/kv_transfer/simple_connector.py delete mode 100644 vllm_ascend/distributed/kv_transfer/simple_pipe.py delete mode 100644 vllm_ascend/distributed/kv_transfer/utils.py delete mode 100644 vllm_ascend/distributed/llmdatadist_connector.py diff --git a/examples/disaggregated_prefill/disaggregated_prefill_offline.py b/examples/disaggregated_prefill/disaggregated_prefill_offline.py deleted file mode 100644 index ea131034b5..0000000000 --- a/examples/disaggregated_prefill/disaggregated_prefill_offline.py +++ /dev/null @@ -1,141 +0,0 @@ -""" - This file demonstrates the example usage of disaggregated prefilling - We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode), - and then transfer the KV cache between them. - prompy_device_ips denotes device ip of NPU 0,1 - decode_device_ips denotes device ip of NPU 2,3 - The device ips of all NPUs in current server can be found through - examples/disaggregated_prefill/find_device_ips.py - """ -import multiprocessing as mp -import os -import time -from multiprocessing import Event, Process - -os.environ["VLLM_USE_MODELSCOPE"] = "True" -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - -kv_connector_extra_config = { - "prefill_device_ips": ["1.2.3.1", "1.2.3.2"], - "decode_device_ips": ["1.2.3.9", "1.2.3.10"], - "llmdatadist_comm_port": 26000, -} - - -def clean_up(): - import gc - - import torch - from vllm.distributed.parallel_state import ( - destroy_distributed_environment, destroy_model_parallel) - destroy_model_parallel() - destroy_distributed_environment() - gc.collect() - torch.npu.empty_cache() - - -def run_prefill(prefill_done, process_close): - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" - - from vllm import LLM, SamplingParams - from vllm.config import KVTransferConfig - - prompts = [ - "Hello, how are you today?", "Hi, what is your name?", - "Tell me a very long story.", "what is your favourite book?" - ] - sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"AscendSimpleConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}' - ) - global kv_connector_extra_config - ktc.kv_connector_extra_config = kv_connector_extra_config - llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", - kv_transfer_config=ktc, - max_model_len=2000, - gpu_memory_utilization=0.8, - tensor_parallel_size=2) - - llm.generate(prompts, sampling_params) - print("Prefill node is finished.") - prefill_done.set() - - # To keep the prefill node running in case the decode node is not done; - # otherwise, the script might exit prematurely, causing incomplete decoding. - try: - while not process_close.is_set(): - time.sleep(1) - except KeyboardInterrupt: - print("Script stopped by user.") - finally: - print("Cleanup prefill resources") - del llm - clean_up() - - -def run_decode(prefill_done): - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3" - - from vllm import LLM, SamplingParams - from vllm.config import KVTransferConfig - - prompts = [ - "Hello, how are you today?", - "Hi, what is your name?", - ] - sampling_params = SamplingParams(temperature=0, top_p=0.95) - - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"AscendSimpleConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}' - ) - global kv_connector_extra_config - ktc.kv_connector_extra_config = kv_connector_extra_config - llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", - kv_transfer_config=ktc, - max_model_len=2000, - gpu_memory_utilization=0.8, - tensor_parallel_size=2) - - # Wait for the producer to start the consumer - print("Waiting for prefill node to finish...") - prefill_done.wait() - - # At this point when the prefill_done is set, the kv-cache should have been - # transferred to this decode node, so we can start decoding. - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - del llm - clean_up() - - -if __name__ == "__main__": - mp.get_context('spawn') - - prefill_done = Event() - process_close = Event() - prefill_process = Process(target=run_prefill, - args=( - prefill_done, - process_close, - )) - decode_process = Process(target=run_decode, args=(prefill_done, )) - - # Start prefill node - prefill_process.start() - - # Start decode node - decode_process.start() - - # Terminate the prefill node when decode is finished - decode_process.join() - - # Terminate prefill process - process_close.set() - prefill_process.join() - prefill_process.terminate() - print("All process done!") diff --git a/examples/disaggregated_prefill/dp_proxy.py b/examples/disaggregated_prefill/dp_proxy.py deleted file mode 100644 index 415e981343..0000000000 --- a/examples/disaggregated_prefill/dp_proxy.py +++ /dev/null @@ -1,466 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import asyncio -import copy -import logging -import os -import threading -import time -import uuid - -import aiohttp -import msgpack # type: ignore -import zmq -from quart import Quart, make_response, request - -os.environ["VLLM_USE_MODELSCOPE"] = "True" -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - -DP_PROXY_HTTP_PORT = 10004 -DP_PROXY_ZMQ_REG_PORT = 30006 -DP_PROXY_ZMQ_NOTIFY_PORT = 30005 - -PD_PROXY_ADDRESS = "127.0.0.1:30002" - -MY_HTTP_ADDRESS = f"127.0.0.1:{DP_PROXY_HTTP_PORT}" -MY_ZMQ_ADDRESS_PLACEHOLDER = f"127.0.0.1:{DP_PROXY_ZMQ_REG_PORT}" - -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -TIME_INTERVAL_FOR_IDLE_RUN = 5e-4 -DP_SIZE = 2 - -dp_instances: dict[str, bool] = {} -dp_cv = threading.Condition() -round_robin_index = 0 -_idle_send_loop = None - - -def make_idle_request(): - # Same as before - data = { - "prompt": "hi", - "max_tokens": 1, - "temperature": 0, - } - return data - - -def random_uuid() -> str: - return str(uuid.uuid4().hex) - - -async def send_idle_token_to_client(schedule_dict): - for key, value in schedule_dict.items(): - if value: - continue - request_received_id = random_uuid() - idle_request_data = make_idle_request() - forward_request_id = f"dp_idle_{key}_{request_received_id}" - target_url = f'http://{key}/v1/completions' - logger.debug( - f"DP Decode Proxy: Sending idle token to D node {key} at {target_url}" - ) - generator = forward_request_internal(target_url, idle_request_data, - forward_request_id) - try: - async for response in generator: - logger.debug( - f"DP Decode Proxy: Idle Request {request_received_id}: response from {key}, got response: {response}" - ) - except Exception as e: - logger.warning( - f"DP Decode Proxy: Error sending idle token to {key}: {e}") - - -def metadata_collect_trigger(poller, router_socket): - global dp_instances - global dp_cv - global _idle_send_loop - with dp_cv: - dp_cv.wait() - while True: - try: - schedule_dict = copy.deepcopy(dp_instances) - for key in schedule_dict.keys(): - schedule_dict[key] = False - first_start = False - start_time = None - while not all(schedule_dict.values()): - if start_time is not None: - time_interval = time.time() - start_time - logger.debug("check time interval: ", time_interval) - if time_interval > TIME_INTERVAL_FOR_IDLE_RUN: - logger.debug( - "exceeds max time interval send idle token to client" - ) - # Send idle token to client in case of single dp rank run solo and block on the CCL part - asyncio.run_coroutine_threadsafe( - send_idle_token_to_client(schedule_dict), - _idle_send_loop) # type: ignore - # Note: Reset start time prevent consistently send idle token to client - # We only reset start time here, for some of the client may loss the idle token send from this proxy - # and we only exit this while loop when we make sure all the client are exactly start inference in this - # step - start_time = time.time() - socks = dict(poller.poll(timeout=500)) # timeout in 500ms - if socks: - logger.debug("receive socks from monitor threads: ", socks) - if router_socket in socks: - messages = router_socket.recv_multipart() - try: - # {"info": "notify_step", "http_address": ""} - for message in messages: - data = msgpack.loads(message) - http_addr = None - logger.debug(f"receive message {data}") - if data.get("info") == "notify_step": - http_addr = data.get("http_address") - if http_addr in schedule_dict.keys(): - schedule_dict[http_addr] = True - logger.debug("set first time") - if not first_start: - logger.debug("record start time") - first_start = True - start_time = time.time() - else: - logger.warning("Unrecognize http address") - else: - logger.warning( - "Got unrecognize info type! We only accept notify step info yet" - ) - except (msgpack.UnpackException, TypeError, KeyError) as e: - logger.error( - f"Error processing message from {http_addr}: {e}. Message: {data}" - ) - except zmq.ZMQError as e: # type: ignore - logger.error(f"ZMQ Error in monitor thread: {e}") - if e.errno == zmq.ETERM: # type: ignore - logger.error( - "Monitor thread terminating due to context termination.") - break - time.sleep(1) - except Exception as e: - logger.error(f"Unexpected error in monitor thread: {e}") - import traceback - traceback.print_exc() - time.sleep(1) - - -def _listen_for_d_register(poller, router_socket): - global dp_instances - global dp_cv - global DP_SIZE - logger.info( - f"DP Decode Proxy: D Node ZMQ Listener started on ROUTER port {DP_PROXY_ZMQ_REG_PORT}" - ) - - while True: - try: - socks = dict(poller.poll(timeout=1000)) - if router_socket in socks: - remote_id, message = router_socket.recv_multipart() - try: - data = msgpack.loads(message) - if data.get("type") == "DP": - http_addr = data.get("http_address") - zmq_addr = data.get("zmq_address") - if http_addr: - with dp_cv: - if http_addr not in dp_instances: - logger.info( - f"DP Decode Proxy: Registering D Node instance: http={http_addr}, zmq={zmq_addr}" - ) - dp_instances[http_addr] = True - if len(dp_instances) >= DP_SIZE: - logger.info( - f"DP Decode Proxy: Reached expected D Node count ({DP_SIZE}). Notifying metadata collector." - ) - dp_cv.notify_all() - else: - pass - else: - logger.warning( - f"DP Decode Proxy: Received D Node registration from {remote_id.decode()} without http_address. Data: {data}" - ) - else: - logger.warning( - f"DP Decode Proxy: Received message with unexpected type from {remote_id.decode()}. Type: {data.get('type')}, Data: {data}" - ) - - except (msgpack.UnpackException, TypeError, KeyError) as e: - logger.error( - f"DP Decode Proxy: Error processing D Node registration from {remote_id.decode()}: {e}. Message: {message}" - ) - except Exception as e: - logger.error( - f"DP Decode Proxy: Unexpected error processing D Node registration from {remote_id.decode()}: {e}" - ) - - except zmq.ZMQError as e: # type: ignore - logger.error( - f"DP Decode Proxy: ZMQ Error in D Node listener thread: {e}") - if e.errno == zmq.ETERM: # type: ignore - logger.info( - "DP Decode Proxy: D Node Listener thread terminating.") - break - time.sleep(1) - except Exception as e: - logger.error( - f"DP Decode Proxy: Unexpected error in D Node listener thread: {e}" - ) - import traceback - traceback.print_exc() - time.sleep(1) - - -def _register_to_pd_proxy(pd_proxy_zmq_addr, my_http_addr, my_zmq_addr): - context = None - sock = None - while True: - try: - if context is None: - context = zmq.Context() # type: ignore - if sock is None: - sock = context.socket(zmq.DEALER) # type: ignore - identity = f"dp_proxy_{my_http_addr}".encode('utf-8') - sock.setsockopt(zmq.IDENTITY, identity) # type: ignore - sock.setsockopt(zmq.LINGER, 0) # type: ignore - logger.info( - f"DP Decode Proxy: Attempting to connect to PD Proxy at {pd_proxy_zmq_addr}..." - ) - sock.connect(f"tcp://{pd_proxy_zmq_addr}") - logger.info( - f"DP Decode Proxy: Connected to PD Proxy at {pd_proxy_zmq_addr}." - ) - - data = { - "type": "D", - "http_address": my_http_addr, - "zmq_address": my_zmq_addr - } - logger.debug( - f"DP Decode Proxy: Sending registration/heartbeat to PD Proxy: {data}" - ) - sock.send(msgpack.dumps(data)) - time.sleep(5) - - except zmq.ZMQError as e: # type: ignore - logger.error( - f"DP Decode Proxy: ZMQ Error connecting/sending to PD Proxy ({pd_proxy_zmq_addr}): {e}" - ) - if sock: - sock.close() - sock = None - time.sleep(10) - except Exception as e: - logger.error( - f"DP Decode Proxy: Unexpected error in PD Proxy registration thread: {e}" - ) - import traceback - traceback.print_exc() - if sock: - sock.close() - sock = None - time.sleep(10) - finally: - pass - - -def start_zmq_thread(hostname, port, socket_type, target_func, thread_name): - """Generic ZMQ thread starter for ROUTER or PULL.""" - if not hostname: - hostname = "0.0.0.0" - context = zmq.Context.instance() # type: ignore - socket = context.socket(socket_type) - socket.setsockopt(zmq.LINGER, 0) # type: ignore - try: - socket.bind(f"tcp://{hostname}:{port}") - except zmq.ZMQError as e: # type: ignore - logger.error( - f"DP Decode Proxy: Error binding ZMQ {socket_type} socket to tcp://{hostname}:{port}: {e}" - ) - socket.close() - raise - - poller = zmq.Poller() # type: ignore - poller.register(socket, zmq.POLLIN) # type: ignore - - thread = threading.Thread(target=target_func, - args=(poller, socket), - daemon=True, - name=thread_name) - thread.start() - return thread, socket - - -def start_thread_with_event_loop(): - global _idle_send_loop - asyncio.set_event_loop(_idle_send_loop) - _idle_send_loop.run_forever() # type: ignore - - -async def forward_request_internal(url, data, request_id): - try: - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - headers = { - "Authorization": - f"Bearer {os.environ.get('OPENAI_API_KEY', '')}", - "X-Request-Id": request_id, - "Content-Type": "application/json" - } - async with session.post(url=url, json=data, - headers=headers) as response: - if response.status == 200: - async for chunk_bytes in response.content.iter_chunked( - 1024): - yield chunk_bytes - else: - error_content = await response.read() - logger.warning( - f"DP Decode Proxy: Error from D node {url} (status {response.status}): {error_content.decode(errors='ignore')}" - ) - yield error_content - - except aiohttp.ClientError as e: - logger.warning( - f"DP Decode Proxy: Error forwarding request {request_id} to D node {url}: {e}" - ) - error_msg = f"Failed to connect or communicate with D node at {url}: {e}".encode( - 'utf-8') - yield error_msg - - -AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) -app = Quart(__name__) - - -@app.route('/v1/completions', methods=['POST']) -async def handle_request(): - global dp_instances - global dp_cv - global round_robin_index - - request_received_id = request.headers.get("X-Request-Id") - if not request_received_id: - fallback_id = f"dp_fallback_{random_uuid()}" - logger.warning( - f"DP Decode Proxy: Received request without X-Request-Id header. Using fallback ID: {fallback_id}" - ) - request_received_id = fallback_id - else: - logger.info( - f"DP Decode Proxy: Received request from PD Proxy, using propagated ID: {request_received_id}" - ) - - try: - original_request_data = await request.get_json() - if not original_request_data: - return await make_response("Request body must be valid JSON.", 400) - - target_addr = None - with dp_cv: - if not dp_instances: - logger.warning( - f"DP Decode Proxy: Request {request_received_id}: No D Node instances available/registered." - ) - return await make_response("No Decode instances available.", - 503) - - dp_addresses = list(dp_instances.keys()) - if not dp_addresses: - logger.error( - f"DP Decode Proxy: Request {request_received_id}: Internal error - dp_instances populated but list is empty." - ) - return await make_response("Internal Server Error", 500) - - current_selection_index = round_robin_index % len(dp_addresses) - target_addr = dp_addresses[current_selection_index] - round_robin_index += 1 - - logger.info( - f"DP Decode Proxy: Request {request_received_id}: Routing Decode to D Node {target_addr} (Index {current_selection_index})" - ) - - target_url = f'http://{target_addr}/v1/completions' - - generator = forward_request_internal(target_url, original_request_data, - request_received_id) - - response = await make_response(generator) - response.timeout = None - - if original_request_data.get("stream", False): - response.headers['Content-Type'] = 'text/event-stream' - response.headers['Cache-Control'] = 'no-cache' - else: - response.headers['Content-Type'] = 'application/json' - - logger.debug( - f"DP Decode Proxy: Request {request_received_id}: Streaming response from D node {target_addr}" - ) - return response - - except Exception as e: - logger.error( - f"DP Decode Proxy: Error handling request {request_received_id}: {e}" - ) - return await make_response("Internal Server Error", 500) - - -if __name__ == '__main__': - d_listener_thread, d_reg_socket = start_zmq_thread( - "0.0.0.0", - DP_PROXY_ZMQ_REG_PORT, - zmq.ROUTER, # type: ignore - _listen_for_d_register, # type: ignore - "DP_DNodeListenerThread") - - metadata_thread, notify_socket = start_zmq_thread( - "0.0.0.0", - DP_PROXY_ZMQ_NOTIFY_PORT, - zmq.PULL, # type: ignore - metadata_collect_trigger, - "DP_MetadataMonitorThread") - - _idle_send_loop = asyncio.new_event_loop() - idle_loop_thread = threading.Thread(target=start_thread_with_event_loop, - daemon=True, - name="DP_IdleSendLoopThread") - idle_loop_thread.start() - - pd_register_thread = threading.Thread(target=_register_to_pd_proxy, - args=(PD_PROXY_ADDRESS, - MY_HTTP_ADDRESS, - MY_ZMQ_ADDRESS_PLACEHOLDER), - daemon=True, - name="DP_PDRegisterThread") - pd_register_thread.start() - - logger.info( - f"DP Decode Proxy: Starting Quart web server on http://0.0.0.0:{DP_PROXY_HTTP_PORT}" - ) - zmq_context = zmq.Context.instance() # type: ignore - try: - app.run(host='0.0.0.0', port=DP_PROXY_HTTP_PORT) - except KeyboardInterrupt: - logger.info("DP Decode Proxy: KeyboardInterrupt received, stopping...") - except Exception as e: - logger.error(f"DP Decode Proxy: Failed to run Quart server: {e}") - finally: - logger.info("DP Decode Proxy: Shutting down...") - if _idle_send_loop and _idle_send_loop.is_running(): - logger.info("DP Decode Proxy: Stopping idle send loop...") - _idle_send_loop.call_soon_threadsafe(_idle_send_loop.stop) - - if d_reg_socket: - d_reg_socket.close() - if notify_socket: - notify_socket.close() - if zmq_context: - zmq_context.term() - - logger.info("DP Decode Proxy: Shutdown complete.") diff --git a/examples/disaggregated_prefill/find_device_ips.py b/examples/disaggregated_prefill/find_device_ips.py deleted file mode 100644 index 48fd7b9d32..0000000000 --- a/examples/disaggregated_prefill/find_device_ips.py +++ /dev/null @@ -1,69 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# Adapted from vllm-project/vllm/examples/offline_inference/basic.py -# -""" - This file provides a function to obtain ips of all NPU Devices in current machine. -""" - -import os -import re -import subprocess - -import vllm_ascend.envs as envs - -# Get all device ips using hccn_tool -HCCN_TOOL_PATH = envs.HCCN_PATH - - -def get_device_ips(): - npu_info = subprocess.run(['npu-smi', 'info', '-m'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True) - if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH): - raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.") - - # ‌Extract NPU IDs for all Ascend devices (excluding Mcu rows) - device_ids = [] - for line in npu_info.stdout.strip().split('\n'): - match = re.match(r'^\s*(\d+)\s+\d+\s+\d+\s+Ascend', line) - if match: - device_ids.append(int(match.group(1))) - - if not device_ids: - raise RuntimeError( - "Cannot parse any valid device ID from npu-smi output.") - - device_ip_list = [] - for device_id in device_ids: - cmd = [HCCN_TOOL_PATH, '-i', str(device_id), '-ip', '-g'] - device_ip_info = subprocess.run(cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True) - ip_match = re.search(r'ipaddr:(.*)', device_ip_info.stdout) - if not ip_match: - raise RuntimeError( - f"Cannot parse IP from hccn_tool for device {device_id}") - device_ip = ip_match.group(1).strip() - device_ip_list.append(device_ip) - - return device_ip_list - - -print(get_device_ips()) diff --git a/examples/disaggregated_prefill/p2p_disaggrefated_prefill_proxy.py b/examples/disaggregated_prefill/p2p_disaggrefated_prefill_proxy.py deleted file mode 100644 index 5baa355a05..0000000000 --- a/examples/disaggregated_prefill/p2p_disaggrefated_prefill_proxy.py +++ /dev/null @@ -1,196 +0,0 @@ -import os -import socket -import threading -import uuid - -import aiohttp -import msgpack # type: ignore -import zmq -from quart import Quart, make_response, request - -os.environ["VLLM_USE_MODELSCOPE"] = "True" -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - -prefill_instances: dict[str, str] = {} # http_address: zmq_address -decode_instances: dict[str, str] = {} # http_address: zmq_address - -prefill_cv = threading.Condition() -decode_cv = threading.Condition() - - -def _listen_for_register(poller, router_socket): - while True: - socks = dict(poller.poll()) - if router_socket in socks: - remote_address, message = router_socket.recv_multipart() - # data: {"type": "P", "http_address": "ip:port", - # "zmq_address": "ip:port"} - data = msgpack.loads(message) - if data["type"] == "P": - global prefill_instances - global prefill_cv - with prefill_cv: - prefill_instances[ - data["http_address"]] = data["zmq_address"] - print( - "Get a prefill register with http_addr %s and zmq_addr %s", - data["http_address"], - data["zmq_address"], - ) - elif data["type"] == "D": - global decode_instances - global decode_cv - with decode_cv: - decode_instances[ - data["http_address"]] = data["zmq_address"] - print( - "Get a decode register with http_addr %s and zmq_addr %s", - data["http_address"], - data["zmq_address"], - ) - else: - print( - "Unexpected, Received message from %s, data: %s", - remote_address, - data, - ) - - -def start_service_discovery(hostname, port): - if not hostname: - hostname = socket.gethostname() - if port == 0: - raise ValueError("Port cannot be 0") - - context = zmq.Context() # type: ignore - router_socket = context.socket(zmq.ROUTER) # type: ignore - router_socket.bind(f"tcp://{hostname}:{port}") - - poller = zmq.Poller() # type: ignore - poller.register(router_socket, zmq.POLLIN) # type: ignore - - _listener_thread = threading.Thread(target=_listen_for_register, - args=[poller, router_socket], - daemon=True) - _listener_thread.start() - return _listener_thread - - -AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) - -app = Quart(__name__) - - -def random_uuid() -> str: - return str(uuid.uuid4().hex) - - -async def forward_request(url, data, request_id): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id, - } - async with session.post(url=url, json=data, - headers=headers) as response: - if response.status == 200: - async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes - - -@app.route("/v1/completions", methods=["POST"]) -async def handle_request(): - try: - original_request_data = await request.get_json() - - prefill_request = original_request_data.copy() - # change max_tokens = 1 to let it only do prefill - prefill_request["max_tokens"] = 1 - - global prefill_instances - global prefill_cv - with prefill_cv: - if len(prefill_instances) > 1: - print( - "Found more than 1 Prefill instances. Currently we only support 1P1D, so only" - f"the first Prefill instance({list(prefill_instances.keys())[0]}) will be used!" - ) - if len(prefill_instances) == 0: - res_str = ( - "No Prefill instances has been registered to proxy. Please confirm that you have successfully" - " and correctly started a Prefill vLLM instance.") - print(res_str) - response = await make_response(res_str) - return response - # prefill_addr, prefill_zmq_addr = random.choice( - # list(prefill_instances.items())) - prefill_addr, prefill_zmq_addr = list(prefill_instances.items())[0] - print( - "handle_request, prefill_addr: %s, zmq_addr: %s", - prefill_addr, - prefill_zmq_addr, - ) - - global decode_instances - global decode_cv - with decode_cv: - if len(decode_instances) > 1: - print( - "Found more than 1 Decode instances. Currently we only support 1P1D, so only" - f"the first Decode instance({list(decode_instances.keys())[0]}) will be used!" - ) - if len(decode_instances) == 0: - res_str = ( - "No Decode instances has been registered to proxy. Please confirm that you have successfully" - " and correctly started a Decode vLLM instance.") - print(res_str) - response = await make_response(res_str) - return response - # decode_addr, decode_zmq_addr = random.choice( - # list(decode_instances.items())) - decode_addr, decode_zmq_addr = list(decode_instances.items())[0] - print( - "handle_request, decode_addr: %s, zmq_addr: %s", - decode_addr, - decode_zmq_addr, - ) - - request_id = f"___prefill_addr_{prefill_addr}___decode_addr_{decode_addr}_{random_uuid()}" - - # finish prefill - async for _ in forward_request(f"http://{prefill_addr}/v1/completions", - prefill_request, request_id): - continue - - # return decode - generator = forward_request( - f"http://{decode_addr}/v1/completions", - original_request_data, - request_id, - ) - response = await make_response(generator) - response.timeout = None - - return response - - except Exception as e: - import sys - import traceback - - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server") - print(e) - print("".join(traceback.format_exception(*exc_info))) - - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser( - description="args of disaggregated-prefill proxy") - parser.add_argument("--http-port", type=int, default=10001) - parser.add_argument("--register-port", type=int, default=10002) - args = parser.parse_args() - - t = start_service_discovery("0.0.0.0", args.register_port) - app.run(host="0.0.0.0", port=args.http_port) - t.join() diff --git a/examples/disaggregated_prefill/run_decode_server.sh b/examples/disaggregated_prefill/run_decode_server.sh deleted file mode 100644 index a3bbaa189f..0000000000 --- a/examples/disaggregated_prefill/run_decode_server.sh +++ /dev/null @@ -1,37 +0,0 @@ -export HCCL_IF_IP=2.0.0.0 -export GLOO_SOCKET_IFNAME="enp189s0f0" -export TP_SOCKET_IFNAME="enp189s0f0" -export HCCL_SOCKET_IFNAME="enp189s0f0" - -export OMP_PROC_BIND=false -export OMP_NUM_THREADS=100 - -export VLLM_USE_V1=0 - -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - - -vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ - --host 0.0.0.0 \ - --port 20002 \ - --tensor-parallel-size 8 \ - --seed 1024 \ - --served-model-name deepseek \ - --max-model-len 2000 \ - --max-num-batched-tokens 2000 \ - --trust-remote-code \ - --gpu-memory-utilization 0.9 \ - --kv-transfer-config \ - '{"kv_connector": "AscendSimpleConnector", - "kv_buffer_device": "npu", - "kv_role": "kv_consumer", - "kv_parallel_size": 8, - "kv_port":"21001", - "kv_connector_extra_config": - {"prompt_device_ips": ["1.2.3.1", "1.2.3.2", "1.2.3.3", "1.2.3.4", "1.2.3.5", "1.2.3.6", "1.2.3.7", "1.2.3.8"], - "decode_device_ips": ["1.2.3.9", "1.2.3.10", "1.2.3.11", "1.2.3.12", "1.2.3.13", "1.2.3.14", "1.2.3.15", "1.2.3.16"], - "llmdatadist_comm_port": 26000, - "proxy_ip":"3.0.0.0", - "proxy_port":"30001", - "http_port": 10002} - }' diff --git a/examples/disaggregated_prefill/run_prefill_server.sh b/examples/disaggregated_prefill/run_prefill_server.sh deleted file mode 100644 index dc929f8a49..0000000000 --- a/examples/disaggregated_prefill/run_prefill_server.sh +++ /dev/null @@ -1,37 +0,0 @@ -export HCCL_IF_IP=1.0.0.0 -export GLOO_SOCKET_IFNAME="enp189s0f0" -export TP_SOCKET_IFNAME="enp189s0f0" -export HCCL_SOCKET_IFNAME="enp189s0f0" - -export OMP_PROC_BIND=false -export OMP_NUM_THREADS=100 - -export VLLM_USE_V1=0 - -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - - -vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ - --host 0.0.0.0 \ - --port 10002 \ - --tensor-parallel-size 8 \ - --seed 1024 \ - --served-model-name deepseek \ - --max-model-len 2000 \ - --max-num-batched-tokens 2000 \ - --trust-remote-code \ - --gpu-memory-utilization 0.9 \ - --kv-transfer-config \ - '{"kv_connector": "AscendSimpleConnector", - "kv_buffer_device": "npu", - "kv_role": "kv_producer", - "kv_parallel_size": 8, - "kv_port":"11001", - "kv_connector_extra_config": - {"prompt_device_ips": ["1.2.3.1", "1.2.3.2", "1.2.3.3", "1.2.3.4", "1.2.3.5", "1.2.3.6", "1.2.3.7", "1.2.3.8"], - "decode_device_ips": ["1.2.3.9", "1.2.3.10", "1.2.3.11", "1.2.3.12", "1.2.3.13", "1.2.3.14", "1.2.3.15", "1.2.3.16"], - "llmdatadist_comm_port": 26000, - "proxy_ip":"3.0.0.0", - "proxy_port":"30001", - "http_port": 10002} - }' diff --git a/tests/ut/distributed/kv_transfer/test_simple_buffer.py b/tests/ut/distributed/kv_transfer/test_simple_buffer.py deleted file mode 100644 index 1ff81bc3b6..0000000000 --- a/tests/ut/distributed/kv_transfer/test_simple_buffer.py +++ /dev/null @@ -1,71 +0,0 @@ -import zlib -from unittest.mock import MagicMock - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer, - int32_hash) - - -class MockSimplePipe: - - def __init__(self): - self.cluster_id = 0 - self.send_tensor = MagicMock() - self.recv_tensor = MagicMock() - self.deallocate_buffer = MagicMock() - - -class TestSimpleBuffer(TestBase): - - def setUp(self): - self.pipe = MockSimplePipe() - self.buffer = SimpleBuffer(self.pipe) - - def test_int32_hash(self): - self.assertEqual(int32_hash("test"), zlib.adler32(b"test")) - - def test_insert(self): - input_tokens = torch.tensor([1, 2, 3]) - roi = torch.tensor([1, 0, 1]) - key = torch.randn(2, 3, 4, 5) - value = torch.randn(2, 3, 4, 5) - hidden = torch.randn(3, 6) - - self.buffer.num_layers = 2 - self.buffer.num_heads = 4 - self.buffer.head_size = 5 - self.buffer.hidden_size = 6 - self.buffer.dtype = torch.float32 - - self.buffer.insert(input_tokens, roi, key, value, hidden, "req1") - - self.pipe.send_tensor.assert_called() - - def test_drop_select(self): - input_tokens = torch.tensor([1, 2, 3]) - roi = None - - self.buffer.num_layers = 2 - self.buffer.num_heads = 4 - self.buffer.head_size = 5 - self.buffer.hidden_size = 6 - self.buffer.dtype = torch.float32 - - self.pipe.recv_tensor.side_effect = [ - (MagicMock(), torch.randn(1, 2, 3 * 4 * 5)), - (MagicMock(), torch.randn(1, 2, 3 * 4 * 5)), - (MagicMock(), torch.randn(1, 3, 6)) - ] - - result = self.buffer.drop_select(input_tokens, roi, "req1") - self.assertEqual(len(result), 4) - self.assertIsInstance(result[0], torch.Tensor) - self.assertIsInstance(result[1], torch.Tensor) - self.assertIsInstance(result[2], torch.Tensor) - self.assertIsNone(result[3]) - self.assertEqual(result[0].shape, (2, 3, 4, 5)) - - def test_close(self): - self.buffer.close() diff --git a/tests/ut/distributed/kv_transfer/test_simple_connector.py b/tests/ut/distributed/kv_transfer/test_simple_connector.py deleted file mode 100644 index 2c81943a7c..0000000000 --- a/tests/ut/distributed/kv_transfer/test_simple_connector.py +++ /dev/null @@ -1,146 +0,0 @@ -from unittest.mock import MagicMock, patch - -import torch -from vllm.config import VllmConfig -from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -from tests.ut.base import TestBase -from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer -from vllm_ascend.distributed.kv_transfer.simple_connector import \ - SimpleConnector -from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe - - -class TestSimpleConnector(TestBase): - - def setUp(self): - self.mock_pipe = MagicMock(spec=SimplePipe) - self.mock_buffer = MagicMock(spec=SimpleBuffer) - - patcher = patch( - 'vllm_ascend.distributed.kv_transfer.simple_buffer.SimpleBuffer') - self.addCleanup(patcher.stop) - self.MockSimpleBuffer = patcher.start() - self.MockSimpleBuffer.return_value = self.mock_buffer - - def _create_mock_config(self, kv_role): - mock_config = MagicMock() - mock_config.kv_role = "kv_producer" - mock_config.kv_connector_extra_config = { - "prefill_device_ips": ["127.0.0.1"], - "decode_device_ips": ["127.0.0.1"], - "llmdatadist_comm_port": 26000, - "http_port": 8000, - "proxy_ip": "127.0.0.1", - "proxy_port": "8000", - "port": 5500 - } - mock_config.kv_port = 5500 - self.mock_config = MagicMock(spec=VllmConfig) - self.mock_config.kv_transfer_config.is_kv_producer = True - self.mock_config.model_config.hf_config.hidden_size = 128 - self.mock_config.model_config.hf_config.num_attention_heads = 8 - self.mock_config.model_config.hf_config.num_key_value_heads = 8 - self.mock_config.model_config.hf_config.qk_rope_head_dim = 16 - self.mock_config.model_config.hf_config.kv_lora_rank = 16 - self.mock_config.model_config.is_deepseek_mla = True - # 模拟 parallel_config - self.mock_config.parallel_config = MagicMock() - self.mock_config.parallel_config.tensor_parallel_size = 1 - self.mock_config.parallel_config.get_num_layers.return_value = 4 - - if kv_role == "kv_producer": - self.mock_config.kv_transfer_config.kv_role = "kv_producer" - else: - self.mock_config.kv_transfer_config.kv_role = "kv_consumer" - return mock_config - - @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') - @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') - @patch('llm_datadist.LLMDataDist') - def test_select_init(self, mock_pipe, mock_buffer, MockLLMDataDist): - """Test select method when buffer retrieval succeeds.""" - connector = SimpleConnector( - rank=0, - local_rank=0, - config=self._create_mock_config("kv_producer")) - assert connector.producer_data_pipe is not None - assert connector.producer_buffer is not None - mock_data_dist = MockLLMDataDist.return_value - mock_data_dist.init.return_value = None - - @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') - @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') - @patch('llm_datadist.LLMDataDist') - def test_select_select(self, mock_pipe, mock_buffer, MockLLMDataDist): - - connector = SimpleConnector( - rank=0, - local_rank=0, - config=self._create_mock_config("kv_consumer")) - connector.consumer_data_pipe = mock_pipe - connector.consumer_buffer = mock_buffer - assert connector.consumer_data_pipe is not None - assert connector.consumer_buffer is not None - input_tokens = torch.tensor([1, 2, 3]) - roi = torch.tensor([True, True, True]) - req_id = "test_req" - connector.select(input_tokens, roi, req_id) - - @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') - @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') - @patch('llm_datadist.LLMDataDist') - def test_insert(self, mock_pipe, mock_buffer, MockLLMDataDist): - """Test insert operation""" - connector = SimpleConnector( - rank=0, - local_rank=0, - config=self._create_mock_config("kv_producer")) - - connector.producer_buffer = mock_buffer - - input_tokens = torch.randint(0, 1000, (5, )) - roi = torch.ones_like(input_tokens, dtype=torch.bool) - keys = torch.randn(3, 5, 1, 96) - values = torch.randn(3, 5, 1, 96) - hidden = torch.randn(5, 768) - req_id = "test_req" - - connector.insert(input_tokens, roi, keys, values, hidden, req_id) - - mock_buffer.insert.assert_called_once_with(input_tokens, roi, keys, - values, hidden, req_id) - - @patch.object(SimpleConnector, 'insert') - @patch('torch.distributed.get_rank', return_value=0) - @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') - @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') - @patch('llm_datadist.LLMDataDist') - def test_send_kv_caches_and_hidden_states(self, mock_pipe, mock_buffer, - MockLLMDataDist, mock_insert, - mock_rank): - """Test sending KV caches and hidden states""" - connector = SimpleConnector( - rank=0, - local_rank=0, - config=self._create_mock_config("kv_producer")) - - mock_model_executable = MagicMock() - mock_model_executable.model.start_layer = 0 - mock_model_executable.model.end_layer = 3 - - mock_model_input = MagicMock(spec=ModelInputForGPUWithSamplingMetadata) - mock_model_input.input_tokens = torch.randint(0, 1000, (10, )) - mock_model_input.attn_metadata.seq_lens = [5, 5] - mock_model_input.attn_metadata.slot_mapping = torch.randint( - 0, 100, (10, )) - mock_model_input.attn_metadata.num_prefill_tokens = 10 - mock_model_input.request_ids_to_seq_ids = {"req1": [0], "req2": [1]} - - kv_caches = [torch.randn(2, 100, 1, 96) for _ in range(3)] - - hidden_states = torch.randn(10, 768) - - connector.send_kv_caches_and_hidden_states(mock_model_executable, - mock_model_input, kv_caches, - hidden_states) diff --git a/tests/ut/distributed/kv_transfer/test_simple_pipe.py b/tests/ut/distributed/kv_transfer/test_simple_pipe.py deleted file mode 100644 index ccc984b62a..0000000000 --- a/tests/ut/distributed/kv_transfer/test_simple_pipe.py +++ /dev/null @@ -1,145 +0,0 @@ -from unittest.mock import MagicMock, patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe - - -class TestSimplePipe(TestBase): - - @classmethod - def _create_mock_config(self): - mock_config = MagicMock() - mock_config.kv_role = "kv_producer" - mock_config.kv_connector_extra_config = { - "prefill_device_ips": ["127.0.0.1"], - "decode_device_ips": ["127.0.0.1"], - "llmdatadist_comm_port": 26000, - "http_port": 8000, - "proxy_ip": "127.0.0.1", - "proxy_port": "8000", - "port": 5500 - } - mock_config.kv_port = 5500 - return mock_config - - @patch('threading.Thread') - @patch('llm_datadist.LLMDataDist') - def test_init_success(self, mock_thread, MockLLMDataDist): - - mock_config = self._create_mock_config() - - self.pipe = SimplePipe(rank=5, - local_rank=0, - kv_transfer_config=mock_config, - hostname="127.0.0.1", - port_offset=0) - - self.pipe.router_socket.close() - - @patch('threading.Thread') - @patch('llm_datadist.LLMDataDist') - def test_prepare_data_dist(self, mock_thread, MockLLMDataDist): - self.pipe = SimplePipe(rank=5, - local_rank=0, - kv_transfer_config=self._create_mock_config(), - hostname="127.0.0.1", - port_offset=0) - mock_data_dist = MockLLMDataDist.return_value - mock_data_dist.init.return_value = None - self.pipe.router_socket.close() - - def test_init_with_invalid_kv_role(self): - with self.assertRaises(NotImplementedError): - mock_config = MagicMock() - mock_config.kv_role = "err_role" - mock_config.kv_connector_extra_config = { - "prefill_device_ips": ["127.0.0.1"], - "decode_device_ips": ["127.0.0.1"], - "llmdatadist_comm_port": 26000, - "http_port": 8000, - "proxy_ip": "127.0.0.1", - "proxy_port": "8000", - "port": 5500 - } - pipe = SimplePipe(rank=5, - local_rank=0, - kv_transfer_config=mock_config, - hostname="127.0.0.1", - port_offset=0) - pipe.router_socket.close() - - def test_init_with_missing_device_ips(self): - with self.assertRaises(ValueError): - mock_config = MagicMock() - mock_config.kv_role = "kv_producer" - mock_config.kv_connector_extra_config = { - "llmdatadist_comm_port": 26000, - "http_port": 8000, - "proxy_ip": "127.0.0.1", - "proxy_port": "8000", - "port": 5500 - } - pipe = SimplePipe(rank=0, - local_rank=0, - kv_transfer_config=mock_config, - hostname="127.0.0.1", - port_offset=0) - pipe.router_socket.close() - - @patch('threading.Thread') - @patch('llm_datadist.LLMDataDist') - def test_create_register_thread_address_is_empty(self, MockThread, - MockLLMDataDist): - - mock_config = self._create_mock_config() - pipe = SimplePipe(rank=5, - local_rank=0, - kv_transfer_config=mock_config, - hostname="127.0.0.1", - port_offset=0) - self.assertIsNotNone(pipe._register_thread) - mock_data_dist = MockLLMDataDist.return_value - mock_data_dist.init.return_value = None - pipe.router_socket.close() - - @patch('threading.Thread') - @patch('llm_datadist.LLMDataDist') - def test_create_register_thread_address_is_not_empty( - self, MockThread, MockLLMDataDist): - mock_config = MagicMock() - mock_config.kv_role = "kv_producer" - mock_config.kv_connector_extra_config = { - "prefill_device_ips": [""], - "decode_device_ips": [""], - "llmdatadist_comm_port": 26000, - "http_port": 8000, - "proxy_ip": "127.0.0.1", - "proxy_port": "8000", - "port": 5500 - } - pipe = SimplePipe(rank=5, - local_rank=0, - kv_transfer_config=mock_config, - hostname="127.0.0.1", - port_offset=0) - self.assertIsNotNone(pipe._register_thread) - mock_data_dist = MockLLMDataDist.return_value - mock_data_dist.init.return_value = None - pipe.router_socket.close() - - @patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe') - @patch('llm_datadist.LLMDataDist') - def test_should_send_tensor_when_valid_input(self, MockSimplePipe, - MockLLMDataDist): - pipe = MockSimplePipe() - tensor = torch.randn(3, 3) - tensor_desc = MockLLMDataDist.CacheDesc( - num_tensors=1, - shape=(3, 3), - data_type=MockLLMDataDist.DataType.DT_FLOAT, - seq_len_dim_index=1) - tensor_key = MockLLMDataDist.CacheKey(1, 0, 1) - result = pipe.send_tensor(tensor, tensor_desc, tensor_key) - self.assertIsNotNone(result) diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index d7be705c2b..ebe8694b09 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -18,14 +18,6 @@ from vllm.distributed.kv_transfer.kv_connector.factory import \ KVConnectorFactory -KVConnectorFactory.register_connector( - "AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector", - "LLMDataDistConnector") - -KVConnectorFactory.register_connector( - "AscendSimpleConnector", - "vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector") - KVConnectorFactory.register_connector( "LLMDataDistCMgrConnector", "vllm_ascend.distributed.llmdatadist_c_mgr_connector", diff --git a/vllm_ascend/distributed/kv_transfer/__init__.py b/vllm_ascend/distributed/kv_transfer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/vllm_ascend/distributed/kv_transfer/simple_buffer.py b/vllm_ascend/distributed/kv_transfer/simple_buffer.py deleted file mode 100644 index 57474d07f4..0000000000 --- a/vllm_ascend/distributed/kv_transfer/simple_buffer.py +++ /dev/null @@ -1,207 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import zlib -from typing import List, Optional - -import llm_datadist # type: ignore -import torch -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \ - KVLookupBufferBase -from vllm.logger import logger - -from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe -from vllm_ascend.distributed.kv_transfer.utils import TORCH_DTYPE_TO_NPU_DTYPE - - -# Hash a string into a int32 value. -def int32_hash(data): - assert isinstance(data, str) - data = data.encode("utf-8") - return zlib.adler32(data) - - -class SimpleBuffer(KVLookupBufferBase): - - def __init__(self, data_pipe: SimplePipe): - self.data_pipe = data_pipe - # Consumer buffer need these information to construct receiving buffer. - self.num_layers = None - self.num_heads = None - self.head_size = None - self.dtype = None - self.hidden_size = None - self.key_buffer = None - self.value_buffer = None - self.hidden_buffer = None - - def insert( - self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - hidden: torch.Tensor, - req_id: str, - ) -> None: - """ - seq_len: num_tokens of current request. - input_tokens: [seq_len] - roi: [seq_len] - key: [num_layers, seq_len, num_kv_heads, head_size] - value: [num_layers, seq_len, num_kv_heads, head_size] - hidden: [seq_len, hidden_size] - """ - orig_k_shape = key.shape - num_layers = orig_k_shape[0] - - # unsequeeze all tensors to make first dim to 1. - # This is because D node can only pull one batch data from P. - # So we make first dim to 1 here in order to pull full data. - key = key.view(num_layers, -1).unsqueeze(0) - value = value.view(num_layers, -1).unsqueeze(0) - hidden = hidden.unsqueeze(0) - - hidden_dtype = key.dtype - # initialize LLMDatadist data structure - key_desc = llm_datadist.CacheDesc( - 1, - key.shape, - TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype], - seq_len_dim_index=1, - ) - value_desc = llm_datadist.CacheDesc( - 1, - value.shape, - TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype], - seq_len_dim_index=1, - ) - hidden_desc = llm_datadist.CacheDesc( - 1, - hidden.shape, - TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype], - seq_len_dim_index=-1, - ) - - req_id = int32_hash(req_id) - key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id, - req_id, 1) - value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id, - req_id, 2) - hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id, - req_id, 3) - - # Currently we use hash value of request id as key, so no need to send input_tokens - self.key_buffer = self.data_pipe.send_tensor(key, key_desc, - key_cache_key) - self.value_buffer = self.data_pipe.send_tensor(value, value_desc, - value_cache_key) - self.hidden_buffer = self.data_pipe.send_tensor( - hidden, hidden_desc, hidden_cache_key) - - def drop_select( - self, - input_tokens: torch.Tensor, - roi: Optional[torch.Tensor], - req_id: str, - ) -> List[Optional[torch.Tensor]]: - """Select and *drop* KV cache entries from the lookup buffer. - - The functionality is similar to the following python statements - ``` - ret = buffer.pop(input_tokens, roi) - return ret - ``` - - Args: - input_tokens (torch.Tensor): token IDs. - roi (torch.Tensor): A binary mask on top of the input tokens - - Returns: - A list of tensors including: - key: [num_layers, num_tokens, num_heads, head_size] - value: [num_layers, num_tokens, num_heads, head_size] - hidden_or_intermediate_states: [num_tokens, hidden_size] - roi: None (Currently we don't supported roi) - """ - orig_req_id = req_id - req_id = int32_hash(req_id) - num_tokens = input_tokens.shape[0] - kv_shape = ( - 1, - self.num_layers, - num_tokens * self.num_heads * self.head_size, - ) - hidden_shape = (1, num_tokens, self.hidden_size) - key_desc = llm_datadist.CacheDesc( - 1, - kv_shape, - TORCH_DTYPE_TO_NPU_DTYPE[self.dtype], - seq_len_dim_index=-1, - ) - value_desc = llm_datadist.CacheDesc( - 1, - kv_shape, - TORCH_DTYPE_TO_NPU_DTYPE[self.dtype], - seq_len_dim_index=-1, - ) - hidden_desc = llm_datadist.CacheDesc( - 1, - hidden_shape, - TORCH_DTYPE_TO_NPU_DTYPE[self.dtype], - seq_len_dim_index=-1, - ) - - key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id, - req_id, 1) - value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id, - req_id, 2) - hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id, - req_id, 3) - - # Deallocate buffer allocated in last round. - if self.key_buffer: - try: - self.data_pipe.deallocate_buffer(self.key_buffer) - self.data_pipe.deallocate_buffer(self.value_buffer) - self.data_pipe.deallocate_buffer(self.hidden_buffer) - except Exception as e: - logger.warning( - f"Failed to free kv cache buffer, Error code: {str(e)}") - - try: - self.key_buffer, key = self.data_pipe.recv_tensor( - key_desc, key_cache_key) - self.value_buffer, value = self.data_pipe.recv_tensor( - value_desc, value_cache_key) - self.hidden_buffer, hidden = self.data_pipe.recv_tensor( - hidden_desc, hidden_cache_key) - key = key.view(self.num_layers, num_tokens, self.num_heads, - self.head_size) - value = value.view(self.num_layers, num_tokens, self.num_heads, - self.head_size) - hidden = hidden.view(num_tokens, self.hidden_size) - except Exception as e: - logger.warning( - f"Fail to receive kv cache and hidden states of request: {orig_req_id} " - f"Error is {str(e)}") - return [None, None, None, None] - - return [key, value, hidden, roi] - - def close(self): - pass diff --git a/vllm_ascend/distributed/kv_transfer/simple_connector.py b/vllm_ascend/distributed/kv_transfer/simple_connector.py deleted file mode 100644 index 31b38c068e..0000000000 --- a/vllm_ascend/distributed/kv_transfer/simple_connector.py +++ /dev/null @@ -1,379 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import TYPE_CHECKING, List, Optional, Tuple, Union - -import torch -import torch_npu -import vllm.envs as vllm_envs -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.parallel_state import get_dp_group -from vllm.logger import logger -from vllm.sequence import IntermediateTensors - -from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer -from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - - -class SimpleConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - self.config = config - self.model_config = config.model_config.hf_config - self.tp_size = config.parallel_config.tensor_parallel_size - self.rank = rank - self.local_rank = local_rank - self.is_deepseek_mla = config.model_config.is_deepseek_mla - self.use_mla_opt = not vllm_envs.VLLM_MLA_DISABLE - self.n_layer = self.config.model_config.get_num_layers( - self.config.parallel_config) - - self.producer_data_pipe: Optional[SimplePipe] - self.consumer_data_pipe: Optional[SimplePipe] - - self.producer_buffer: Optional[SimpleBuffer] - self.consumer_buffer: Optional[SimpleBuffer] - - if self.config.kv_transfer_config.is_kv_producer: - self.producer_data_pipe = SimplePipe( - rank=rank, - local_rank=local_rank, - kv_transfer_config=config.kv_transfer_config, - hostname="", - port_offset=rank, - ) - self.producer_buffer = SimpleBuffer(self.producer_data_pipe) - else: - self.consumer_data_pipe = SimplePipe( - rank=rank, - local_rank=local_rank, - kv_transfer_config=config.kv_transfer_config, - hostname="", - port_offset=rank, - ) - self.consumer_buffer = SimpleBuffer(self.consumer_data_pipe) - - def select( - self, - input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor], - req_id: str, - ) -> List[Optional[torch.Tensor]]: - - assert self.consumer_buffer is not None, ( - "Please initialize the " - "consumer buffer before calling select.") - return self.consumer_buffer.drop_select(input_tokens, roi, req_id) - - def insert( - self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - keys: torch.Tensor, - values: torch.Tensor, - hidden: torch.Tensor, - req_id: str, - ) -> None: - - assert self.producer_buffer is not None, ( - "Please initialize the " - "producer buffer before calling insert.") - self.producer_buffer.insert(input_tokens, roi, keys, values, hidden, - req_id) - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - - model_config = self.model_config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - - # Deepseek's MLA (Multi-head Latent Attention) uses two different - # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. - # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, - # resulting in a kv_cache shape of [num_blks, blk_size, 1, - # kv_lora_rank + qk_rope_head_dim]. - # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading - # to a kv_cache shape of [2, num_blks, blk_size, - # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. - # For more details, see vllm/attention/backends/mla/common.py. - if self.is_deepseek_mla and self.use_mla_opt: - head_size = (model_config.kv_lora_rank + - model_config.qk_rope_head_dim) - num_heads = 1 - elif self.is_deepseek_mla and not self.use_mla_opt: - head_size = (model_config.qk_nope_head_dim + - model_config.qk_rope_head_dim) - else: - head_size = getattr( - model_config, - "head_dim", - int(hidden_size // num_attention_heads), - ) - # Enumerate over all requests and insert them one by one. - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You have some decode requests while using " - "SimpleConnector. Their KVCache won't be sent.") - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - - keys, values = [], [] - - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] - - if self.is_deepseek_mla and self.use_mla_opt: - key_cache = kv_cache.reshape(-1, num_heads, head_size) - value_cache = kv_cache.reshape(-1, num_heads, head_size) - else: - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) - - # shape: [num_layers, num_tokens, num_heads, head_size] - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) - cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx] - # Currently we haven't considered situation of roi, pass None here. - self.insert( - current_tokens, - None, - keys, - values, - hidden_or_intermediate_states[start_pos:end_pos], - cur_req_id, - ) - - logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) - - def recv_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], - ) -> Tuple[ - Union[torch.Tensor, IntermediateTensors], - bool, - "ModelInputForGPUWithSamplingMetadata", - ]: - bypass_model_exec = True - - model_config = self.model_config - - # get model config - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - num_heads, head_dim = kv_caches[0].shape[-2:] - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - num_layers = end_layer - start_layer - if self.is_deepseek_mla and self.use_mla_opt: - head_size = (model_config.kv_lora_rank + - model_config.qk_rope_head_dim) - num_heads = 1 - elif self.is_deepseek_mla and not self.use_mla_opt: - head_size = (model_config.qk_nope_head_dim + - model_config.qk_rope_head_dim) - else: - head_size = getattr( - model_config, - "head_dim", - int(hidden_size // num_attention_heads), - ) - self.consumer_buffer.num_heads = num_heads # type: ignore - self.consumer_buffer.num_layers = num_layers # type: ignore - self.consumer_buffer.head_size = head_size # type: ignore - self.consumer_buffer.dtype = kv_caches[0].dtype # type: ignore - self.consumer_buffer.hidden_size = hidden_size # type: ignore - - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - - total_tokens = model_input.attn_metadata.num_prefill_tokens + model_input.attn_metadata.num_decode_tokens - hidden_or_intermediate_states_for_one_req = [] - - input_tokens_list = [] - num_computed_tokens_list = [] - start_pos_list = [] - - # enumerate different requests - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - logger.warning("You should set --enable_chunked_prefill=False " - "and --max_num_batched_tokens " - "should be equal to --max_seq_len_to_capture") - bypass_model_exec = False - assert start_pos == num_prefill_tokens - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - num_tokens = slen - - # collecting data for rebuilding the input - input_tokens_list.append(current_tokens) - start_pos_list.append(start_pos) - - cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx] - - ret = self.select( - current_tokens, - torch.ones_like(current_tokens, dtype=bool), - cur_req_id, - ) - if ret[0] is None: - # didn't find any match. - bypass_model_exec = False - num_computed_tokens_list.append(0) - continue - - keys: torch.Tensor = ret[0] - values: torch.Tensor = ret[1] - hidden: torch.Tensor = ret[2] - - num_computed_tokens = keys.shape[1] - num_computed_tokens_list.append(num_computed_tokens) - - # check if both KV cache and the hidden states are received - # If not, need to redo the forwarding to compute missing states - if not all([(num_computed_tokens == num_tokens), hidden is not None - ]): - bypass_model_exec = False - - # update the end position based on how many tokens are cached. - end_pos = start_pos + num_computed_tokens - - # put received KV caches into paged memory - for i in range( - model_executable.model.start_layer, - model_executable.model.end_layer, - ): - - kv_cache = kv_caches[i - model_executable.model.start_layer] - layer = model_executable.model.layers[i] - - if self.is_deepseek_mla and self.use_mla_opt: - layer.self_attn.attn = layer.self_attn.mla_attn - key_cache = kv_cache - slots = slot_mapping[start_pos:end_pos] - sliced_key = keys[i - model_executable.model.start_layer] - torch_npu._npu_reshape_and_cache_siso(key=sliced_key, - key_cache=key_cache, - slot_indices=slots) - else: - key_cache, value_cache = kv_cache[0], kv_cache[1] - sliced_key = keys[i - model_executable.model.start_layer] - sliced_value = values[i - - model_executable.model.start_layer] - torch_npu._npu_reshape_and_cache( - key=sliced_key, - value=sliced_value, - key_cache=key_cache, - value_cache=value_cache, - slot_indices=slot_mapping[start_pos:end_pos], - ) - - hidden_or_intermediate_states_for_one_req.append(hidden) - - if not bypass_model_exec: - # Some of the KV cache is not retrieved - # Here we will fall back to normal model forwarding - # But optionally you can adjust model_input so that you only do - # prefilling on those tokens that are missing KV caches. - if get_dp_group().world_size > 1: - bypass_model_exec = True - hidden_or_intermediate_states = torch.empty( - [total_tokens, hidden_size], - dtype=kv_caches[0].dtype, - device=kv_caches[0].device) - logger.warning( - "[Detect there is more one DP rank in this decode node, in this scenario, no recompute is expected when kv cache dose not received.]" - ) - else: - logger.warning( - "[rank%d]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", - torch.distributed.get_rank()) - hidden_or_intermediate_states = None - else: - logger.debug( - "[rank%d]: Successfully received all KVs and hidden " - "states, skip model forwarding.", - torch.distributed.get_rank(), - ) - # Can't directly concat here which might cause error when bs = 1. - # hidden_or_intermediate_states = torch.empty(total_num_tokens, hidden_size, dtype=kv_caches[0].dtype, device=kv_caches[0].device) - if len(hidden_or_intermediate_states_for_one_req) == 1: - hidden = hidden_or_intermediate_states_for_one_req[0] - tmp_indice = torch.tensor([0] * hidden.shape[0], - dtype=torch.int64).npu() - hidden_or_intermediate_states = torch.empty_like(hidden) - torch_npu.scatter_update_( - hidden_or_intermediate_states, - tmp_indice, - hidden, - axis=-1, - ) - else: - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - return hidden_or_intermediate_states, bypass_model_exec, model_input - - def close(self): - self.producer_data_pipe.close() # type: ignore - self.consumer_data_pipe.close() # type: ignore - self.producer_buffer.close() # type: ignore - self.consumer_buffer.close() # type: ignore diff --git a/vllm_ascend/distributed/kv_transfer/simple_pipe.py b/vllm_ascend/distributed/kv_transfer/simple_pipe.py deleted file mode 100644 index ef9dd3ca36..0000000000 --- a/vllm_ascend/distributed/kv_transfer/simple_pipe.py +++ /dev/null @@ -1,207 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import threading -import time -from typing import Optional - -import llm_datadist # type: ignore -import msgpack # type: ignore -import torch -import torch_npu -import torchair # type: ignore -import zmq # type: ignore -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from vllm.logger import logger -from vllm.utils import get_ip - -import vllm_ascend.envs as envs -from vllm_ascend.distributed.kv_transfer.utils import NPU_DTYPE_TO_TORCH_DTYPE - - -class SimplePipe(KVPipeBase): - - def __init__( - self, - rank, - local_rank, - kv_transfer_config, - hostname: str = "", - port_offset: int = 0, # NPU offset in current P/D instance. - ): - self.rank = rank - self.local_rank = local_rank - # Currently for 1P1D situation, we use cluster_id=0 for both Prefill and Decode - # Will change here in the future to support xPyD. - self.cluster_id = 0 - self.config = kv_transfer_config - kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config - kv_role = kv_transfer_config.kv_role - if kv_role == "kv_producer": - self.role = llm_datadist.LLMRole.PROMPT - elif kv_role == "kv_consumer": - self.role = llm_datadist.LLMRole.DECODER - else: - raise NotImplementedError( - "kv_role should be inside [kv_producer, kv_consumer]") - - prefill_device_ips = kv_connector_extra_config.get( - "prefill_device_ips", None) - decode_device_ips = kv_connector_extra_config.get( - "decode_device_ips", None) - if prefill_device_ips is None or decode_device_ips is None: - raise ValueError( - "Please specify prefill_device_ips and decode_device_ips" - "in kv_transfer_config.kv_connector_extra_config") - p_device_num = len(prefill_device_ips) - d_device_num = len(decode_device_ips) - # When number of devices in P and D is not equal, - # we assume that device in D can be mapped to any device in P. - self.p_device_rank = self.rank % p_device_num - self.d_device_rank = self.rank % d_device_num - - self.prompt_ip_list = prefill_device_ips - self.decode_ip_list = decode_device_ips - self.llmdatadist_comm_port = kv_connector_extra_config.get( - "llmdatadist_comm_port", 26000) - # LLMDataDist initializing. - self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id) - self._prepare_data_dist() - # Decoder needs to initialize and link cluster - if self.role == llm_datadist.LLMRole.DECODER: - self.cluster = self._make_cluster() - _, ret = self.data_dist.link_clusters([self.cluster], 20000) - logger.info( - f"rank {self.rank}, local_rank {self.local_rank} link, ret={ret}" - ) - - # If `proxy_ip` or `proxy_port` is `""`, - # then the ping thread will not be enabled. - proxy_ip = self.config.get_from_extra_config("proxy_ip", "") - proxy_port = self.config.get_from_extra_config("proxy_port", "") - if proxy_ip == "" or proxy_port == "": - self.proxy_address = "" - else: - self.proxy_address = proxy_ip + ":" + str(proxy_port) - - self._register_thread = None - if port_offset == 0 and self.proxy_address != "": - # Initialize zmq socket and register to proxy. - # Note that only NPU 0 of each P/D instance register to proxy. - if not hostname: - hostname = get_ip() # Get ip of current host. - port = int(kv_transfer_config.kv_port) + port_offset - if port == 0: - raise ValueError("Port cannot be 0") - self._hostname = hostname - self._port = port - # Each card corresponds to a ZMQ address. - self.zmq_address = f"{self._hostname}:{self._port}" - - self.context = zmq.Context() # type: ignore - self.router_socket = self.context.socket( - zmq.ROUTER) # type: ignore - self.router_socket.bind(f"tcp://{self.zmq_address}") - # The `http_port` must be consistent with the serving port of OpenAI. - self.http_address = ( - f"{self._hostname}:" - f"{self.config.kv_connector_extra_config['http_port']}") - self._register_thread = threading.Thread( - target=self._register_to_proxy, daemon=True) - self._register_thread.start() - - def _prepare_data_dist(self): - options = { - "llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME, - } - if self.role == llm_datadist.LLMRole.PROMPT: - options["ge.exec.deviceId"] = str(self.local_rank) - options["llm.listenIpInfo"] = ( - f"{self.prompt_ip_list[self.p_device_rank]}:{self.llmdatadist_comm_port}" - ) - else: - options["ge.exec.deviceId"] = str(self.local_rank) - print(f"prepare datadist, options: {options}") - self.data_dist.init(options) - self.kv_transfer = self.data_dist.kv_cache_manager - print(f"{self.rank} rank data dist is ready") - - def _make_cluster(self): - cluster = llm_datadist.LLMClusterInfo() - cluster.remote_cluster_id = self.cluster_id - local_ip = self.decode_ip_list[self.d_device_rank] - remote_ip = self.prompt_ip_list[self.p_device_rank] - cluster.append_local_ip_info(local_ip, 0) - cluster.append_remote_ip_info(remote_ip, self.llmdatadist_comm_port) - return cluster - - def _register_to_proxy(self): - sock = self.context.socket(zmq.DEALER) # type: ignore - sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) # type: ignore - logger.debug("ping start, zmq_address:%s", self.zmq_address) - sock.connect(f"tcp://{self.proxy_address}") - data = { - "type": "P" if self.config.is_kv_producer else "D", - "http_address": self.http_address, - "zmq_address": self.zmq_address, - } - while True: - sock.send(msgpack.dumps(data)) - time.sleep(3) - - def send_tensor( - self, - tensor: Optional[torch.Tensor], - tensor_desc: llm_datadist.CacheDesc, - tensor_key: llm_datadist.CacheKey, - ) -> llm_datadist.Cache: - buffer = self.kv_transfer.allocate_cache(tensor_desc, [tensor_key]) - buffer_addr = buffer.per_device_tensor_addrs[0] - data_tensor = torchair.llm_datadist.create_npu_tensors( - tensor_desc.shape, tensor.dtype, buffer_addr)[0] # type: ignore - update_indices = torch.tensor( - [0] * tensor.shape[0], # type: ignore - dtype=torch.int64).npu() - torch_npu.scatter_update_(data_tensor, update_indices, tensor, axis=-1) - # Free cache_id of buffer, actual deallocate will happen after consumer performing pull_cache. - self.kv_transfer.deallocate_cache(buffer) - return buffer - - def recv_tensor( - self, - tensor_desc: llm_datadist.CacheDesc, - tensor_key: llm_datadist.CacheKey, - ) -> llm_datadist.Cache: - """Note that this function only creates empty tensor on buffer addr and returns it.""" - tmp_buffer = self.kv_transfer.allocate_cache(tensor_desc) - buffer_addr = tmp_buffer.per_device_tensor_addrs[0] - data_tensor = torchair.llm_datadist.create_npu_tensors( - tensor_desc.shape, - NPU_DTYPE_TO_TORCH_DTYPE[tensor_desc.data_type], - buffer_addr, - )[0] - self.kv_transfer.pull_cache(tensor_key, tmp_buffer, 0) - # tmp_buffer is allocated without key and will be deallocated here immediately. - # Free buffer here will cause accuracy problem. - # self.kv_transfer.deallocate_cache(tmp_buffer) - return tmp_buffer, data_tensor - - def deallocate_buffer(self, buffer: llm_datadist.Cache): - self.kv_transfer.deallocate_cache(buffer) - - def close(self): - self.data_dist.unlink_clusters([self.cluster], 5000) diff --git a/vllm_ascend/distributed/kv_transfer/utils.py b/vllm_ascend/distributed/kv_transfer/utils.py deleted file mode 100644 index 9dc43a06d3..0000000000 --- a/vllm_ascend/distributed/kv_transfer/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import llm_datadist # type: ignore -import torch - -TORCH_DTYPE_TO_NPU_DTYPE = { - torch.half: llm_datadist.DataType.DT_FLOAT16, - torch.float16: llm_datadist.DataType.DT_FLOAT16, - torch.bfloat16: llm_datadist.DataType.DT_BF16, - torch.float: llm_datadist.DataType.DT_FLOAT, - torch.float32: llm_datadist.DataType.DT_FLOAT, - torch.int8: llm_datadist.DataType.DT_INT8, - torch.int64: llm_datadist.DataType.DT_INT64, - torch.int32: llm_datadist.DataType.DT_INT32, -} - -NPU_DTYPE_TO_TORCH_DTYPE = { - llm_datadist.DataType.DT_FLOAT16: torch.half, - llm_datadist.DataType.DT_FLOAT16: torch.float16, - llm_datadist.DataType.DT_BF16: torch.bfloat16, - llm_datadist.DataType.DT_FLOAT: torch.float, - llm_datadist.DataType.DT_FLOAT: torch.float32, - llm_datadist.DataType.DT_INT8: torch.int8, - llm_datadist.DataType.DT_INT64: torch.int64, - llm_datadist.DataType.DT_INT32: torch.int32, -} \ No newline at end of file diff --git a/vllm_ascend/distributed/llmdatadist_connector.py b/vllm_ascend/distributed/llmdatadist_connector.py deleted file mode 100644 index 19a759a850..0000000000 --- a/vllm_ascend/distributed/llmdatadist_connector.py +++ /dev/null @@ -1,470 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -import re -import subprocess -from typing import TYPE_CHECKING, List, Tuple, Union - -import torch -import torch_npu -import torchair # type: ignore -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.logger import logger -from vllm.sequence import IntermediateTensors - -import vllm_ascend.envs as envs - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -import llm_datadist # type: ignore - -TORCH_DTYPE_TO_NPU_DTYPE = { - torch.half: llm_datadist.DataType.DT_FLOAT16, - torch.float16: llm_datadist.DataType.DT_FLOAT16, - torch.bfloat16: llm_datadist.DataType.DT_BF16, - torch.float: llm_datadist.DataType.DT_FLOAT, - torch.float32: llm_datadist.DataType.DT_FLOAT, - torch.int8: llm_datadist.DataType.DT_INT8, - torch.int64: llm_datadist.DataType.DT_INT64, - torch.int32: llm_datadist.DataType.DT_INT32 -} - -# Get all device ips using hccn_tool -HCCN_TOOL_PATH = envs.HCCN_PATH - - -def get_device_ips(): - world_size = 8 - npu_info = subprocess.run(['npu-smi', 'info', '-m'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True) - if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH): - raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.") - re_result = re.match(r'.*\n\t([0-9]+).*', npu_info.stdout) - if re_result is None: - raise RuntimeError("Can't find npu start index") - npu_start_idx = int(re_result.group(1)) - device_ip_list = [] - for ip_offset in range(world_size): - cmd = [ - HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', '-g' - ] - device_ip_info = subprocess.run(cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True) - re_result = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout) - if re_result is None: - raise RuntimeError("Can't find npu ip") - device_ip = re_result.group(1) - device_ip_list.append(device_ip) - return device_ip_list - - -class KVTransferEngine: - - def __init__(self, world_size, n_layer, role, local_rank): - self.world_size = world_size - self.n_layer = n_layer - self.role = role - self.device_ip_list = get_device_ips() - self.local_rank = local_rank - self.cluster_id = local_rank - self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id) - - prompt_device_ids = envs.PROMPT_DEVICE_ID - decode_device_ids = envs.DECODE_DEVICE_ID - if prompt_device_ids is None or decode_device_ids is None: - raise ValueError( - "Please specify env PROMPT_DEVICE_ID or DECODE_DEVICE_ID") - - prompt_ids = [ - int(x.strip()) for x in prompt_device_ids.split(",") if x.strip() - ] - decode_ids = [ - int(x.strip()) for x in decode_device_ids.split(",") if x.strip() - ] - - self.prompt_ip_list = [self.device_ip_list[i] for i in prompt_ids] - self.decode_ip_list = [self.device_ip_list[i] for i in decode_ids] - - def prepare_data_dist(self): - options = { - "llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME, - } - if self.role == llm_datadist.LLMRole.PROMPT: - options["ge.exec.deviceId"] = str(self.local_rank) - options[ - "llm.listenIpInfo"] = f"{self.prompt_ip_list[self.local_rank]}:{envs.LLMDATADIST_COMM_PORT}" - else: - options["ge.exec.deviceId"] = str(self.local_rank) - self.data_dist.init(options) - self.kv_transfer = self.data_dist.kv_cache_manager - logger.info( - f"{self.local_rank}/{self.world_size} rank data dist is ready") - - def make_cluster(self, prefill_ip, cluster_id=-1): - cluster = llm_datadist.LLMClusterInfo() - cluster.remote_cluster_id = cluster_id - local_ip = self.decode_ip_list[self.local_rank] - remote_ip = prefill_ip - cluster.append_local_ip_info(local_ip, 0) - cluster.append_remote_ip_info(remote_ip, 26000) - return cluster - - -class LLMDataDistConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - self.config = config - self.tp_size = config.parallel_config.tensor_parallel_size - self.rank = rank - self.local_rank = local_rank - - if self.config.kv_transfer_config.kv_role == "kv_producer": - self.role = llm_datadist.LLMRole.PROMPT - elif self.config.kv_transfer_config.kv_role == "kv_consumer": - self.role = llm_datadist.LLMRole.DECODER - else: - raise NotImplementedError( - "kv_role should be inside [kv_producer, kv_consumer]") - - self.world_size = self.config.parallel_config.world_size - self.n_layer = self.config.model_config.get_num_layers( - self.config.parallel_config) - - self.llm_datadist_engine = KVTransferEngine(self.world_size, - self.n_layer, self.role, - self.local_rank) - if self.role == llm_datadist.LLMRole.PROMPT: - self.llm_datadist_engine.prepare_data_dist() - else: - self.llm_datadist_engine.prepare_data_dist() - self.cluster = self.llm_datadist_engine.make_cluster( - self.llm_datadist_engine.prompt_ip_list[self.local_rank], - self.llm_datadist_engine.cluster_id) - _, ret = self.llm_datadist_engine.data_dist.link_clusters( - [self.cluster], 20000) - logger.info(f"local_rank {self.local_rank} link, ret={ret}") - - def send_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors] - ) -> None: - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - - model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - head_size = int(hidden_size / num_attention_heads) - - num_layer = end_layer - start_layer - - # Get shape of input_tokens_tensor and kv_cache - input_shape = (1, input_tokens_tensor.shape[0], 1, 1) - hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size) - kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size) - - assert kv_caches[0].dtype == hidden_or_intermediate_states.dtype - kv_hidden_dtype = kv_caches[0].dtype - input_dtype = torch.int32 - - # initialize LLMDatadist data structure - key_desc = llm_datadist.CacheDesc( - num_layer, - kv_shape, - TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], - seq_len_dim_index=1) - value_desc = llm_datadist.CacheDesc( - num_layer, - kv_shape, - TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], - seq_len_dim_index=1) - input_desc = llm_datadist.CacheDesc( - 1, - input_shape, - TORCH_DTYPE_TO_NPU_DTYPE[input_dtype], - seq_len_dim_index=-1) - hidden_desc = llm_datadist.CacheDesc( - 1, - hidden_shape, - TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], - seq_len_dim_index=-1) - - key_cache_keys = [ - llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 1) - ] - value_cache_keys = [ - llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 2) - ] - input_cache_keys = [ - llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 3) - ] - hidden_cache_keys = [ - llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 4) - ] - - self.key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( - key_desc, key_cache_keys) - self.value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( - value_desc, value_cache_keys) - self.input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( - input_desc, input_cache_keys) - self.hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( - hidden_desc, hidden_cache_keys) - - key_buffer_addr = self.key_buffer.per_device_tensor_addrs[0] - value_buffer_addr = self.value_buffer.per_device_tensor_addrs[0] - input_buffer_addr = self.input_buffer.per_device_tensor_addrs[0] - hidden_buffer_addr = self.hidden_buffer.per_device_tensor_addrs[0] - - self.key_cache = torchair.llm_datadist.create_npu_tensors( - key_desc.shape, kv_hidden_dtype, key_buffer_addr) - self.value_cache = torchair.llm_datadist.create_npu_tensors( - value_desc.shape, kv_hidden_dtype, value_buffer_addr) - self.input_cache = torchair.llm_datadist.create_npu_tensors( - input_desc.shape, input_dtype, input_buffer_addr) - self.hidden_cache = torchair.llm_datadist.create_npu_tensors( - hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addr) - - indices = torch.tensor([0], dtype=torch.int64).npu() - - # copy cache data into llm datadist cache using scatter update - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - current_tokens = input_tokens_tensor[start_pos:end_pos].to( - torch.int32) - - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] - - key_cache = kv_cache[0].view(-1, num_heads, head_size) - value_cache = kv_cache[1].view(-1, num_heads, head_size) - - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - - # copy key into datadist - k = self.key_cache[layer_id][:, start_pos:end_pos, :, :] - new_k = key_cache[current_slot_mapping].unsqueeze(0) - torch_npu.scatter_update_(k, indices, new_k, axis=-2) - - # copy value into datadist - val = self.value_cache[layer_id][:, start_pos:end_pos, :, :] - new_val = value_cache[current_slot_mapping].unsqueeze(0) - torch_npu.scatter_update_(val, indices, new_val, axis=-2) - - # copy input into datadist - inp = self.input_cache[0][:, start_pos:end_pos, :, :] - new_inp = current_tokens.view(1, current_tokens.shape[0], 1, 1) - torch_npu.scatter_update_(inp, indices, new_inp, axis=-2) - - # copy hidden into datadist - hid = self.hidden_cache[0][:, start_pos:end_pos, :, :] - hid_shape0, hid_shape1 = hidden_or_intermediate_states[ - start_pos:end_pos].shape - new_hid = hidden_or_intermediate_states[start_pos:end_pos].view( - 1, hid_shape0, 1, hid_shape1) - torch_npu.scatter_update_(hid, indices, new_hid, axis=-2) - - logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor] - ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - bypass_model_exec = True - - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - - hidden_or_intermediate_states_for_one_req = [] - - input_tokens_list = [] - num_computed_tokens_list = [] - start_pos_list = [] - - # get model config - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - head_size = int(hidden_size / num_attention_heads) - num_layer = end_layer - start_layer - - # get input_tensor_shape and hidden_shape - input_shape = (1, input_tokens_tensor.shape[0], 1, 1) - hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size) - kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size) - - kv_hidden_dtype = kv_caches[0].dtype - input_dtype = torch.int32 - - # Add LLM DataDist initialization - key_desc = llm_datadist.CacheDesc( - num_layer, - kv_shape, - TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], - seq_len_dim_index=-1) - value_desc = llm_datadist.CacheDesc( - num_layer, - kv_shape, - TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], - seq_len_dim_index=-1) - input_desc = llm_datadist.CacheDesc( - 1, - input_shape, - TORCH_DTYPE_TO_NPU_DTYPE[input_dtype], - seq_len_dim_index=-1) - hidden_desc = llm_datadist.CacheDesc( - 1, - hidden_shape, - TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], - seq_len_dim_index=-1) - self.decode_key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( - key_desc) - self.decode_value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( - value_desc) - self.decode_input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( - input_desc) - self.decode_hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( - hidden_desc) - key_buffer_addrs = self.decode_key_buffer.per_device_tensor_addrs[0] - value_buffer_addrs = self.decode_value_buffer.per_device_tensor_addrs[ - 0] - input_buffer_addrs = self.decode_input_buffer.per_device_tensor_addrs[ - 0] - hidden_buffer_addrs = self.decode_hidden_buffer.per_device_tensor_addrs[ - 0] - self.key_cache = torchair.llm_datadist.create_npu_tensors( - key_desc.shape, kv_hidden_dtype, key_buffer_addrs) - self.value_cache = torchair.llm_datadist.create_npu_tensors( - value_desc.shape, kv_hidden_dtype, value_buffer_addrs) - self.input_cache = torchair.llm_datadist.create_npu_tensors( - input_desc.shape, input_dtype, input_buffer_addrs) - self.hidden_cache = torchair.llm_datadist.create_npu_tensors( - hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addrs) - - key_cache_key = llm_datadist.CacheKeyByIdAndIndex( - self.cluster.remote_cluster_id, 1, 0) - value_cache_key = llm_datadist.CacheKeyByIdAndIndex( - self.cluster.remote_cluster_id, 2, 0) - input_cache_key = llm_datadist.CacheKeyByIdAndIndex( - self.cluster.remote_cluster_id, 3, 0) - hidden_cache_key = llm_datadist.CacheKeyByIdAndIndex( - self.cluster.remote_cluster_id, 4, 0) - - self.llm_datadist_engine.kv_transfer.pull_cache( - key_cache_key, self.decode_key_buffer, 0) - self.llm_datadist_engine.kv_transfer.pull_cache( - value_cache_key, self.decode_value_buffer, 0) - self.llm_datadist_engine.kv_transfer.pull_cache( - input_cache_key, self.decode_input_buffer, 0) - self.llm_datadist_engine.kv_transfer.pull_cache( - hidden_cache_key, self.decode_hidden_buffer, 0) - - keys = self.key_cache - values = self.value_cache - inputs = self.input_cache - hidden = self.hidden_cache - - # enumerate different requests - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - current_tokens = input_tokens_tensor[start_pos:end_pos] - num_tokens = slen - - # collecting data for rebuilding the input - input_tokens_list.append(current_tokens) - start_pos_list.append(start_pos) - - num_computed_tokens = inputs[0][0, start_pos:end_pos, 0, - 0].shape[0] - num_computed_tokens_list.append(num_computed_tokens) - - # check if both KV cache and the hidden states are received - # If not, need to redo the forwarding to compute missing states - if not all([(num_computed_tokens == num_tokens), hidden is not None - ]): - bypass_model_exec = False - - # update the end position based on how many tokens are cached. - end_pos = start_pos + num_computed_tokens - - # put received KV caches into paged memory - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - kv_cache = kv_caches[i - model_executable.model.start_layer] - key_cache, value_cache = kv_cache[0], kv_cache[1] - - sliced_key = keys[i - model_executable.model.start_layer][ - 0, start_pos:end_pos, :, :] - sliced_value = values[i - model_executable.model.start_layer][ - 0, start_pos:end_pos, :, :] - - torch_npu._npu_reshape_and_cache( - key=sliced_key, - value=sliced_value, - key_cache=key_cache, - value_cache=value_cache, - slot_indices=slot_mapping[start_pos:end_pos]) - - hidden_or_intermediate_states_for_one_req.append( - hidden[0][0, start_pos:end_pos, 0, :]) - - if not bypass_model_exec: - # Some of the KV cache is not retrieved - # Here we will fall back to normal model forwarding - # But optionally you can adjust model_input so that you only do - # prefilling on those tokens that are missing KV caches. - logger.info( - "[rank%d][D]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = None - else: - logger.info( - "[rank%d][D]: Successfully received all KVs and hidden " - "states, skip model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - return hidden_or_intermediate_states, bypass_model_exec, model_input - - def close(self, ): - self.llm_datadist_engine.data_dist.unlink_clusters([self.cluster], - 5000)