diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 77cb6a893d66..885b179a99fd 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -223,7 +223,7 @@ class CompilationConfig: constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" # CudaGraph compilation - cudagraph_mode: Optional[CUDAGraphMode] = CUDAGraphMode.FULL + cudagraph_mode: Optional[CUDAGraphMode] = CUDAGraphMode.PIECEWISE #Hard coding this but will return it back later """ The mode of the cudagraph: @@ -340,6 +340,7 @@ class CompilationConfig: "vllm.mamba_mixer", "vllm.short_conv", "vllm.linear_attention", + "vllm.streams_breaks", ] def compute_hash(self) -> str: @@ -569,6 +570,10 @@ def set_splitting_ops_for_v1(self): "any problems.") self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] + if "vllm.streams_breaks" not in self.splitting_ops: + self.splitting_ops.append("vllm.streams_breaks") + print(f"[compile][DEBUGOMAR] final splitting_ops={self.splitting_ops} " + f"mode={self.cudagraph_mode} level={self.level}") def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0450443ede70..d729d20b9f15 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -26,7 +26,6 @@ import typing from collections.abc import Callable, Iterable from typing import Any, Optional, Union - import torch from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config @@ -63,6 +62,7 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.logger import init_logger +from vllm.forward_context import get_forward_context logger = init_logger(__name__) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) @@ -88,10 +88,111 @@ if VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS: from aiter.ops.triton.fused_gemm_a8w8_blockscale_a16w16 import fused_gemm_a8w8_blockscale_a16w16 from aiter.ops.triton.fused_fp8_quant import fused_reduce_act_mul_fp8_group_quant + import aiter as rocm_aiter rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 rocm_aiter_fp8_quant_group_size = 128 - + + # alt_stream = + + # @torch._dynamo.disable + # def streams_breaks(dummy_input: torch.Tensor) -> torch.Tensor: + # device = torch.device("cuda") + # current_stream = torch.cuda.current_stream(device=device) + # alt_stream = torch.cuda.Stream(device=device) + + # alt_stream.wait_stream(current_stream) + + # with torch.cuda.stream(alt_stream): + # a = torch.randn(256, 256, device=device) + # b = torch.randn(256, 256, device=device) + # mm = a @ b + # y = torch.tanh(mm) + # z = y * 1.0001 + y.sin() + + # current_stream.wait_stream(alt_stream) + + + # out = z.sum().reshape(1) + # return out + + + # def streams_breaks_fake(dummy_input: torch.Tensor) -> torch.Tensor: + # """ + # Fake implementation for compile-time shape inference. + # Returns a meta tensor with the same shape/dtype contract as runtime. + # """ + # return torch.empty((1,), device="meta", dtype=torch.float32) + + @torch._dynamo.disable + def streams_breaks( + layer_prefix: str, + routed_scaling_factor: float, + hidden_states: torch.Tensor, + shared_output_q: torch.Tensor, + shared_output_s: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + + #assert not torch.cuda.graphs.is_current_stream_capturing(), "OMAR OMAR OMAR OMAR" + ctx = get_forward_context() + m = ctx.no_compile_layers[layer_prefix] # DeepseekV2MoE instance + if m is None: + m = get_current_vllm_config().compilation_config.static_forward_context[layer_prefix] + + # if torch.cuda.graphs.is_current_stream_capturing(): + # print("WE ARE WE ARE WE ARE") + print(m) + print("[DEBUG] m above") + #current_stream = torch.cuda.current_stream() #change to m.curr + #m.alt_stream.wait_stream(current_stream) + assert m.experts is not None, "[DEBUG OMAR] m.experts is None" + assert m.shared_experts.down_proj is not None, "[DEBUG OMAR] m.down_proj is None" + assert m.shared_experts.down_proj is not None, "[DEBUG OMAR] m.fusedmuladd is None" + + #with torch.cuda.stream(m.alt_stream): + final_hidden_states = m.experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + shared_output, _ = m.shared_experts.down_proj( + shared_output_q, x_quant_scales=shared_output_s + ) + + #current_stream.wait_stream(m.alt_stream) + + final_hidden_states = fused_mul_add( + final_hidden_states, + routed_scaling_factor, shared_output + ) + print("[DEBUG] WE ARE ABOUT TO RETURN< AKA STREAMS WORKED AND RAN") + print(shared_output) + print(final_hidden_states) + return shared_output, final_hidden_states + + + # minimal fake impl for compile-time (shape-only) + def streams_breaks_fake( + layer_prefix: str, + routed_scaling_factor: float, + hidden_states: torch.Tensor, + shared_output_q: torch.Tensor, + shared_output_s: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Shape-only fake: outputs mirror hidden_states (M, H) + # print("[DEBUG] WE ARE IN FAKE") + M, H = hidden_states.shape + device = hidden_states.device + dtype = hidden_states.dtype + shared_out = torch.empty((M, H), device=device, dtype=dtype) + final_out = torch.empty((M, H), device=device, dtype=dtype) + # print("[DEBUG] WE ARE OUT OF FAKE") + + return shared_out, final_out + + + def rocm_aiter_triton_fused_shared_expert_impl( hidden_states_shared: torch.Tensor, hidden_states_shared_scale: torch.Tensor, @@ -102,6 +203,7 @@ def rocm_aiter_triton_fused_shared_expert_impl( bias_shared: torch.Tensor, bias_moe_gate: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + #assert not torch.cuda.graphs.is_current_stream_capturing(), "fused code" shared_output, router_logits = fused_gemm_a8w8_blockscale_a16w16(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up, hidden_states_moe_gate, weight_moe_gate, bias_fp8=bias_shared, bias_bf16=bias_moe_gate, dtype=hidden_states_moe_gate.dtype, skip_reduce=True) if shared_output.dim() == 3: @@ -133,6 +235,15 @@ def rocm_aiter_triton_fused_shared_expert_fake( router_logits = torch.empty((M, N_moe), dtype=hidden_states_moe_gate.dtype, device=device) return shared_output_q, shared_output_s, router_logits + # register as torch.ops.vllm.streams_breaks + direct_register_custom_op( + op_name="streams_breaks", + op_func=streams_breaks, + mutates_args=[], + fake_impl=streams_breaks_fake, + dispatch_key=current_platform.dispatch_key, # use platform's key + ) + direct_register_custom_op( op_name="rocm_aiter_triton_fused_shared_expert", op_func=rocm_aiter_triton_fused_shared_expert_impl, @@ -198,6 +309,7 @@ def forward(self, x): return x + class DeepseekV2MoE(nn.Module): def __init__( @@ -206,16 +318,19 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_eplb: bool = False, + alt_stream: Optional[torch.cuda.Stream] = None, + #add current stream = curr ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor - + self.prefix = prefix self.ep_group = get_ep_group().device_group self.ep_rank = self.ep_group.rank() self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts + self.alt_stream = alt_stream if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -281,6 +396,61 @@ def __init__( ), prefix=f"{prefix}.shared_experts", ) + + + + # @torch._dynamo.disable + def forward_shared_experts(self, hidden_states: torch.Tensor, hidden_states_shared: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states_shared, hidden_states_shared_scale = hidden_states_shared + shared_output_q, shared_output_s, router_logits = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert( + hidden_states_shared = hidden_states_shared, + hidden_states_shared_scale = hidden_states_shared_scale, + weight_gate_up = self.shared_experts.gate_up_proj.weight, + weight_scale_gate_up = self.shared_experts.gate_up_proj.weight_scale_inv, + hidden_states_moe_gate = hidden_states, + weight_moe_gate = self.gate.weight, + bias_shared = self.shared_experts.gate_up_proj.bias if not self.shared_experts.gate_up_proj.skip_bias_add else None, + bias_moe_gate = self.gate.bias if not self.gate.skip_bias_add else None, + ) + + #Overlap shared_fused_exports and MOE + #Guard to prevent repeated lines + did_fma = False + if self.alt_stream is not None: #and VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16: #Commented out for debugging purposes + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + + # if hasattr(torch.cuda, "nvtx"): + # torch.cuda.nvtx.range_push("deepseek_alt_downproj") + + + with torch.cuda.stream(self.alt_stream): + shared_output, _ = self.shared_experts.down_proj( + shared_output_q, x_quant_scales = shared_output_s + ) + # e1 = torch.cuda.Event() + # e1.record(self.alt_stream) + + + # if hasattr(torch.cuda, "nvtx"): + # torch.cuda.nvtx.range_pop() + + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits + ) + + current_stream.wait_stream(self.alt_stream) + + final_hidden_states = fused_mul_add( + final_hidden_states, self.routed_scaling_factor, shared_output + ) + + did_fma = True + else: + shared_output, _ = self.shared_experts.down_proj(shared_output_q, x_quant_scales = shared_output_s) + + return shared_output, final_hidden_states def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if isinstance(hidden_states, tuple): @@ -290,6 +460,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + + did_fma = True if VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS and self.n_shared_experts is not None: hidden_states_shared, hidden_states_shared_scale = hidden_states_shared shared_output_q, shared_output_s, router_logits = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert( @@ -301,41 +473,76 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: weight_moe_gate = self.gate.weight, bias_shared = self.shared_experts.gate_up_proj.bias if not self.shared_experts.gate_up_proj.skip_bias_add else None, bias_moe_gate = self.gate.bias if not self.gate.skip_bias_add else None, + ) + + if self.alt_stream is not None and VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16: + shared_output, final_hidden_states = torch.ops.vllm.streams_breaks( + layer_prefix=self.prefix, + routed_scaling_factor=self.routed_scaling_factor, + hidden_states=hidden_states, + shared_output_q=shared_output_q, + shared_output_s=shared_output_s, + router_logits=router_logits, ) - shared_output, _ = self.shared_experts.down_proj(shared_output_q, x_quant_scales = shared_output_s) + + #ignoreThis = torch.ops.vllm.streams_breaks(hidden_states[0]) + + #shared_output, final_hidden_states = self.forward_shared_experts(hidden_states, hidden_states_shared) + did_fma = True + # hidden_states_shared, hidden_states_shared_scale = hidden_states_shared + # shared_output_q, shared_output_s, router_logits = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert( + # hidden_states_shared = hidden_states_shared, + # hidden_states_shared_scale = hidden_states_shared_scale, + # weight_gate_up = self.shared_experts.gate_up_proj.weight, + # weight_scale_gate_up = self.shared_experts.gate_up_proj.weight_scale_inv, + # hidden_states_moe_gate = hidden_states, + # weight_moe_gate = self.gate.weight, + # bias_shared = self.shared_experts.gate_up_proj.bias if not self.shared_experts.gate_up_proj.skip_bias_add else None, + # bias_moe_gate = self.gate.bias if not self.gate.skip_bias_add else None, + # ) + + # #Overlap shared_fused_exports and MOE + # #Guard to prevent repeated lines + # did_fma = False + # if self.alt_stream is not None: #and VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16: #Commented out for debugging purposes + # else: + # shared_output, _ = self.shared_experts.down_proj(shared_output_q, x_quant_scales = shared_output_s) else: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states_shared) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None: - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - final_hidden_states = fused_mul_add(final_hidden_states, self.routed_scaling_factor, shared_output) - else: - if hidden_states.dtype != torch.float16: - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. + if not did_fma: + if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None: final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - if shared_output is not None: + final_hidden_states = fused_mul_add(final_hidden_states, self.routed_scaling_factor, shared_output) + else: if hidden_states.dtype != torch.float16: - final_hidden_states = final_hidden_states + shared_output + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor else: # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. - final_hidden_states = final_hidden_states + shared_output \ - * (1. / self.routed_scaling_factor) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if shared_output is not None: + if hidden_states.dtype != torch.float16: + final_hidden_states = final_hidden_states + shared_output + else: + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + final_hidden_states = final_hidden_states + shared_output \ + * (1. / self.routed_scaling_factor) if self.tp_size > 1: final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + final_hidden_states + ) + ) return final_hidden_states.view(num_tokens, hidden_dim) @@ -706,6 +913,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, enable_eplb: bool = False, + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -747,6 +955,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", enable_eplb=enable_eplb, + alt_stream=alt_stream ) else: self.mlp = DeepseekV2MLP( @@ -756,6 +965,14 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) + + if isinstance(self.mlp, DeepseekV2MoE): + compilation_config = get_current_vllm_config().compilation_config + name = self.mlp.prefix # e.g., "model.layers.12.mlp" + if name in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name in static_forward_context: {name}") + compilation_config.static_forward_context[name] = self.mlp + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -865,6 +1082,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.embed_tokens = PPMissingLayer() + self.alt_stream = torch.cuda.Stream() if torch.cuda.is_available() else None self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: DeepseekV2DecoderLayer( @@ -874,9 +1092,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config=cache_config, quant_config=quant_config, enable_eplb=enable_eplb, + alt_stream=self.alt_stream, ), prefix=f"{prefix}.layers") + self._moe_modules = [] + for layer in self.layers: + mlp = getattr(layer, "mlp", None) + if isinstance(mlp, DeepseekV2MoE): + self._moe_modules.append(mlp) + if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -895,6 +1120,16 @@ def forward( intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: + #torch.ops.vllm.streams_break() + # One-time registration of no-compile MoE layers + # if self not in _registered_models: + # from vllm.forward_context import get_forward_context + # ctx = get_forward_context() # valid now (we're in a real forward) + # for m in self._moe_modules: + # # register by name -> module object; streams_break will honor these + # ctx.no_compile_layers[m.prefix] = m + # _registered_models.add(self) + if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds