Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 129 additions & 28 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -216,6 +217,7 @@ def __init__(
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
self.alt_stream = alt_stream

if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
Expand Down Expand Up @@ -282,6 +284,59 @@ def __init__(
prefix=f"{prefix}.shared_experts",
)

@torch._dynamo.disable
def forward_shared_experts(self, hidden_states: torch.Tensor, hidden_states_shared: torch.Tensor) -> torch.Tensor:
hidden_states_shared, hidden_states_shared_scale = hidden_states_shared
shared_output_q, shared_output_s, router_logits = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert(
hidden_states_shared = hidden_states_shared,
hidden_states_shared_scale = hidden_states_shared_scale,
weight_gate_up = self.shared_experts.gate_up_proj.weight,
weight_scale_gate_up = self.shared_experts.gate_up_proj.weight_scale_inv,
hidden_states_moe_gate = hidden_states,
weight_moe_gate = self.gate.weight,
bias_shared = self.shared_experts.gate_up_proj.bias if not self.shared_experts.gate_up_proj.skip_bias_add else None,
bias_moe_gate = self.gate.bias if not self.gate.skip_bias_add else None,
)

#Overlap shared_fused_exports and MOE
#Guard to prevent repeated lines
did_fma = False
if self.alt_stream is not None: #and VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16: #Commented out for debugging purposes
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)

# if hasattr(torch.cuda, "nvtx"):
# torch.cuda.nvtx.range_push("deepseek_alt_downproj")


with torch.cuda.stream(self.alt_stream):
shared_output, _ = self.shared_experts.down_proj(
shared_output_q, x_quant_scales = shared_output_s
)
# e1 = torch.cuda.Event()
# e1.record(self.alt_stream)


# if hasattr(torch.cuda, "nvtx"):
# torch.cuda.nvtx.range_pop()

final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits
)

current_stream.wait_stream(self.alt_stream)

final_hidden_states = fused_mul_add(
final_hidden_states, self.routed_scaling_factor, shared_output
)

did_fma = True
else:
shared_output, _ = self.shared_experts.down_proj(shared_output_q, x_quant_scales = shared_output_s)

return shared_output, final_hidden_states

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if isinstance(hidden_states, tuple):
hidden_states_shared, hidden_states = hidden_states
Expand All @@ -290,47 +345,89 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

did_fma = False
if VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS and self.n_shared_experts is not None:
hidden_states_shared, hidden_states_shared_scale = hidden_states_shared
shared_output_q, shared_output_s, router_logits = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert(
hidden_states_shared = hidden_states_shared,
hidden_states_shared_scale = hidden_states_shared_scale,
weight_gate_up = self.shared_experts.gate_up_proj.weight,
weight_scale_gate_up = self.shared_experts.gate_up_proj.weight_scale_inv,
hidden_states_moe_gate = hidden_states,
weight_moe_gate = self.gate.weight,
bias_shared = self.shared_experts.gate_up_proj.bias if not self.shared_experts.gate_up_proj.skip_bias_add else None,
bias_moe_gate = self.gate.bias if not self.gate.skip_bias_add else None,
)
shared_output, _ = self.shared_experts.down_proj(shared_output_q, x_quant_scales = shared_output_s)

# if hasattr(torch._dynamo, "is_compiling") and torch._dynamo.is_compiling():
# torch._dynamo.graph_break()
shared_output, final_hidden_states = self.forward_shared_experts(hidden_states, hidden_states_shared)
did_fma = True
# hidden_states_shared, hidden_states_shared_scale = hidden_states_shared
# shared_output_q, shared_output_s, router_logits = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert(
# hidden_states_shared = hidden_states_shared,
# hidden_states_shared_scale = hidden_states_shared_scale,
# weight_gate_up = self.shared_experts.gate_up_proj.weight,
# weight_scale_gate_up = self.shared_experts.gate_up_proj.weight_scale_inv,
# hidden_states_moe_gate = hidden_states,
# weight_moe_gate = self.gate.weight,
# bias_shared = self.shared_experts.gate_up_proj.bias if not self.shared_experts.gate_up_proj.skip_bias_add else None,
# bias_moe_gate = self.gate.bias if not self.gate.skip_bias_add else None,
# )

# #Overlap shared_fused_exports and MOE
# #Guard to prevent repeated lines
# did_fma = False
# if self.alt_stream is not None: #and VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16: #Commented out for debugging purposes
# current_stream = torch.cuda.current_stream()
# self.alt_stream.wait_stream(current_stream)

# # if hasattr(torch.cuda, "nvtx"):
# # torch.cuda.nvtx.range_push("deepseek_alt_downproj")


# with torch.cuda.stream(self.alt_stream):
# shared_output, _ = self.shared_experts.down_proj(
# shared_output_q, x_quant_scales = shared_output_s
# )


# # if hasattr(torch.cuda, "nvtx"):
# # torch.cuda.nvtx.range_pop()

# final_hidden_states = self.experts(
# hidden_states=hidden_states,
# router_logits=router_logits
# )

# current_stream.wait_stream(self.alt_stream)

# final_hidden_states = fused_mul_add(
# final_hidden_states, self.routed_scaling_factor, shared_output
# )

# did_fma = True
# else:
# shared_output, _ = self.shared_experts.down_proj(shared_output_q, x_quant_scales = shared_output_s)
else:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states_shared)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
final_hidden_states = fused_mul_add(final_hidden_states, self.routed_scaling_factor, shared_output)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if not did_fma:
if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
final_hidden_states = fused_mul_add(final_hidden_states, self.routed_scaling_factor, shared_output)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)

if self.tp_size > 1:
final_hidden_states = (
Expand Down Expand Up @@ -706,6 +803,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
enable_eplb: bool = False,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -747,6 +845,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
alt_stream=alt_stream
)
else:
self.mlp = DeepseekV2MLP(
Expand Down Expand Up @@ -865,6 +964,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.embed_tokens = PPMissingLayer()

self.alt_stream = torch.cuda.Stream() if torch.cuda.is_available() else None
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(
Expand All @@ -874,6 +974,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config=cache_config,
quant_config=quant_config,
enable_eplb=enable_eplb,
alt_stream=self.alt_stream,
),
prefix=f"{prefix}.layers")

Expand Down