Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support disaggregated prefill with Mooncake Transfer Engine #10884

Merged
merged 25 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d52dbc8
Rebase from main to work with PR 10502.
ShangmingCai Dec 2, 2024
c8e9d07
Update format of mooncake config ValueError.
ShangmingCai Dec 2, 2024
b718f1e
Modify metadata transfer logic to support tp.
ShangmingCai Dec 3, 2024
08e2800
Fix format to make ruff happy.
ShangmingCai Dec 3, 2024
8179746
Add instructions when mooncake is not installed.
ShangmingCai Dec 3, 2024
ba82d71
Merge branch 'main' into upstream-mooncake-integration
ShangmingCai Dec 3, 2024
76d484c
Merge branch 'main' into upstream-mooncake-integration
ShangmingCai Dec 4, 2024
e912055
fix import order to make isort happy.
ShangmingCai Dec 4, 2024
2396f01
Fix format to make yapf happy.
ShangmingCai Dec 4, 2024
31514a0
Add solution for ports conflict on the same node.
ShangmingCai Dec 4, 2024
2ef10be
Fix format to make mypy happy.
ShangmingCai Dec 4, 2024
0823e47
Get head_size and num_heads from model config to address bugs on Volt…
ShangmingCai Dec 10, 2024
6fb95fb
Add support for other metadata server backend.
ShangmingCai Dec 10, 2024
a5758b1
Change code to align with PR 11058.
ShangmingCai Dec 10, 2024
33e4455
Fix typo.
ShangmingCai Dec 11, 2024
f3312b9
Merge branch 'main' into upstream-mooncake-integration
ShangmingCai Dec 15, 2024
343c474
Reuse simple connector for mooncake pipe.
ShangmingCai Dec 15, 2024
eaa1a45
Remove mooncake connector.
ShangmingCai Dec 15, 2024
83e4db9
fix isort.
ShangmingCai Dec 15, 2024
e8ee5c2
fix mypy.
ShangmingCai Dec 15, 2024
bc01eae
move PyNcclPipe import to fix mypy.
ShangmingCai Dec 15, 2024
b45ff65
still trying to fix mypy.
ShangmingCai Dec 15, 2024
aeccf4f
fix typo and fix mypy.
ShangmingCai Dec 15, 2024
8c2135a
trying to fix mypy again.
ShangmingCai Dec 15, 2024
875ca4c
remove unused kvpipe base.
ShangmingCai Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,13 +2171,14 @@ def from_cli(cls, cli_value: str) -> "KVTransferConfig":
return KVTransferConfig.model_validate_json(cli_value)

def model_post_init(self, __context: Any) -> None:
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if all([
self.kv_connector is not None,
self.kv_connector != "PyNcclConnector"
self.kv_connector is not None, self.kv_connector
not in supported_kv_connector
]):
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
f"Supported connectors are "
f"`PyNcclConnector`.")
f"{supported_kv_connector}.")

if self.kv_role is not None and self.kv_role not in [
"kv_producer", "kv_consumer", "kv_both"
Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class KVConnectorFactory:
@staticmethod
def create_connector(rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
if config.kv_transfer_config.kv_connector == 'PyNcclConnector':
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if config.kv_transfer_config.kv_connector in supported_kv_connector:
from .simple_connector import SimpleConnector
return SimpleConnector(rank, local_rank, config)
else:
Expand Down
101 changes: 74 additions & 27 deletions vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Simple KV Cache Connector for Distributed Machine Learning Inference

The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe.
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
MooncakePipe.

But the logic can be extended to support other pipe and lookup buffer.
"""
Expand All @@ -15,7 +16,6 @@
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors

Expand All @@ -36,32 +36,66 @@ def __init__(

self.config = config.kv_transfer_config

logger.info("Initializing PyNcclConfig under kv_transfer_config %s",
if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
PyNcclPipe)
logger.info(
"Initializing PyNcclConfig under kv_transfer_config %s",
self.config)
elif self.config.kv_connector == "MooncakeConnector":
# Check if MOONCAKE_CONFIG_PATH is set
import os
use_mooncake_distributed_pipe = os.getenv(
'MOONCAKE_CONFIG_PATH') is not None

if not use_mooncake_distributed_pipe:
raise ValueError(
"To use MooncakeConnector, you need to pass the ENV: "
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
else:
from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501
MooncakePipe)
logger.info(
"Initializing MooncakeConfig under kv_transfer_config %s",
self.config)

self.lookup_buffer_size = self.config.kv_buffer_size

self.producer_buffer: Optional[SimpleBuffer] = None
self.consumer_buffer: Optional[SimpleBuffer] = None

self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe]
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]

# 2 pipes for every rank in the world
port_offset_base = 2 * rank

# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:

self.producer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
if self.config.kv_connector == "PyNcclConnector":
self.producer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
elif self.config.kv_connector == "MooncakeConnector":
self.producer_data_pipe = MooncakePipe(
local_rank=local_rank,
config=self.config,
)
# We only need to initialize MooncakePipe once
self.producer_signal_pipe = self.producer_data_pipe

self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
self.producer_data_pipe,
self.config.kv_buffer_size)
Expand All @@ -70,17 +104,25 @@ def __init__(

# the current vLLM instance is KV consumer, so it needs to connect
# its recv pipe to the send pipe of KV producder
self.consumer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
elif self.config.kv_connector == "MooncakeConnector":
self.consumer_data_pipe = MooncakePipe(
local_rank=local_rank,
config=self.config,
)
self.consumer_signal_pipe = self.consumer_data_pipe

self.consumer_buffer = SimpleBuffer(
self.consumer_signal_pipe,
self.consumer_data_pipe,
Expand Down Expand Up @@ -260,6 +302,11 @@ def recv_kv_caches_and_hidden_states(

def close(self):
self.producer_data_pipe.close()
self.producer_signal_pipe.close()
self.consumer_data_pipe.close()
self.consumer_signal_pipe.close()
if self.config.kv_connector == "PyNcclConnector":
self.producer_signal_pipe.close()
self.consumer_signal_pipe.close()
elif self.config.kv_connector == "MooncakeConnector":
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass
Loading
Loading