diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7552fc889f2f..7acc046c5250 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -664,7 +664,20 @@ def add_remote_agent(self, # TODO re-evaluate refreshing for scaling/recovery if remote_tp_rank in self._remote_agents.get(engine_id, ()): return + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + assert self._tp_size[self.engine_id] % nixl_agent_meta.tp_size == 0, ( + "Local TP size must be divisible by remote TP size.") + tp_ratio = self._tp_size[self.engine_id] // nixl_agent_meta.tp_size + assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + if remote_tp_rank != self.tp_rank // tp_ratio: + # Only register remote agents that this local rank pulls from. + logger.debug( + "Skipping registration of remote agent %s with rank %s " + "as it is not the one this local rank %s pulls from.", + engine_id, remote_tp_rank, self.tp_rank) + return if engine_id in self._tp_size: assert self._tp_size[engine_id] == nixl_agent_meta.tp_size else: @@ -676,13 +689,6 @@ def add_remote_agent(self, self._remote_agents[engine_id][ remote_tp_rank] = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata) - - # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. - assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, ( - "Local TP size must be divisible by remote TP size.") - tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] - assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" if self.use_mla: # With MLA the only difference is in the number of blocks. remote_block_size = nixl_agent_meta.block_len // (