Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b83e888
Apply TP Attn + EP MoE fix to other models
tlrmchlsmth Sep 16, 2025
1ae573a
llama4
tlrmchlsmth Sep 16, 2025
9b969ee
llama4 eagle
tlrmchlsmth Sep 16, 2025
d3bc2cf
Qwen3-Next
tlrmchlsmth Sep 18, 2025
f1f3f63
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 18, 2025
2ba92e7
Use SP for AG RS All2All backend
tlrmchlsmth Sep 18, 2025
84d57d3
WIP debugging
tlrmchlsmth Sep 19, 2025
42b9a0c
WIP implementing for naive and ag_rs a2a
tlrmchlsmth Sep 21, 2025
9cff96c
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 21, 2025
d5daf2c
WIP making sure it works when the model isn't SP.
tlrmchlsmth Sep 22, 2025
c3d7c76
Fixes and cleanup
tlrmchlsmth Sep 22, 2025
f0dc75f
hit gpt_oss
tlrmchlsmth Sep 22, 2025
3f6b0cc
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 22, 2025
d561aa7
granitemoe
tlrmchlsmth Sep 22, 2025
3ea7bea
fixup
tlrmchlsmth Sep 23, 2025
9f1d196
nix use_pynccl
tlrmchlsmth Sep 23, 2025
33cfd3f
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 23, 2025
5789423
fix pure TP
tlrmchlsmth Sep 23, 2025
2c02081
Fixup Llama4Decoder interface changes
tlrmchlsmth Sep 23, 2025
675f3fc
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 23, 2025
f5682b8
fix xpu_communicator
tlrmchlsmth Sep 23, 2025
98a660b
fixup
tlrmchlsmth Sep 23, 2025
bcecc33
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 23, 2025
6e505f5
fixup
tlrmchlsmth Sep 23, 2025
99cf467
eagle fixups
tlrmchlsmth Sep 23, 2025
75d51da
fixup glm4
tlrmchlsmth Sep 23, 2025
0a66ca3
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 23, 2025
cbc8696
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 24, 2025
96d241c
fixup llama4 - replicated MLP
tlrmchlsmth Sep 24, 2025
b5d8564
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 24, 2025
0237ff1
gmu
tlrmchlsmth Sep 25, 2025
c51619d
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 25, 2025
210bcc3
fixup
tlrmchlsmth Sep 25, 2025
2b27ab5
fixup
tlrmchlsmth Sep 25, 2025
88bfcab
move whisper to the top
tlrmchlsmth Sep 25, 2025
6de0995
Merge branch 'tp_attn_fix_more_models' of http://github.com/tlrmchlsm…
tlrmchlsmth Sep 25, 2025
82c77ee
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 26, 2025
ce03982
Apply Varun's Patch
simon-mo Sep 27, 2025
15c15bd
Merge branch 'main' into tp_attn_fix_more_models
tlrmchlsmth Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,24 @@ def stateless_init_dp_group(self) -> ProcessGroup:
assert last_exc is not None
raise last_exc

# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)

@staticmethod
def has_unfinished_dp(dp_group: ProcessGroup,
has_unfinished: bool) -> bool:
Expand Down
122 changes: 82 additions & 40 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.distributed as dist

import vllm.envs as envs
from vllm.distributed import get_dp_group
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
Expand Down Expand Up @@ -34,41 +34,60 @@ def __init__(self, cpu_group):
super().__init__(cpu_group)

def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)

start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
rank = self.rank if is_sequence_parallel else self.dp_rank
world_size = (self.world_size
if is_sequence_parallel else self.dp_world_size)

start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
buffer[start:end, :].copy_(x)
for idx in range(self.dp_world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
self.dp_group.broadcast(buffer[start:end, :], idx)
for idx in range(world_size):
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
end = cu_tokens_across_sp_cpu[idx]
get_ep_group().broadcast(buffer[start:end, :], idx)

return buffer

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)

def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:

ep_rank = self.rank if is_sequence_parallel else self.dp_rank

dp_metadata = get_forward_context().dp_metadata
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
end = cu_tokens_across_sp_cpu[ep_rank]

all_hidden_states = get_ep_group().all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states

def destroy(self):
Expand All @@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(

dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)

dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states

def destroy(self):
Expand Down Expand Up @@ -148,11 +178,17 @@ def get_handle(self, kwargs):
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError

def destroy(self):
Expand Down Expand Up @@ -184,11 +220,17 @@ def __init__(self, cpu_group):
def get_handle(self, kwargs):
raise NotImplementedError

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError

def destroy(self):
Expand Down Expand Up @@ -395,4 +437,4 @@ def cleanup(self):
self.workspace_tensor = None
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False
self.initialized = False
28 changes: 20 additions & 8 deletions vllm/distributed/device_communicators/base_device_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def get_or_create(self, kwargs, func):


class All2AllManagerBase:
rank: int
world_size: int

def __init__(self, cpu_group):
self.cpu_group = cpu_group
Expand All @@ -40,6 +42,7 @@ def __init__(self, cpu_group):
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()

# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
Expand All @@ -60,17 +63,21 @@ def get_handle(self, kwargs):
# and reuse it for the same config.
raise NotImplementedError

def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError

def set_num_sms(self, num_sms: int):
pass

def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError

def destroy(self):
Expand Down Expand Up @@ -267,15 +274,20 @@ def prepare_communication_buffer_for_model(self,
module.quant_method.init_prepare_finalize(module)

def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
Expand Down
22 changes: 12 additions & 10 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def __init__(self,
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM

# ep does not use pynccl
use_pynccl = "ep" not in unique_name

self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem

Expand All @@ -57,7 +53,7 @@ def __init__(self,
SymmMemCommunicator)

self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
Expand Down Expand Up @@ -308,14 +304,20 @@ def _all_gather_single(input_: torch.Tensor,
return output_list

def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states
16 changes: 11 additions & 5 deletions vllm/distributed/device_communicators/xpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,20 @@ def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)

def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states
17 changes: 12 additions & 5 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,17 +871,24 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
model)

def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states,
router_logits)
router_logits,
is_sequence_parallel)
else:
return hidden_states, router_logits

def combine(self, hidden_states) -> torch.Tensor:
def combine(self,
hidden_states,
is_sequence_parallel: bool = False) -> torch.Tensor:
if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states)
return self.device_communicator.combine(hidden_states,
is_sequence_parallel)
else:
return hidden_states

Expand Down
Loading