-
Notifications
You must be signed in to change notification settings - Fork 542
[main][prefill optimization] Optimize parallel strategies to reduce communication overhead #2198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4ecbd33
63967cb
b165945
5e9bb72
31bd0d7
19ac51c
b00fa60
a4117fb
0cf9e1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,7 +141,8 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): | |
| def forward( | ||
| self, | ||
| input_, | ||
| is_prefill=True | ||
| is_prefill=True, | ||
| is_force_scatter=False | ||
| ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: | ||
| if self.input_is_parallel: | ||
| input_parallel = input_ | ||
|
|
@@ -160,7 +161,13 @@ def forward( | |
| input_parallel, | ||
| bias=bias_) | ||
| if self.reduce_results and self.tp_size > 1: | ||
| if not is_prefill and output_parallel.shape[0] % self.tp_size == 0: | ||
| num_tokens = output_parallel.shape[0] | ||
| if is_force_scatter and num_tokens % self.tp_size: | ||
| output_parallel = nn.functional.pad( | ||
| output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) | ||
| if is_force_scatter or (not is_prefill | ||
| and output_parallel.shape[0] % self.tp_size | ||
| == 0): | ||
| output = tensor_model_parallel_reduce_scatter(output_parallel, | ||
| dim=0) | ||
| else: | ||
|
|
@@ -180,7 +187,8 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear): | |
| def forward( | ||
| self, | ||
| input_, | ||
| is_prefill=True | ||
| is_prefill=True, | ||
| is_force_scatter=False | ||
| ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: | ||
| if self.input_is_parallel: | ||
| input_parallel = input_ | ||
|
|
@@ -347,13 +355,15 @@ def __init__( | |
| reduce_results = not self.all_reduce_merge | ||
| intermediate_size = (config.moe_intermediate_size * | ||
| config.n_shared_experts) | ||
| enable_shared_expert_dp = ascend_config.enable_shared_expert_dp | ||
| self.shared_experts = CustomDeepseekV2MLP( | ||
| hidden_size=config.hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| hidden_act=config.hidden_act, | ||
| quant_config=quant_config, | ||
| reduce_results=reduce_results, | ||
| force_replicate=self.enable_multistream_moe, | ||
| force_replicate=self.enable_multistream_moe | ||
| or enable_shared_expert_dp, | ||
| prefix=f"{prefix}.shared_experts", | ||
| ) | ||
| else: | ||
|
|
@@ -447,9 +457,11 @@ def __init__( | |
| self.kv_lora_rank = kv_lora_rank | ||
|
|
||
| self.num_heads = num_heads | ||
| tp_size = get_tensor_model_parallel_world_size() | ||
| assert num_heads % tp_size == 0 | ||
| self.num_local_heads = num_heads // tp_size | ||
| self.tp_size = get_tensor_model_parallel_world_size() | ||
| assert num_heads % self.tp_size == 0 | ||
| self.num_local_heads = num_heads // self.tp_size | ||
| self.layers = config.num_hidden_layers | ||
| self.first_k_dense_replace = config.first_k_dense_replace | ||
|
|
||
| self.scaling = self.qk_head_dim**-0.5 | ||
| self.rope_theta = rope_theta | ||
|
|
@@ -462,6 +474,7 @@ def __init__( | |
| self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled | ||
| self.enable_multistream_mla = \ | ||
| ascend_config.torchair_graph_config.enable_multistream_mla | ||
| self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp | ||
|
|
||
| if self.q_lora_rank is not None: | ||
| self.q_a_proj = ReplicatedLinear(self.hidden_size, | ||
|
|
@@ -501,8 +514,9 @@ def __init__( | |
| prefix=f"{prefix}.kv_b_proj") | ||
| if (config.n_routed_experts is not None | ||
| and self.debug_layer_idx >= config.first_k_dense_replace | ||
| and self.debug_layer_idx % config.moe_layer_freq == 0 and | ||
| ascend_config.torchair_graph_config.enable_multistream_moe): | ||
| and self.debug_layer_idx % config.moe_layer_freq == 0 | ||
| and (ascend_config.torchair_graph_config.enable_multistream_moe | ||
| or self.enable_shared_expert_dp)): | ||
| self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce( | ||
| self.num_heads * self.v_head_dim, | ||
| self.hidden_size, | ||
|
|
@@ -596,13 +610,27 @@ def forward( | |
| output = output.view(-1, output_shape[-1]) | ||
| return output | ||
| else: | ||
| kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( | ||
| kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] | ||
| if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: | ||
| hidden_states_or_q_c = get_tp_group().all_gather( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the case of severe load imbalance between DPs, the DP domain may become stuck? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the DP domain will not block, this is all_gather of all TPs within the DP domain, and different DP domains will not affect each other. |
||
| hidden_states_or_q_c, 0) | ||
| kv_no_split = get_tp_group().all_gather(kv_no_split, 0) | ||
|
|
||
| kv_c, k_pe = kv_no_split.split( | ||
| [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) | ||
| kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) | ||
| if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: | ||
| output_shape = hidden_states.shape | ||
| else: | ||
| num_tokens = hidden_states_or_q_c.shape[0] | ||
| rows = num_tokens // self.tp_size | ||
| if num_tokens % self.tp_size: | ||
| rows += 1 | ||
| output_shape = (rows, hidden_states.shape[1]) | ||
| return self.mla_attn(hidden_states_or_q_c, | ||
| kv_c_normed, | ||
| k_pe, | ||
| output_shape=hidden_states.shape) | ||
| output_shape=output_shape) | ||
|
|
||
|
|
||
| class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): | ||
|
|
@@ -677,6 +705,8 @@ def __init__( | |
| eps=config.rms_norm_eps) | ||
| self.routed_scaling_factor = config.routed_scaling_factor | ||
| self.first_k_dense_replace = config.first_k_dense_replace | ||
| self.tp_group = get_tp_group().device_group | ||
| self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp | ||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -731,6 +761,18 @@ def forward( | |
| # first layer. | ||
| residual *= 1. / self.routed_scaling_factor | ||
|
|
||
| tp_size = get_tensor_model_parallel_world_size() | ||
| if self.enable_shared_expert_dp and ( | ||
| self.layer_idx == self.first_k_dense_replace | ||
| or self.layer_idx == self.layers) and tp_size > 1: | ||
| num_tokens, _ = residual.shape | ||
| if num_tokens % tp_size: | ||
| residual = nn.functional.pad(residual, | ||
| (0, 0, 0, -num_tokens % tp_size)) | ||
| chunk_residual = torch.tensor_split(residual, tp_size, dim=0) | ||
| tp_rank = get_tensor_model_parallel_rank() | ||
| residual = chunk_residual[tp_rank] | ||
|
|
||
| # Fully Connected | ||
| hidden_states, residual = self.post_attention_layernorm( | ||
| hidden_states, residual) | ||
|
|
@@ -756,6 +798,22 @@ def forward( | |
| dim=0) | ||
| residual = tensor_model_parallel_all_gather(residual, dim=0) | ||
|
|
||
| # for last layer of main model and mtp layer. | ||
| if self.enable_shared_expert_dp and self.layer_idx >= ( | ||
| self.layers - 1) and tp_size > 1: | ||
| hidden_states = get_tp_group().all_gather(hidden_states, 0) | ||
| residual = get_tp_group().all_gather(residual, 0) | ||
|
|
||
| attn_metadata = get_forward_context().attn_metadata | ||
| if attn_metadata is not None: | ||
| num_tokens = attn_metadata.num_actual_tokens | ||
| else: | ||
| num_tokens = hidden_states.shape[0] | ||
|
|
||
| if num_tokens < hidden_states.shape[0]: | ||
| hidden_states = hidden_states[:num_tokens] | ||
| residual = residual[:num_tokens] | ||
|
|
||
| return hidden_states, residual | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add the note,please take a look, thanks. @wangxiyuan