diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index f80eb1adc7fd..8b980458ddaf 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -279,6 +279,24 @@ def stateless_init_dp_group(self) -> ProcessGroup: assert last_exc is not None raise last_exc + # The all_reduce at the end of attention (during o_proj) means that + # inputs are replicated across each rank of the tensor parallel group. + # If using expert-parallelism with DeepEP All2All ops, replicated + # tokens results in useless duplicate computation and communication. + # + # In this case, ensure the input to the experts is sequence parallel + # to avoid the excess work. + # + # Not needed for pplx-kernels as it can handle duplicate input tokens. + @property + def use_sequence_parallel_moe(self) -> bool: + return (envs.VLLM_ALL2ALL_BACKEND + in ("allgather_reducescatter", "naive", + "deepep_high_throughput", "deepep_low_latency") + and self.enable_expert_parallel + and self.tensor_parallel_size > 1 + and self.data_parallel_size > 1) + @staticmethod def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool: diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 661ed939608a..bb3fd657facd 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -6,7 +6,7 @@ import torch.distributed as dist import vllm.envs as envs -from vllm.distributed import get_dp_group +from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils import has_deep_ep, has_pplx @@ -34,41 +34,60 @@ def __init__(self, cpu_group): super().__init__(cpu_group) def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): + cu_tokens_across_sp_cpu: torch.Tensor, + is_sequence_parallel: bool) -> torch.Tensor: assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype) - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] + rank = self.rank if is_sequence_parallel else self.dp_rank + world_size = (self.world_size + if is_sequence_parallel else self.dp_world_size) + + start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1] + end = cu_tokens_across_sp_cpu[rank] buffer[start:end, :].copy_(x) - for idx in range(self.dp_world_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - self.dp_group.broadcast(buffer[start:end, :], idx) + for idx in range(world_size): + start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1] + end = cu_tokens_across_sp_cpu[idx] + get_ep_group().broadcast(buffer[start:end, :], idx) return buffer - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - sizes = get_forward_context( - ).dp_metadata.get_chunk_sizes_across_dp_rank() - hidden_states, router_logits = get_dp_group().all_gatherv( - [hidden_states, router_logits], - dim=0, - sizes=sizes, - ) - + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: + sp_size = self.tp_group.world_size if is_sequence_parallel else 1 + dp_metadata = get_forward_context().dp_metadata + cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_sp_cpu, + is_sequence_parallel) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_sp_cpu, + is_sequence_parallel) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - sizes = get_forward_context( - ).dp_metadata.get_chunk_sizes_across_dp_rank() - hidden_states = get_dp_group().reduce_scatterv(hidden_states, - dim=0, - sizes=sizes) + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: + + ep_rank = self.rank if is_sequence_parallel else self.dp_rank + + dp_metadata = get_forward_context().dp_metadata + sp_size = self.tp_group.world_size if is_sequence_parallel else 1 + cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) + + start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1] + end = cu_tokens_across_sp_cpu[ep_rank] + + all_hidden_states = get_ep_group().all_reduce(hidden_states) + hidden_states = all_hidden_states[start:end, :] return hidden_states def destroy(self): @@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase): def __init__(self, cpu_group): super().__init__(cpu_group) - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: """ Gather hidden_states and router_logits from all dp ranks. """ sizes = get_forward_context( ).dp_metadata.get_chunk_sizes_across_dp_rank() - hidden_states, router_logits = get_dp_group().all_gatherv( + + dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() + assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] + hidden_states, router_logits = dist_group.all_gatherv( [hidden_states, router_logits], dim=0, sizes=sizes, ) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: """ Reduce-scatter hidden_states across all dp ranks. """ sizes = get_forward_context( ).dp_metadata.get_chunk_sizes_across_dp_rank() - hidden_states = get_dp_group().reduce_scatterv(hidden_states, - dim=0, - sizes=sizes) + + dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() + hidden_states = dist_group.reduce_scatterv(hidden_states, + dim=0, + sizes=sizes) return hidden_states def destroy(self): @@ -148,11 +178,17 @@ def get_handle(self, kwargs): kwargs, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode) - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: raise NotImplementedError def destroy(self): @@ -184,11 +220,17 @@ def __init__(self, cpu_group): def get_handle(self, kwargs): raise NotImplementedError - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: raise NotImplementedError def destroy(self): @@ -395,4 +437,4 @@ def cleanup(self): self.workspace_tensor = None self.prepare_workspace_tensor = None self.mapping = None - self.initialized = False \ No newline at end of file + self.initialized = False diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 586441c91783..a42081fb0c15 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -28,6 +28,8 @@ def get_or_create(self, kwargs, func): class All2AllManagerBase: + rank: int + world_size: int def __init__(self, cpu_group): self.cpu_group = cpu_group @@ -40,6 +42,7 @@ def __init__(self, cpu_group): # all2all lives in ep group, which is merged from dp and tp group self.dp_group = get_dp_group() self.tp_group = get_tp_group() + # no self.ep_group since self.ep_group is still in construction # when we create this object self.dp_rank = self.dp_group.rank_in_group @@ -60,17 +63,21 @@ def get_handle(self, kwargs): # and reuse it for the same config. raise NotImplementedError + def dispatch(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False): + raise NotImplementedError + def set_num_sms(self, num_sms: int): pass def max_sms_used(self) -> Optional[int]: return None # None means it could use the whole GPU - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - raise NotImplementedError - - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False): raise NotImplementedError def destroy(self): @@ -267,15 +274,20 @@ def prepare_communication_buffer_for_model(self, module.quant_method.init_prepare_finalize(module) def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: """ Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class. """ return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: """ Combine the hidden states and router logits from the appropriate device. This is a no-op in the base class. diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index bab372b722db..30d1bf10138b 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -39,10 +39,6 @@ def __init__(self, use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM - # ep does not use pynccl - use_pynccl = "ep" not in unique_name - - self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem = use_torch_symm_mem @@ -57,7 +53,7 @@ def __init__(self, SymmMemCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] = None - if use_pynccl and self.world_size > 1: + if self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, @@ -308,14 +304,20 @@ def _all_gather_single(input_: torch.Tensor, return output_list def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits) + hidden_states, router_logits, is_sequence_parallel) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states) + hidden_states = self.all2all_manager.combine(hidden_states, + is_sequence_parallel) return hidden_states diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index b236bae261e0..27bd176554af 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -75,14 +75,20 @@ def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: dist.broadcast(input_, src=src, group=self.device_group) def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits) + hidden_states, router_logits, is_sequence_parallel) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states) + hidden_states = self.all2all_manager.combine(hidden_states, + is_sequence_parallel) return hidden_states diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 69f98eb54f36..638170963e2b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -871,17 +871,24 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module): model) def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: if self.device_communicator is not None: return self.device_communicator.dispatch(hidden_states, - router_logits) + router_logits, + is_sequence_parallel) else: return hidden_states, router_logits - def combine(self, hidden_states) -> torch.Tensor: + def combine(self, + hidden_states, + is_sequence_parallel: bool = False) -> torch.Tensor: if self.device_communicator is not None: - return self.device_communicator.combine(hidden_states) + return self.device_communicator.combine(hidden_states, + is_sequence_parallel) else: return hidden_states diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 2bf4e1804521..09defade00dc 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -49,16 +49,29 @@ def non_uniform(self) -> "BatchDescriptor": return BatchDescriptor(self.num_tokens, uniform_decode=False) -def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], +def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor, + sequence_parallel_size: int) -> list[int]: + sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) // + sequence_parallel_size) + + sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size) + return sp_tokens.tolist() + + +def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor, + sequence_parallel_size: int, max_num_tokens: int, chunk_idx: int) -> list[int]: - dp_size = len(num_tokens_across_dp_cpu) - local_size = [-1] * dp_size - for i in range(dp_size): - dp_tokens = num_tokens_across_dp_cpu[i] + sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, + sequence_parallel_size) + sp_size = len(sp_tokens) + + local_size = [-1] * sp_size + for i in range(sp_size): + # Take into account sharding if MoE activation is sequence parallel. local_size[i] = min(max_num_tokens, - dp_tokens - (max_num_tokens * chunk_idx)) + sp_tokens[i] - (max_num_tokens * chunk_idx)) if local_size[i] <= 0: local_size[i] = 1 # ensure lockstep even if done return local_size @@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor - cu_tokens_across_dp_cpu: torch.Tensor + num_tokens_across_dp_cpu: torch.Tensor + + # NOTE: local_sizes should only be set by the chunked_sizes context manager local_sizes: Optional[list[int]] = None @staticmethod @@ -98,6 +113,17 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, dist.all_reduce(num_tokens_tensor, group=group) return num_tokens_tensor.cpu() + # Get the cumulative tokens across sequence parallel ranks. + # In this case the input to the MoEs will be distributed w.r.t both + # DP and TP rank. + # When sp_size==1, this is just the cummulative num tokens across DP. + def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor: + num_tokens_across_sp_cpu = ( + (self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size) + num_tokens_across_sp_cpu = ( + num_tokens_across_sp_cpu.repeat_interleave(sp_size)) + return torch.cumsum(num_tokens_across_sp_cpu, dim=0) + @staticmethod def should_ubatch_across_dp( should_ubatch: bool, orig_num_tokens_per_ubatch: int, @@ -147,10 +173,10 @@ def should_ubatch_across_dp( @staticmethod def make( - parallel_config: ParallelConfig, - attn_metadata: Any, - num_tokens: int, - num_tokens_across_dp: Optional[torch.Tensor] = None + parallel_config: ParallelConfig, + attn_metadata: Any, + num_tokens: int, + num_tokens_across_dp_cpu: Optional[torch.Tensor] = None ) -> "DPMetadata": assert parallel_config.data_parallel_size > 1 @@ -167,18 +193,18 @@ def make( # If num_tokens_across_dp is None, it will be computed by all_reduce # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize - assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank] - == batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}" - if num_tokens_across_dp is None: - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + assert (num_tokens_across_dp_cpu is None + or num_tokens_across_dp_cpu[dp_rank] == batchsize + ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" + if num_tokens_across_dp_cpu is None: + num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp( batchsize, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) - cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) - return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu, - num_tokens_across_dp) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) + return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu) @contextmanager - def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): + def chunked_sizes(self, sequence_parallel_size: int, + max_chunk_size_per_rank: int, chunk_idx: int): """ Context manager to compute and temporarily set the per-rank local token sizes for a specific chunk during chunked forward execution. @@ -192,31 +218,40 @@ def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): `chunk_idx`, this context manager sets `self.local_sizes` to the number of tokens to process in that chunk on each rank. - It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the - number of tokens per rank, and calls `_compute_chunked_local_num_tokens` - to determine the chunk-wise split. - `self.local_sizes` is only valid inside the context. Args: + sequence_parallel_size: When Attn is TP and MoE layers are EP, + we use SP between the layers to avoid + redundant ops. We need this value to + compute the chunked sizes. max_chunk_size_per_rank: The max number of tokens each rank is allowed to process in this chunk. chunk_idx: The index of the chunk to compute sizes for. """ - cu_sizes = self.cu_tokens_across_dp_cpu - num_tokens_across_dp_cpu = [ - (cu_sizes[i] - - cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item() - for i in range(len(cu_sizes)) - ] self.local_sizes = _compute_chunked_local_num_tokens( - num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx) + self.num_tokens_across_dp_cpu, sequence_parallel_size, + max_chunk_size_per_rank, chunk_idx) + try: + yield self.local_sizes + finally: + self.local_sizes = None + + @contextmanager + def sp_local_sizes(self, sequence_parallel_size: int): + """ + Context mamager for setting self.local_sizes. Same as self.chunked_sizes + but without any chunking. + """ + self.local_sizes = _compute_sp_num_tokens( + self.num_tokens_across_dp_cpu, sequence_parallel_size) try: yield self.local_sizes finally: self.local_sizes = None def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: + assert self.local_sizes is not None return self.local_sizes diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index eccae8b2a7af..8de1d14d46b3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,6 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable +from contextlib import nullcontext from enum import Enum from typing import Callable, Literal, Optional, Union, get_args, overload @@ -983,8 +984,7 @@ def __init__( if dp_size is not None else get_dp_group().world_size) self.is_sequence_parallel = is_sequence_parallel - if self.is_sequence_parallel: - self.sp_size = tp_size_ + self.sp_size = tp_size_ if is_sequence_parallel else 1 self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( @@ -1966,7 +1966,8 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) - with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank, + with ctx.dp_metadata.chunked_sizes(self.sp_size, + moe_dp_chunk_size_per_rank, chunk_idx): process_chunk(chunk_start, chunk_end, @@ -2011,65 +2012,73 @@ def forward_impl( else: shared_output = None - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits) + ctx = get_forward_context() + sp_ctx = ctx.dp_metadata.sp_local_sizes( + self.sp_size) if ctx.dp_metadata else nullcontext() - # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.global_num_experts, - expert_map=self.expert_map, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.e_score_correction_bias, - activation=self.activation, - apply_router_weight_on_input=self.apply_router_weight_on_input, - enable_eplb=self.enable_eplb, - expert_load_view=self.expert_load_view, - logical_to_physical_map=self.logical_to_physical_map, - logical_replica_count=self.logical_replica_count, - ) + with sp_ctx: + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits, self.is_sequence_parallel) - if shared_output is not None: - assert not isinstance(final_hidden_states, tuple) - assert self.shared_experts is not None - final_hidden_states = ( - shared_output, - final_hidden_states, + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, tuple) - final_hidden_states, zero_expert_result = final_hidden_states - def reduce_output(states: torch.Tensor, - do_combine: bool = True) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: - states = get_ep_group().combine(states) + if shared_output is not None: + assert not isinstance(final_hidden_states, tuple) + assert self.shared_experts is not None + final_hidden_states = ( + shared_output, + final_hidden_states, + ) + elif self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, tuple) + final_hidden_states, zero_expert_result = final_hidden_states + + def reduce_output(states: torch.Tensor, + do_combine: bool = True) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states, + self.is_sequence_parallel) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - states = self.maybe_all_reduce_tensor_model_parallel(states) + if (not self.is_sequence_parallel and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1)): + states = self.maybe_all_reduce_tensor_model_parallel( + states) - return states + return states - if self.shared_experts is not None: - return ( - reduce_output(final_hidden_states[0], do_combine=False), - reduce_output(final_hidden_states[1]), - ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, torch.Tensor) - return reduce_output(final_hidden_states) + zero_expert_result - else: - return reduce_output(final_hidden_states) + if self.shared_experts is not None: + return ( + reduce_output(final_hidden_states[0], do_combine=False), + reduce_output(final_hidden_states[1]), + ) + elif self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, torch.Tensor) + return reduce_output(final_hidden_states) + zero_expert_result + else: + return reduce_output(final_hidden_states) @classmethod def make_expert_params_mapping( diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 6cef5e134a4b..e0d7af0b1c3e 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -9,7 +9,7 @@ from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.processing_aria import AriaProcessor -from vllm.config import CacheConfig, QuantizationConfig, VllmConfig +from vllm.config import QuantizationConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import FusedMoE @@ -297,14 +297,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer): Experts (MoE) Layer. """ - def __init__( - self, - config: AriaTextConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, cache_config, quant_config, prefix) + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config, prefix) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.mlp = AriaTextMoELayer(config, quant_config=quant_config, prefix=f"{prefix}.mlp") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index aab522390a7a..2e0bcbe5d2e5 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -32,7 +32,6 @@ from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config -import vllm.envs as envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig @@ -56,8 +55,8 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors -from vllm.utils import cdiv, direct_register_custom_op from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, @@ -108,43 +107,6 @@ def forward(self, x): return x -# Chunk x along the num_tokens axis for sequence parallelism -# NOTE: This is wrapped in a torch custom op to work around the following issue: -# The output tensor can have a sequence length 0 at small input sequence lengths -# even though we explicitly pad to avoid this. -def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - - # all_gather needs the sequence length to be divisible by tp_size - seq_len = x.size(0) - remainder = seq_len % tp_size - if remainder != 0: - pad_len = tp_size - remainder - x = nn.functional.pad(x, (0, 0, 0, pad_len)) - - chunk = x.shape[0] // tp_size - start = tp_rank * chunk - return torch.narrow(x, 0, start, chunk) - - -def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor: - tp_size = get_tensor_model_parallel_world_size() - seq_len = cdiv(x.size(0), tp_size) - shape = list(x.shape) - shape[0] = seq_len - out = torch.empty(shape, dtype=x.dtype, device=x.device) - return out - - -direct_register_custom_op( - op_name="sequence_parallel_chunk", - op_func=sequence_parallel_chunk, - fake_impl=sequence_parallel_chunk_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), -) - - class DeepseekV2MoE(nn.Module): def __init__( @@ -166,20 +128,7 @@ def __init__( self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts - # The all_reduce at the end of attention (during o_proj) means that - # inputs are replicated across each rank of the tensor parallel group. - # If using expert-parallelism with DeepEP All2All ops, replicated - # tokens results in useless duplicate computation and communication. - # - # In this case, ensure the input to the experts is sequence parallel - # to avoid the excess work. - # - # Not needed for pplx-kernels as it can handle duplicate input tokens. - self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND - in ("deepep_high_throughput", - "deepep_low_latency") - and parallel_config.enable_expert_parallel - and self.tp_size > 1) + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -278,8 +227,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # TODO: We can replace the all_reduce at the end of attn with a # reduce_scatter instead of chunking here. if self.is_sequence_parallel: - hidden_states = torch.ops.vllm.sequence_parallel_chunk( - hidden_states) + hidden_states = sequence_parallel_chunk(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 3b24bf2f1ef8..2e6ef2d476a6 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -29,10 +29,9 @@ import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + vllm_config: VllmConfig, prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() + config = vllm_config.model_config.hf_config self.mtp_emb_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -62,8 +59,7 @@ def __init__( self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) - self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, - prefix) + self.mtp_block = LlamaDecoderLayer(vllm_config, prefix) def forward( self, @@ -102,10 +98,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.layers = torch.nn.ModuleDict({ str(idx): ErnieMultiTokenPredictorLayer( - config, + vllm_config, f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, ) for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index b9d5e24e9f6f..f49f21a40f82 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -136,14 +136,16 @@ def forward( class Glm4DecoderLayer(nn.Module): - def __init__( - self, - config: Glm4Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[Glm4Config] = None) -> None: super().__init__() + + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 7c755a00e1c9..47ba5084d608 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -13,7 +13,8 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -24,6 +25,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from vllm.utils import cdiv @@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module): def __init__( self, - config: GptOssConfig, + vllm_config: VllmConfig, layer_idx: int, - quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + self.layer_idx = layer_idx self.num_experts = config.num_local_experts self.experts_per_token = config.num_experts_per_tok @@ -155,11 +163,20 @@ def __init__( prefix=f"{prefix}.experts", apply_router_weight_on_input=False, has_bias=True, - activation="swigluoai") + activation="swigluoai", + is_sequence_parallel=self.is_sequence_parallel) def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.shape[0] + if self.is_sequence_parallel: + x = sequence_parallel_chunk(x) + g = self.router(x) x = self.experts(hidden_states=x, router_logits=g) + + if self.is_sequence_parallel: + x = tensor_model_parallel_all_gather(x.contiguous(), 0) + x = x[:num_tokens] return x @@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module): def __init__( self, - config: GptOssConfig, - cache_config: CacheConfig, - quant_config: QuantizationConfig, + vllm_config: VllmConfig, prefix: str = "", ): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + self.layer_idx = extract_layer_index(prefix) self.attn = OAIAttention(config, prefix=f"{prefix}.attn", cache_config=cache_config) - self.mlp = MLPBlock(config, + self.mlp = MLPBlock(vllm_config, self.layer_idx, - quant_config=quant_config, prefix=f"{prefix}.mlp") self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) @@ -216,8 +234,6 @@ def __init__( ): super().__init__() self.config = vllm_config.model_config.hf_config - self.cache_config = vllm_config.cache_config - self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( @@ -227,9 +243,7 @@ def __init__( self.start_layer, self.end_layer, self.layers = make_layers( self.config.num_hidden_layers, lambda prefix: TransformerBlock( - self.config, - cache_config=self.cache_config, - quant_config=self.quant_config, + vllm_config, prefix=prefix, ), prefix=f"{prefix}.layers", diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 47ac22c4aeaa..76a5745a4f51 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -29,12 +29,13 @@ import torch from torch import nn -from transformers.models.granitemoe import GraniteMoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -48,6 +49,7 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -71,9 +73,11 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + is_sequence_parallel=False, prefix: str = ""): super().__init__() self.hidden_size = hidden_size + self.is_sequence_parallel = is_sequence_parallel # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(hidden_size, @@ -92,15 +96,27 @@ def __init__(self, renormalize=True, quant_config=quant_config, tp_size=tp_size, - prefix=f"{prefix}.experts") + prefix=f"{prefix}.experts", + is_sequence_parallel=self.is_sequence_parallel) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + num_tokens = orig_shape[0] + final_hidden_states = final_hidden_states[:num_tokens] + return final_hidden_states.view(orig_shape) @@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module): def __init__( self, - config: GraniteMoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) @@ -218,6 +238,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, + is_sequence_parallel=parallel_config.use_sequence_parallel_moe, prefix=f"{prefix}.block_sparse_moe") self.input_layernorm = RMSNorm(config.hidden_size, @@ -255,7 +276,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -275,9 +295,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: GraniteMoeDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), + lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1b03cbef501b..c7dd134ea47e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -68,6 +68,7 @@ def __init__( bias: bool = False, prefix: str = "", reduce_results: bool = True, + disable_tp: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -75,6 +76,7 @@ def __init__( output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, + disable_tp=disable_tp, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( @@ -83,6 +85,7 @@ def __init__( bias=bias, quant_config=quant_config, reduce_results=reduce_results, + disable_tp=disable_tp, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": @@ -237,14 +240,16 @@ def _init_rotary_emb(self, config: LlamaConfig, class LlamaDecoderLayer(nn.Module): - def __init__( - self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[LlamaConfig] = None) -> None: super().__init__() + + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -335,7 +340,6 @@ def __init__(self, super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -357,10 +361,7 @@ def __init__(self, self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index ddd7e6a5936e..32d4f69c6bf1 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -28,7 +28,8 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -39,6 +40,7 @@ from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import sequence_parallel_chunk from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, @@ -59,13 +61,16 @@ def custom_routing_function( router_scores = torch.sigmoid(router_scores.float()) return (router_scores, router_indices.to(torch.int32)) - def __init__(self, - config: Llama4TextConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + self.tp_size = get_tensor_model_parallel_world_size() self.top_k = config.num_experts_per_tok + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe intermediate_size_moe = config.intermediate_size self.router = ReplicatedLinear(config.hidden_size, @@ -82,6 +87,7 @@ def __init__(self, bias=False, prefix=f"{prefix}.shared_expert", reduce_results=False, + disable_tp=self.is_sequence_parallel, ) self.experts = SharedFusedMoE( @@ -96,9 +102,14 @@ def __init__(self, renormalize=False, quant_config=quant_config, prefix=f"{prefix}.experts", + is_sequence_parallel=self.is_sequence_parallel, ) def forward(self, hidden_states): + num_tokens = hidden_states.shape[0] + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + router_logits, _ = self.router(hidden_states) shared_out, routed_out = self.experts( @@ -107,7 +118,10 @@ def forward(self, hidden_states): ) experts_out = routed_out + shared_out - if self.tp_size > 1: + if self.is_sequence_parallel: + experts_out = tensor_model_parallel_all_gather(experts_out, 0) + experts_out = experts_out[:num_tokens] + elif self.tp_size > 1: experts_out = self.experts.maybe_all_reduce_tensor_model_parallel( experts_out) @@ -257,15 +271,16 @@ def forward( class Llama4DecoderLayer(nn.Module): - def __init__( - self, - config: Llama4TextConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[Llama4TextConfig] = None) -> None: super().__init__() + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.layer_idx = extract_layer_index(prefix) self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.hidden_size = config.hidden_size @@ -291,8 +306,7 @@ def __init__( self.layer_idx + 1) % config.interleave_moe_layer_step == 0 if is_moe_layer: self.feed_forward = Llama4MoE( - config=config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=f"{prefix}.feed_forward", ) else: diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 235275c0940a..0768edd08315 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -68,9 +68,9 @@ def __init__( self.layers = nn.ModuleList([ Llama4DecoderLayer( - self.config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, ) for i in range(self.config.num_hidden_layers) ]) self.fc = torch.nn.Linear(self.config.hidden_size * 2, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index d6e6fd3fcfe9..d7d6b1745fc8 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def __init__( self, - config: LlamaConfig, + vllm_config: VllmConfig, disable_input_layernorm: bool, prefix: str = "", + config: Optional[LlamaConfig] = None, ) -> None: - super().__init__(config, prefix=prefix) + super().__init__(vllm_config, prefix=prefix, config=config) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 @@ -64,9 +65,10 @@ def __init__( self.layers = nn.ModuleList([ LlamaDecoderLayer( - self.config, + vllm_config, i == 0, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, ) for i in range(self.config.num_hidden_layers) ]) self.fc = torch.nn.Linear(self.config.hidden_size * 2, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 34b8ea0ca536..7192a76c8749 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -8,13 +8,11 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -28,17 +26,14 @@ class LlamaDecoderLayer(LlamaDecoderLayer): - def __init__( - self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) + def __init__(self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[LlamaConfig] = None) -> None: + super().__init__(vllm_config, prefix=prefix, config=config) + + config = config or vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config # override qkv self.self_attn.qkv_proj = QKVParallelLinear( @@ -125,9 +120,9 @@ def __init__( self.layers = nn.ModuleList([ LlamaDecoderLayer( - config=self.config, - cache_config=current_vllm_config.cache_config, + current_vllm_config, prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + config=self.config, ) ]) if hasattr(self.config, "target_hidden_size"): diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index cb2ff97a5df2..45b9c656a4bb 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -29,13 +29,13 @@ import torch from torch import nn -from transformers import Qwen3MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -51,6 +51,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP @@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( self, - config: Qwen3MoeConfig, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", - enable_eplb: bool = False, ): super().__init__() + + config = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group @@ -114,6 +118,8 @@ def __init__( self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -122,7 +128,7 @@ def __init__( # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb + self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts @@ -144,7 +150,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel) self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, @@ -156,14 +163,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: assert hidden_states.dim( ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" is_input_1d = hidden_states.dim() == 1 - hidden_dim = hidden_states.shape[-1] + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + # return to 1d if input is 1d return final_hidden_states.squeeze(0) if is_input_1d else \ final_hidden_states @@ -275,15 +290,13 @@ def forward( class Qwen3MoeDecoderLayer(nn.Module): - def __init__( - self, - config: Qwen3MoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - enable_eplb: bool = False, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -315,10 +328,8 @@ def __init__( if (layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen3MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb) + self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -362,10 +373,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config.get_text_config() - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts @@ -379,11 +388,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.embed_tokens") self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Qwen3MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - enable_eplb=enable_eplb), + lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, + prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index dc3153fcc826..14d19874a51e 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -17,7 +17,8 @@ VllmConfig, get_current_vllm_config) from vllm.distributed import (divide, get_ep_group, get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.fla.ops import ( @@ -47,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -69,14 +71,13 @@ class Qwen3NextSparseMoeBlock(nn.Module): - def __init__( - self, - config: Qwen3NextConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - enable_eplb: bool = False, - ): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group @@ -84,6 +85,8 @@ def __init__( self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -92,7 +95,7 @@ def __init__( # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb + self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts @@ -114,7 +117,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel) self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, @@ -141,9 +145,12 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape - hidden_dim = hidden_states.shape[-1] + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + shared_output = None if self.shared_expert is not None: shared_output = self.shared_expert(hidden_states) @@ -158,7 +165,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states) @@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module): def __init__( self, - config: Qwen3NextConfig, + vllm_config: VllmConfig, layer_type: str, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, prefix: str = "", - enable_eplb: bool = False, ) -> None: super().__init__() - self.config = config + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config self.layer_type = layer_type self.layer_idx = extract_layer_index(prefix) @@ -759,10 +771,8 @@ def __init__( config.num_experts > 0 and (self.layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen3NextSparseMoeBlock( - config=config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb, ) else: self.mlp = Qwen3NextMLP( @@ -783,14 +793,14 @@ def __init__( torch.zeros( 1, 1, - self.config.hidden_size, + config.hidden_size, dtype=config.torch_dtype, ), ) self.ffn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, - self.config.hidden_size, + config.hidden_size, dtype=config.torch_dtype, ), ) @@ -858,13 +868,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen3NextConfig = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config lora_config = vllm_config.lora_config - speculative_config = vllm_config.speculative_config - enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts @@ -881,14 +886,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_layer(prefix: str): return Qwen3NextDecoderLayer( - config, + vllm_config, layer_type=config.layer_types[extract_layer_index(prefix)], - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - speculative_config=speculative_config, prefix=prefix, - enable_eplb=enable_eplb, ) self.start_layer, self.end_layer, self.layers = make_layers( diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index c054339842e6..e950699a0c49 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -38,7 +38,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() model_config = vllm_config.model_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config config: Qwen3NextConfig = model_config.hf_config @@ -68,11 +67,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.layers = torch.nn.ModuleList( Qwen3NextDecoderLayer( - config, + vllm_config, layer_type="full_attention", - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, prefix=f'{prefix}.layers.{idx}', ) for idx in range(self.num_mtp_layers)) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index bb6a0bd02202..4bf151fbf62d 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -13,11 +13,14 @@ import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors -from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, +from vllm.utils import (cdiv, direct_register_custom_op, + get_cuda_view_from_cpu_tensor, is_pin_memory_available, is_uva_available) logger = init_logger(__name__) @@ -743,3 +746,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int: return hf_config.hidden_size text_config = hf_config.get_text_config() return text_config.hidden_size + + +# Chunk x along the num_tokens axis for sequence parallelism +# NOTE: This is wrapped in a torch custom op to work around the following issue: +# The output tensor can have a sequence length 0 at small input sequence lengths +# even though we explicitly pad to avoid this. +def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.sequence_parallel_chunk_impl(x) + + +def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + # all_gather needs the sequence length to be divisible by tp_size + seq_len = x.size(0) + remainder = seq_len % tp_size + if remainder != 0: + pad_len = tp_size - remainder + y = nn.functional.pad(x, (0, 0, 0, pad_len)) + else: + y = x + + chunk = y.shape[0] // tp_size + start = tp_rank * chunk + return torch.narrow(y, 0, start, chunk) + + +def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + seq_len = cdiv(x.size(0), tp_size) + shape = list(x.shape) + shape[0] = seq_len + out = torch.empty(shape, dtype=x.dtype, device=x.device) + return out + + +direct_register_custom_op( + op_name="sequence_parallel_chunk_impl", + op_func=sequence_parallel_chunk_impl, + fake_impl=sequence_parallel_chunk_impl_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +)