Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tools/pre_commit/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/distributed",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging_utils",
Expand All @@ -42,7 +43,6 @@
"tests",
"vllm/attention",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/executor",
"vllm/inputs",
Expand Down
2 changes: 1 addition & 1 deletion vllm/config/kv_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class KVTransferConfig:
engine_id: str | None = None
"""The engine id for KV transfers."""

kv_buffer_device: str | None = "cuda"
kv_buffer_device: str = "cuda"
"""The device used by kv connector to buffer the KV cache. Choices are
'cuda' and 'cpu'."""

Expand Down
36 changes: 24 additions & 12 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from .base_device_communicator import All2AllManagerBase, Cache

if has_flashinfer_all2all():
from flashinfer.comm import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -65,6 +67,7 @@ def dispatch(
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

hidden_states = self.naive_multicast(
Expand All @@ -81,6 +84,7 @@ def combine(
ep_rank = self.rank if is_sequence_parallel else self.dp_rank

dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

Expand Down Expand Up @@ -113,7 +117,10 @@ def dispatch(
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None

dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
Expand All @@ -130,7 +137,10 @@ def combine(
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None

dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
Expand All @@ -155,7 +165,7 @@ def __init__(self, cpu_group):
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
Expand All @@ -182,7 +192,7 @@ def __init__(self, cpu_group):
self.handle_cache = Cache()

def get_handle(self, kwargs):
import pplx_kernels as pplx
import pplx_kernels as pplx # type: ignore[import-not-found]

return self.handle_cache.get_or_create(
kwargs,
Expand All @@ -208,7 +218,9 @@ def destroy(self):
handle.destroy()

if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)

logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
Expand Down Expand Up @@ -288,7 +300,7 @@ def get_handle(self, kwargs):
"args are computed in the Manager itself."
)

import deep_ep
import deep_ep # type: ignore[import-not-found]

buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
Expand All @@ -298,7 +310,7 @@ def get_handle(self, kwargs):
return handle

def set_num_sms(self, num_sms: int):
import deep_ep
import deep_ep # type: ignore[import-not-found]

# Right now the buffers are sized for only what the kernels were
# created with. So we can only reduce the number of SMS used
Expand Down Expand Up @@ -332,7 +344,7 @@ def _make_all2all_kwargs(
num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank.
"""
import deep_ep
import deep_ep # type: ignore[import-not-found]

# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
Expand All @@ -358,7 +370,7 @@ def get_handle(self, kwargs):
The kwargs for DeepEPLLAll2AllManager is dictated by
_make_all2all_kwargs.
"""
import deep_ep
import deep_ep # type: ignore[import-not-found]

buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
Expand Down
19 changes: 12 additions & 7 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from contextlib import contextmanager
from typing import cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -118,15 +119,18 @@ def __init__(
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability().as_version_str()
device_capability = current_platform.get_device_capability()
if (
current_platform.is_cuda()
and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES
and device_capability is not None
):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size
)
device_capability_str = device_capability.as_version_str()
if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES:
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size],
max_size,
)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
Expand Down Expand Up @@ -213,6 +217,7 @@ def register_graph_buffers(self):
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data: list[list[list[int] | None]]
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
Expand All @@ -221,8 +226,8 @@ def register_graph_buffers(self):
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
handles = cast(list[list[int]], [d[0] for d in all_data])
offsets = cast(list[list[int]], [d[1] for d in all_data])
ops.register_graph_buffers(self._ptr, handles, offsets)

def should_custom_ar(self, inp: torch.Tensor):
Expand Down
11 changes: 8 additions & 3 deletions vllm/distributed/device_communicators/symm_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,14 @@ def __init__(
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = (
current_platform.get_device_capability().as_version_str()
)
capability = current_platform.get_device_capability()
if capability is None:
logger.warning(
"SymmMemCommunicator: device capability is unknown, "
"communicator is not available."
)
return
self.device_capability = capability.as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
Expand Down
14 changes: 12 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import (
Expand Down Expand Up @@ -48,6 +48,8 @@ def create_connector(
)

kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(kv_transfer_config)
logger.info(
"Creating v1 connector with name: %s and engine_id: %s",
Expand All @@ -70,14 +72,22 @@ def get_connector_class(
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector
if connector_name is None:
raise ValueError("Connector name is not set in KVTransferConfig")
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
try:
connector_cls = getattr(connector_module, connector_name)
except AttributeError as e:
raise AttributeError(
f"Class {connector_name} not found in {connector_module_path}"
) from e
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
return connector_cls


Expand Down
14 changes: 7 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,21 @@ def update_finished_set(
aggregated_kv_connector_stats = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
output = model_runner_output.kv_connector_output
if not output:
kv_output = model_runner_output.kv_connector_output
if not kv_output:
continue
update_finished_set(
output.finished_sending, self._send_remaining_count, finished_sending
kv_output.finished_sending, self._send_remaining_count, finished_sending
)
update_finished_set(
output.finished_recving, self._recv_remaining_count, finished_recving
kv_output.finished_recving, self._recv_remaining_count, finished_recving
)

# Aggregate kv_connector_stats from all workers.
if aggregated_kv_connector_stats is None:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = output.kv_connector_stats
elif kv_connector_stats := output.kv_connector_stats:
aggregated_kv_connector_stats = kv_output.kv_connector_stats
elif kv_connector_stats := kv_output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
Expand All @@ -176,7 +176,7 @@ def update_finished_set(
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)

invalid_block_ids |= output.invalid_block_ids
invalid_block_ids |= kv_output.invalid_block_ids

# select output of the worker specified by output_rank
output = outputs[output_rank]
Expand Down
4 changes: 4 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
)
self._connector_metadata: KVConnectorMetadata | None = None
self._vllm_config = vllm_config
if vllm_config.kv_transfer_config is not None:
self._kv_transfer_config = vllm_config.kv_transfer_config
else:
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
self._role = role

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
Expand Down Expand Up @@ -296,6 +294,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
Expand Down
18 changes: 12 additions & 6 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)
assert vllm_config.kv_transfer_config is not None
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
logger.info("Initializing NIXL Scheduler %s", engine_id)

Expand Down Expand Up @@ -334,7 +335,8 @@ def get_num_new_matched_tokens(

if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
count = len(request.prompt_token_ids) - num_computed_tokens
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
if count > 0:
return count, True

Expand Down Expand Up @@ -515,6 +517,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size

if vllm_config.kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set for NixlConnector")

self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"]
)
Expand Down Expand Up @@ -571,17 +576,18 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.use_host_buffer = self.kv_buffer_device == "cpu"
# support for oot platform which can't register nixl memory
# type based on kv_buffer_device
self.nixl_memory_type = current_platform.get_nixl_memory_type()
if self.nixl_memory_type is None:
nixl_memory_type = current_platform.get_nixl_memory_type()
if nixl_memory_type is None:
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
if self.nixl_memory_type is None:
nixl_memory_type = "DRAM"
if nixl_memory_type is None:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported."
)
self.nixl_memory_type = nixl_memory_type

# Note: host xfer buffer ops when use_host_buffer is True
self.copy_blocks: CopyBlocksOp | None = None
Expand Down
Loading