Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class ParallelConfig:
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
context_parallel_size: int = 1
"""Number of context parallel groups."""
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
Expand Down Expand Up @@ -103,6 +105,8 @@ class ParallelConfig:
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
enable_sequence_parallel: bool = False
"""Enable sequence parallel."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
Expand Down Expand Up @@ -314,7 +318,7 @@ def __post_init__(self) -> None:

# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
self.tensor_parallel_size * self.context_parallel_size

if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
Expand Down
54 changes: 47 additions & 7 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,24 @@ def get_pipeline_model_parallel_group():
return get_pp_group()


_CP: Optional[GroupCoordinator] = None


def get_cp_group() -> GroupCoordinator:
assert _CP is not None, ("context parallel group is not initialized")
return _CP


def get_context_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_cp_group().world_size


def get_context_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return get_cp_group().rank_in_group


@contextmanager
def graph_capture(device: torch.device):
"""
Expand Down Expand Up @@ -1034,6 +1052,7 @@ def init_distributed_environment(
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
context_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
Expand Down Expand Up @@ -1082,7 +1101,7 @@ def initialize_model_parallel(
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape(
-1, data_parallel_size, pipeline_model_parallel_size,
-1, data_parallel_size, pipeline_model_parallel_size, context_model_parallel_size,
tensor_model_parallel_size) # noqa

# Build the tensor model-parallel groups.
Expand All @@ -1102,7 +1121,7 @@ def initialize_model_parallel(
global _PP
assert _PP is None, (
"pipeline model parallel group is already initialized")
group_ranks = all_ranks.transpose(2, 3).reshape(
group_ranks = all_ranks.transpose(2, 4).reshape(
-1, pipeline_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(group_ranks,
Expand All @@ -1113,7 +1132,7 @@ def initialize_model_parallel(
global _DP
assert _DP is None, ("data parallel group is already initialized")
group_ranks = all_ranks.transpose(1,
3).reshape(-1,
4).reshape(-1,
data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group(group_ranks,
Expand All @@ -1124,23 +1143,34 @@ def initialize_model_parallel(
global _EP
assert _EP is None, ("expert parallel group is already initialized")
group_ranks = all_ranks.transpose(1, 2).reshape(
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
-1, data_parallel_size * tensor_model_parallel_size * context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_EP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="ep")

global _CP
assert _CP is None, ("context parallel group is already initialized")
group_ranks = all_ranks.transpose(3, 4).reshape(
-1, context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_CP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="cp")

logger.info(
"rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
"DP rank %s, PP rank %s, TP rank %s, EP rank %s, CP rank %s", rank, world_size,
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group)
_EP.rank_in_group, _CP.rank_in_group)


def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
context_model_parallel_size: int,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
Expand All @@ -1151,7 +1181,7 @@ def ensure_model_parallel_initialized(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend)
pipeline_model_parallel_size, context_model_parallel_size, backend)
return

assert (
Expand All @@ -1164,6 +1194,11 @@ def ensure_model_parallel_initialized(
"pipeline parallel group already initialized, but of unexpected size. "
f"got: {pp_world_size=} vs. "
f"wanted: {pipeline_model_parallel_size=}")
cp_world_size = get_cp_group().world_size
assert (cp_world_size == context_model_parallel_size), (
"context parallel group already initialized, but of unexpected size: "
f"{cp_world_size=} vs. "
f"{context_model_parallel_size=}")


def prepare_communication_buffer_for_model(model: torch.nn.Module):
Expand Down Expand Up @@ -1256,6 +1291,11 @@ def destroy_model_parallel():
_EP.destroy()
_EP = None

global _CP
if _CP:
_CP.destroy()
_CP = None


def destroy_distributed_environment():
global _WORLD, _NODE_COUNT
Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
context_parallel_size: int = ParallelConfig.context_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None
data_parallel_start_rank: Optional[int] = None
Expand All @@ -303,6 +304,7 @@ class EngineArgs:
data_parallel_rpc_port: Optional[int] = None
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_sequence_parallel: bool = ParallelConfig.enable_sequence_parallel
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb
Expand Down Expand Up @@ -623,6 +625,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**parallel_kwargs["pipeline_parallel_size"])
parallel_group.add_argument("--tensor-parallel-size", "-tp",
**parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument("--context-parallel-size", "-cp",
**parallel_kwargs["context_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument(
Expand Down Expand Up @@ -660,6 +664,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parallel_group.add_argument(
"--data-parallel-hybrid-lb",
**parallel_kwargs["data_parallel_hybrid_lb"])
parallel_group.add_argument(
"--enable-sequence-parallel",
**parallel_kwargs["enable_sequence_parallel"])
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
Expand Down Expand Up @@ -1273,6 +1280,7 @@ def create_engine_config(
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
context_parallel_size=self.context_parallel_size,
data_parallel_size=self.data_parallel_size,
data_parallel_rank=self.data_parallel_rank or 0,
data_parallel_external_lb=data_parallel_external_lb,
Expand All @@ -1281,6 +1289,7 @@ def create_engine_config(
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_sequence_parallel=self.enable_sequence_parallel,
enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb,
eplb_config=self.eplb_config,
Expand Down
18 changes: 13 additions & 5 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank, get_context_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -163,9 +163,11 @@ def make(
@dataclass
class FusedMoEParallelConfig:
tp_size: int
cp_size: int
dp_size: int
ep_size: int
tp_rank: int
cp_rank: int
dp_rank: int
ep_rank: int

Expand Down Expand Up @@ -197,7 +199,7 @@ def use_flashinfer_cutlass_kernels(self):
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")

@staticmethod
def make(tp_size_: int, dp_size_: int,
def make(tp_size_: int, dp_size_: int, cp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
"""
Determine MoE parallel configuration. Based on the input `tp_size_`,
Expand Down Expand Up @@ -278,16 +280,20 @@ def flatten_tp_across_dp(dp_rank: int):
tp_rank = dp_rank * tp_size_ + tp_rank
return tp_size, tp_rank

use_ep = (dp_size_ * tp_size_ > 1
use_ep = (dp_size_ * tp_size_ * cp_size_ > 1
and vllm_parallel_config.enable_expert_parallel)

dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
cp_size = cp_size_
cp_rank = get_context_model_parallel_rank() if cp_size_ > 1 else 0

if not use_ep:
return FusedMoEParallelConfig(tp_size=tp_size,
tp_rank=tp_rank,
cp_size=cp_size,
cp_rank=cp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=1,
Expand All @@ -297,10 +303,12 @@ def flatten_tp_across_dp(dp_rank: int):
assert use_ep
# In EP, each device owns a set of experts fully. There is no tensor
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
ep_size = tp_size
ep_rank = tp_rank
ep_size = tp_size * cp_size
ep_rank = tp_rank + tp_size * cp_rank
return FusedMoEParallelConfig(tp_size=1,
tp_rank=0,
cp_size=1,
cp_rank=0,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=ep_size,
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size,
get_context_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.forward_context import ForwardContext, get_forward_context
Expand Down Expand Up @@ -759,6 +760,7 @@ def __init__(
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
cp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
Expand All @@ -778,12 +780,15 @@ def __init__(
get_tensor_model_parallel_world_size())
dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size)
cp_size_ = (cp_size
if cp_size is not None else get_context_model_parallel_world_size())

vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=tp_size_,
dp_size_=dp_size_,
cp_size_=cp_size_,
vllm_parallel_config=vllm_config.parallel_config))

self.global_num_experts = num_experts + num_redundant_experts
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class NewRequestData:
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
num_computed_tokens_of_cp_sp: Optional[list[list[int]]]

@classmethod
def from_request(
Expand All @@ -50,6 +51,7 @@ def from_request(
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
num_computed_tokens_of_cp_sp=request.num_computed_tokens_of_cp_sp,
)

def __repr__(self):
Expand Down Expand Up @@ -93,6 +95,8 @@ class CachedRequestData:
new_token_ids: list[list[int]]
new_block_ids: list[Optional[tuple[list[int], ...]]]
num_computed_tokens: list[int]
kv_rank: list[tuple[int]]
num_computed_tokens_of_cp_sp: list[list[list[int]]]

@property
def num_reqs(self) -> int:
Expand All @@ -106,6 +110,8 @@ def make_empty(cls) -> CachedRequestData:
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
kv_rank=[],
num_computed_tokens_of_cp_sp=[],
)


Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@ def _make_cached_request_data(
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
num_computed_tokens: list[int] = []

# cp param
kv_rank: list[tuple[int]] = []
num_computed_tokens_of_cp_sp: list[list[list[int]]] = []

use_connector = self.connector is not None
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
Expand All @@ -638,6 +642,8 @@ def _make_cached_request_data(
new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
num_computed_tokens.append(req.num_computed_tokens)
kv_rank.append(req.kv_rank)
num_computed_tokens_of_cp_sp.append(req.num_computed_tokens_of_cp_sp)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
Expand All @@ -649,6 +655,8 @@ def _make_cached_request_data(
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
kv_rank=kv_rank,
num_computed_tokens_of_cp_sp=num_computed_tokens_of_cp_sp,
)

def _try_schedule_encoder_inputs(
Expand Down
8 changes: 5 additions & 3 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ def _init_executor(self) -> None:
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
context_parallel_size = self.parallel_config.context_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size * context_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). ")
f"_parallel_size ({pp_parallel_size}) x context"
f"_parallel_size ({context_parallel_size}). ")

# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
Expand Down Expand Up @@ -323,7 +325,7 @@ def _get_output_rank(self) -> int:
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size
return self.world_size - self.parallel_config.tensor_parallel_size * self.parallel_config.context_parallel_size


@dataclass
Expand Down
Loading