Skip to content

Commit a086071

Browse files
committed
Introduce and use CustomDeepseekV2MergedReplicatedLinear
As the replicated version of MergedColumnParallelLinear, aiming at removing TP communication of DeepSeek-V2's `gate_up_proj` linear. Also, with replicated weight, the chunked input hidden_states can be used by shared experts. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 2db0dce commit a086071

File tree

3 files changed

+99
-58
lines changed

3 files changed

+99
-58
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,41 @@ def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
9999
return super().forward_oot(x)
100100

101101

102+
class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
103+
104+
def __init__(
105+
self,
106+
input_size: int,
107+
output_sizes: list[int],
108+
bias: bool = True,
109+
quant_config: Optional[QuantizationConfig] = None,
110+
prefix: str = "",
111+
):
112+
self.output_sizes = output_sizes
113+
super().__init__(input_size,
114+
sum(output_sizes),
115+
bias=bias,
116+
quant_config=quant_config,
117+
prefix=prefix)
118+
119+
def weight_loader(self, param: torch.nn.Parameter,
120+
loaded_weight: torch.Tensor, loaded_shard_id: int):
121+
# With no support for GGUF format yet.
122+
assert not getattr(param, "is_gguf_weight", False)
123+
assert not getattr(param, "is_gguf_weight_type", False)
124+
125+
assert loaded_shard_id < len(self.output_sizes)
126+
shard_offset = sum(self.output_sizes[:loaded_shard_id])
127+
shard_size = self.output_sizes[loaded_shard_id]
128+
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
129+
130+
assert shard.size() == loaded_weight.size(), (
131+
f"Tried to load weights of size {loaded_weight.size()}"
132+
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}"
133+
)
134+
shard.copy_(loaded_weight)
135+
136+
102137
class CustomDeepseekV2MLP(nn.Module):
103138

104139
def __init__(
@@ -108,20 +143,33 @@ def __init__(
108143
hidden_act: str,
109144
quant_config: Optional[QuantizationConfig] = None,
110145
reduce_results: bool = True,
146+
force_replicate: bool = False,
111147
prefix: str = "",
112148
) -> None:
113149
super().__init__()
114-
self.gate_up_proj = MergedColumnParallelLinear(
115-
hidden_size, [intermediate_size] * 2,
116-
bias=False,
117-
quant_config=quant_config,
118-
prefix=f"{prefix}.gate_up_proj")
119-
self.down_proj = RowParallelLinear(intermediate_size,
120-
hidden_size,
121-
bias=False,
122-
quant_config=quant_config,
123-
reduce_results=reduce_results,
124-
prefix=f"{prefix}.down_proj")
150+
if not force_replicate:
151+
self.gate_up_proj = MergedColumnParallelLinear(
152+
hidden_size, [intermediate_size] * 2,
153+
bias=False,
154+
quant_config=quant_config,
155+
prefix=f"{prefix}.gate_up_proj")
156+
self.down_proj = RowParallelLinear(intermediate_size,
157+
hidden_size,
158+
bias=False,
159+
quant_config=quant_config,
160+
reduce_results=reduce_results,
161+
prefix=f"{prefix}.down_proj")
162+
else:
163+
self.gate_up_proj = CustomDeepseekV2MergedReplicatedLinear(
164+
hidden_size, [intermediate_size] * 2,
165+
bias=False,
166+
quant_config=quant_config,
167+
prefix=f"{prefix}.gate_up_proj")
168+
self.down_proj = ReplicatedLinear(intermediate_size,
169+
hidden_size,
170+
bias=False,
171+
quant_config=quant_config,
172+
prefix=f"{prefix}.down_proj")
125173
if hidden_act != "silu":
126174
raise ValueError(f"Unsupported activation: {hidden_act}. "
127175
"Only silu is supported for now.")
@@ -183,6 +231,12 @@ def __init__(
183231
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
184232
"Only silu is supported for now.")
185233

234+
ascend_config = get_ascend_config()
235+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
236+
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
237+
self.enable_multistream_moe = \
238+
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
239+
186240
self.gate = ReplicatedLinear(config.hidden_size,
187241
config.n_routed_experts,
188242
bias=False,
@@ -218,6 +272,7 @@ def __init__(
218272
hidden_act=config.hidden_act,
219273
quant_config=quant_config,
220274
reduce_results=True,
275+
force_replicate=self.enable_multistream_moe,
221276
prefix=f"{prefix}.shared_experts",
222277
)
223278
else:
@@ -232,12 +287,6 @@ def __init__(
232287

233288
self.params_dtype = torch.get_default_dtype()
234289

235-
ascend_config = get_ascend_config()
236-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
237-
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
238-
self.enable_multistream_moe = \
239-
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
240-
241290
def forward(
242291
self,
243292
hidden_states: torch.Tensor,
@@ -276,27 +325,22 @@ def forward(
276325
# router_logits: (num_tokens, n_experts)
277326
router_logits, _ = self.gate(hidden_states)
278327

279-
kwargs = {}
280-
if not use_separated_shared_experts:
281-
kwargs.update({
282-
"shared_experts": self.shared_experts,
283-
"shared_experts_input": old_hidden_states
284-
})
285-
286328
experts_hidden_states = self.experts(
287329
hidden_states=hidden_states,
288330
router_logits=router_logits,
289331
is_prefill=is_prefill,
290332
top_k=CustomDeepseekV2MoE.top_k,
291333
enable_force_load_balance=enable_force_load_balance,
292-
**kwargs)
334+
shared_experts=(self.shared_experts
335+
if not use_separated_shared_experts else None),
336+
)
293337

294338
if not isinstance(experts_hidden_states, tuple):
295339
hidden_states = experts_hidden_states * self.routed_scaling_factor
296340
else:
297-
hidden_states = experts_hidden_states[
298-
0] * self.routed_scaling_factor
299-
shared_hidden_states = experts_hidden_states[1]
341+
hidden_states = (
342+
experts_hidden_states[0] * self.routed_scaling_factor +
343+
experts_hidden_states[1])
300344

301345
if self.tp_size > 1:
302346
if (VLLM_ENABLE_MC2
@@ -311,10 +355,8 @@ def forward(
311355
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
312356

313357
if use_separated_shared_experts:
314-
shared_hidden_states = self.shared_experts(old_hidden_states)
315-
316-
if self.shared_experts is not None:
317-
hidden_states = hidden_states + shared_hidden_states
358+
hidden_states = hidden_states + self.shared_experts(
359+
old_hidden_states)
318360

319361
return hidden_states.view(num_tokens, hidden_size)
320362

vllm_ascend/ops/fused_moe.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

1818
import os
19-
from typing import Callable, List, Optional
19+
from typing import Any, Callable, List, Optional
2020

2121
import torch
2222
import torch.distributed as dist
@@ -1099,8 +1099,8 @@ def forward(self,
10991099
router_logits: torch.Tensor,
11001100
is_prefill: bool,
11011101
enable_force_load_balance: bool = False,
1102-
top_k=None,
1103-
**kwargs):
1102+
top_k: Optional[int] = None,
1103+
shared_experts: Optional[Any] = None):
11041104
assert self.quant_method is not None
11051105

11061106
if top_k:
@@ -1147,14 +1147,13 @@ def forward(self,
11471147
enable_force_load_balance=enable_force_load_balance,
11481148
log2phy=self.log2phy,
11491149
global_redundant_expert_num=self.global_redundant_expert_num,
1150-
**kwargs)
1150+
shared_experts=shared_experts,
1151+
)
11511152

1152-
shared_experts = kwargs.get("shared_experts", None)
1153-
shared_experts_input = kwargs.get("shared_experts_input", None)
11541153
if shared_experts is not None:
11551154
# Provide dummy implementation of "non-separated" shared experts.
11561155
if not isinstance(e_hidden_states, tuple):
1157-
return e_hidden_states, shared_experts(shared_experts_input)
1156+
return e_hidden_states, shared_experts(hidden_states)
11581157
else:
11591158
return e_hidden_states
11601159

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,21 @@ def apply_mlp(hidden_states: torch.Tensor,
105105
return hidden_states
106106

107107

108-
def fused_experts_with_mc2(hidden_states: torch.Tensor,
109-
w1: torch.Tensor,
110-
w2: torch.Tensor,
111-
w1_scale: torch.Tensor,
112-
w2_scale: torch.Tensor,
113-
topk_weights: torch.Tensor,
114-
topk_ids: torch.Tensor,
115-
top_k: int,
116-
expert_map: torch.Tensor = None,
117-
moe_all_to_all_group_name: str = "",
118-
log2phy: torch.Tensor = None,
119-
global_redundant_expert_num: int = 0,
120-
**kwargs) -> torch.Tensor:
108+
def fused_experts_with_mc2(
109+
hidden_states: torch.Tensor,
110+
w1: torch.Tensor,
111+
w2: torch.Tensor,
112+
w1_scale: torch.Tensor,
113+
w2_scale: torch.Tensor,
114+
topk_weights: torch.Tensor,
115+
topk_ids: torch.Tensor,
116+
top_k: int,
117+
expert_map: torch.Tensor = None,
118+
moe_all_to_all_group_name: str = "",
119+
log2phy: torch.Tensor = None,
120+
global_redundant_expert_num: int = 0,
121+
shared_experts: Optional[Any] = None,
122+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
121123
if log2phy:
122124
topk_ids = log2phy[topk_ids]
123125
global_bs = 0
@@ -161,13 +163,10 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
161163
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
162164
0:5]
163165

164-
shared_experts = kwargs.get("shared_experts", None)
165-
shared_experts_input = kwargs.get("shared_experts_input", None)
166166
if shared_experts is not None:
167167
with npu_stream_switch("moe_secondary", 0):
168-
npu_wait_tensor(shared_experts_input, topk_weights)
169-
shared_gate_up, _ = shared_experts.gate_up_proj(
170-
shared_experts_input)
168+
npu_wait_tensor(hidden_states, topk_weights)
169+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
171170
npu_wait_tensor(shared_gate_up[0], expand_x)
172171
shared_act = shared_experts.act_fn(shared_gate_up)
173172

@@ -616,6 +615,7 @@ def apply(
616615
enable_force_load_balance: bool = True,
617616
log2phy: torch.Tensor = None,
618617
global_redundant_expert_num: int = 0,
618+
shared_experts: Optional[Any] = None,
619619
**kwargs,
620620
) -> torch.Tensor:
621621
assert router_logits.shape[
@@ -672,7 +672,7 @@ def apply(
672672
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
673673
log2phy=log2phy,
674674
global_redundant_expert_num=global_redundant_expert_num,
675-
**kwargs)
675+
shared_experts=shared_experts)
676676
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
677677
return fused_experts(hidden_states=x,
678678
w1=layer.w13_weight,

0 commit comments

Comments
 (0)