Skip to content

Commit c72fdad

Browse files
committed
add some fix for llmdatadist
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
1 parent cc4f9fb commit c72fdad

File tree

2 files changed

+25
-23
lines changed

2 files changed

+25
-23
lines changed

vllm_ascend/distributed/llmdatadist_connector_v1_a3.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: int):
167167
# Can not retrive the parallel config since it is not initialized.
168168
self.local_dp_rank = None
169169
self.tp_size = None
170+
self.port = self.vllm_config.parallel_config.data_parallel_rank_local * self.vllm_config.parallel_config.tensor_parallel_size + envs.VLLM_LLMDD_CHANNEL_PORT
170171

171172
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
172173

@@ -244,12 +245,6 @@ def request_finished(
244245
request: "Request",
245246
block_ids: list[int],
246247
) -> tuple[bool, Optional[dict[str, Any]]]:
247-
if self.local_dp_rank is None:
248-
vllm_config = get_current_vllm_config()
249-
# Need this dp rank to locate the only dp rank the kv cache from
250-
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
251-
# Need this tp size to offset the port in tp size
252-
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
253248

254249
params = request.kv_transfer_params
255250
logger.debug(
@@ -267,13 +262,13 @@ def request_finished(
267262
# If prompt < block_size, no xfer so free blocks immediately.
268263
delay_free_blocks = len(computed_block_ids) > 0
269264

270-
return delay_free_blocks, dict(
265+
return False, dict(
271266
do_remote_prefill=True,
272267
do_remote_decode=False,
273268
remote_block_ids=computed_block_ids,
274269
remote_engine_id=self.engine_id,
275270
remote_host=self.local_ip,
276-
remote_port=envs.VLLM_LLMDD_CHANNEL_PORT + self.local_dp_rank * self.tp_size,
271+
remote_port=self.port,
277272
)
278273

279274
class LLMDataDistConnectorWorker():
@@ -295,6 +290,7 @@ def __init__(
295290
self.local_ip = get_ip()
296291
self.kv_transfer_config: Optional[KVTransferConfig] = vllm_config.kv_transfer_config
297292
self.local_agent_metadata: Optional[LLMDataDistAgentMetadata] = None
293+
self.vllm_config = vllm_config
298294

299295
self.llm_datadist_role = None
300296
self.llm_datadist_remote_role = None
@@ -498,11 +494,11 @@ def start_load_kv(self, metadata: LLMDataDistConnectorMetadata):
498494
)
499495
self.finished_reqs.add(req_id)
500496

501-
def add_remote_agent(self, metadata: LLMDataDistAgentMetadata) -> bool:
497+
def add_remote_agent(self, metadata: LLMDataDistAgentMetadata) -> int:
502498
remote_cluster_id = metadata.cluster_id
503499
if remote_cluster_id in self.linked_cluster:
504500
logger.debug(f"LLMDataDistConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection")
505-
return False
501+
return remote_cluster_id
506502
remote_super_pod_id = metadata.super_pod_id
507503
remote_device_id = metadata.device_id
508504
remote_device_ip = metadata.device_ip
@@ -618,10 +614,10 @@ def add_remote_agent(self, metadata: LLMDataDistAgentMetadata) -> bool:
618614
raise RuntimeError(f"LLMDataDistConnectorWorker: Linking failed, comm id: {comm_id}")
619615
time.sleep(1)
620616
logger.info("Checking query_register_mem_status again")
621-
self.linked_cluster.update({remote_server_id: (remote_cluster_id, comm_id)})
617+
self.linked_cluster.update({remote_cluster_id: comm_id})
622618
logger.info(f"cached linked cluster: {self.linked_cluster}")
623619
logger.info(f"Sucessfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !")
624-
return True
620+
return remote_cluster_id
625621

626622

627623
def remove_remote_agent(self, cluster_id: int):
@@ -641,7 +637,7 @@ def connect_to_remote_agent(
641637
host: str,
642638
port: int
643639
):
644-
url = f"tcp://{host}:{port + self.tp_rank}"
640+
url = f"tcp://{host}:{port}"
645641
logger.debug(f"Querying metadata from url: {url}")
646642
msg_encoder = msgspec.msgpack.Encoder()
647643
msg_send = msg_encoder.encode(self.local_agent_metadata)
@@ -653,7 +649,8 @@ def connect_to_remote_agent(
653649
metadata = decoder.decode(metadata_bytes)
654650
metadata = LLMDataDistAgentMetadata(**metadata)
655651
logger.info(f"recving metadata: {metadata}")
656-
self.add_remote_agent(metadata)
652+
cluster_id = self.add_remote_agent(metadata)
653+
return cluster_id
657654

658655
def _read_blocks(
659656
self,
@@ -664,8 +661,8 @@ def _read_blocks(
664661
remote_engine_id: str,
665662
request_id: str,
666663
):
667-
if remote_ip not in self.linked_cluster:
668-
self.connect_to_remote_agent(remote_ip, remote_port)
664+
# if remote_ip not in self.linked_cluster:
665+
self.connect_to_remote_agent(remote_ip, remote_port + self.tp_rank)
669666
num_local_blocks = len(local_block_ids)
670667
if num_local_blocks == 0:
671668
return
@@ -681,8 +678,8 @@ def _read_blocks(
681678
remote_cache_key_k_pe = BlocksCacheKey(cluster_id=remote_cluster_id, model_id=1)
682679
logger.info("Try pull blocks from remote server")
683680
try:
684-
self.cache_manager.pull_blocks(remote_cache_key_k_normed, self.cache[0], local_block_ids, remote_block_ids)
685-
self.cache_manager.pull_blocks(remote_cache_key_k_pe, self.cache[1], local_block_ids, remote_block_ids)
681+
self.cache_manager.pull_blocks(remote_cache_key_k_normed, self.cache[0], remote_block_ids, local_block_ids)
682+
self.cache_manager.pull_blocks(remote_cache_key_k_pe, self.cache[1], remote_block_ids, local_block_ids)
686683
except (TypeError, ValueError) as e:
687684
raise RuntimeError(f"LLMDataDistConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}")
688685
except LLMException:
@@ -691,7 +688,7 @@ def _read_blocks(
691688
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
692689
logger.info("Try pull blocks from remote server")
693690
try:
694-
self.cache_manager.pull_blocks(remote_cache_key, self.cache, local_block_ids, remote_block_ids)
691+
self.cache_manager.pull_blocks(remote_cache_key, self.cache, remote_block_ids, local_block_ids)
695692
except (TypeError, ValueError) as e:
696693
raise RuntimeError(f"LLMDataDistConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}")
697694
except LLMException:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
13891389
"""
13901390
import torch_npu
13911391
kv_caches: Dict[str, torch.Tensor] = {}
1392+
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
1393+
data_ptr = tensor.data_ptr()
1394+
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
1395+
offset = (aligned_addr - data_ptr) // tensor.element_size
1396+
return tensor[int(offset):]
13921397

13931398
# Remove this after we drop 0.9.0 support
13941399
if vllm_version_is("0.9.0"):
@@ -1461,10 +1466,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
14611466
rope_allocate_shape_alignment = rope_allocate_shape + alignment
14621467
nope_cache_shape = (num_blocks, block_size, num_kv_heads, nope_dim)
14631468
rope_cache_shape = (num_blocks, block_size, num_kv_heads, rope_dim)
1464-
rope_cache = torch.zeros(nope_allocate_shape_alignment, dtype=dtype, device=self.device)
1465-
nope_cache = torch.zeros(rope_allocate_shape_alignment, dtype=dtype, device=self.device)
1466-
rope_cache = align_memory(nope_cache, alignment)[:nope_allocate_shape].view(nope_cache_shape)
1467-
nope_cache = align_memory(rope_cache, alignment)[:rope_allocate_shape].view(rope_cache_shape)
1469+
nope_cache = torch.zeros(nope_allocate_shape_alignment, dtype=dtype, device=self.device)
1470+
rope_cache = torch.zeros(rope_allocate_shape_alignment, dtype=dtype, device=self.device)
1471+
nope_cache = align_memory(nope_cache, alignment)[:nope_allocate_shape].view(nope_cache_shape)
1472+
rope_cache = align_memory(rope_cache, alignment)[:rope_allocate_shape].view(rope_cache_shape)
14681473
kv_caches[layer_name] = (nope_cache, rope_cache)
14691474
else:
14701475
num_caches = kv_cache_shape[0]

0 commit comments

Comments
 (0)