Skip to content

Commit 2ab27b7

Browse files
faaanyyewentao256
authored andcommitted
[XPU] Fix MOE DP accuracy issue on XPU (#25465)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent a500f7c commit 2ab27b7

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ def parse_args():
101101
"--quantization",
102102
type=str,
103103
)
104+
parser.add_argument(
105+
"--disable-expert-parallel",
106+
dest="enable_expert_parallel",
107+
action="store_false",
108+
help="Disable expert parallel (default: enabled).",
109+
)
110+
parser.set_defaults(enable_expert_parallel=True)
104111
return parser.parse_args()
105112

106113

@@ -113,6 +120,7 @@ def main(
113120
dp_master_port,
114121
GPUs_per_dp_rank,
115122
enforce_eager,
123+
enable_expert_parallel,
116124
trust_remote_code,
117125
max_num_seqs,
118126
max_model_len,
@@ -168,7 +176,7 @@ def start(rank):
168176
model=model,
169177
tensor_parallel_size=GPUs_per_dp_rank,
170178
enforce_eager=enforce_eager,
171-
enable_expert_parallel=True,
179+
enable_expert_parallel=enable_expert_parallel,
172180
trust_remote_code=trust_remote_code,
173181
max_num_seqs=max_num_seqs,
174182
max_model_len=max_model_len,
@@ -229,6 +237,7 @@ def start(rank):
229237
dp_master_port,
230238
tp_size,
231239
args.enforce_eager,
240+
args.enable_expert_parallel,
232241
args.trust_remote_code,
233242
args.max_num_seqs,
234243
args.max_model_len,

vllm/distributed/device_communicators/xpu_communicator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ def __init__(self,
2525
super().__init__(cpu_group, device, device_group, unique_name)
2626
if self.use_all2all:
2727
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
28+
if all2all_backend != "naive":
29+
logger.warning(
30+
"`%s` all2all manager is not supported on XPU."
31+
"Falling back to `naive` all2all manager for XPU.",
32+
all2all_backend)
33+
all2all_backend = "naive"
2834
if all2all_backend == "naive":
2935
from .all2all import NaiveAll2AllManager
3036
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
@@ -67,3 +73,16 @@ def gather(self,
6773

6874
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
6975
dist.broadcast(input_, src=src, group=self.device_group)
76+
77+
def dispatch(
78+
self, hidden_states: torch.Tensor,
79+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
80+
assert self.all2all_manager is not None
81+
hidden_states, router_logits = self.all2all_manager.dispatch(
82+
hidden_states, router_logits)
83+
return hidden_states, router_logits
84+
85+
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
86+
assert self.all2all_manager is not None
87+
hidden_states = self.all2all_manager.combine(hidden_states)
88+
return hidden_states

0 commit comments

Comments
 (0)