Skip to content

[RFC]: Automatic Kernel Fusion via torch.fx.graph and graph rewriter for vLLM-Ascend #2386

@ganyi1996ppo

Description

@ganyi1996ppo

Motivation.

Currently, the vLLM project’s high-performance execution path relies heavily on PyTorch Inductor to perform automatic kernel fusion. However, for vLLM-Ascend, we face two major limitations:

  1. No Inductor Support in torch_npu

    • The torch_npu backend does not implement Inductor becaused of the limited support on triton, meaning we cannot leverage vLLM’s existing kernel fusion mechanisms out-of-the-box.
  2. High Maintenance Burden from Model Duplication

    • To achieve equivalent fusion behavior, we have been rewriting and maintaining large portions of model code directly inside vllm-ascend, like Qwen3, Qwen3Moe, Qwen2VL etc. Some of them only have very few changes compared with the community one like Qwen3 only add a simple fusion kernel on add_rms_norm and quantize.
    • This results in:
      • Code duplication across vllm and vllm-ascend.
      • Divergence over time, making upgrades from upstream vllm difficult.
      • Reduced maintainability and increased development cost.

Without an automatic fusion mechanism, Ascend NPU execution suffers from suboptimal kernel launch patterns, leading to:

  • Higher latency due to excessive kernel dispatch.
  • Poor hardware utilization.
  • Increased complexity in managing performance-critical paths.

We need a solution that:

  • Keeps all model definitions inside vllm.
  • Implements NPU-specific fusions in vllm-ascend without forking or rewriting model code.
  • Is maintainable and extensible for future optimizations.

Proposed Change.

We propose implementing an Automatic Kernel Fusion Compiler Interface for Ascend NPUs that:

  • Operates on torch.fx.graph extracted from original vllm models.
  • Uses PyTorch’s Graph Rewriter API (torch.fx.subgraph_rewriter) to:
    • Match known computational subgraphs that can be fused.
    • Replace them with fused kernels implemented in torch_npu.
  • Runs before the model is executed, ensuring the runtime graph is already optimized for Ascend NPUs.

Design Overview

Our approach builds on the existing vllm execution pipeline while replacing only the Inductor-specific fusion logic with an NPU-compatible implementation. The key changes are:

  1. Leverage vLLM’s Existing VllmBackend Path

    • We keep the original VllmBackend execution flow in place.
    • Inside this flow, we intercept and replace the PostGradPassManager with our own GraphRewritePassManager, which will host all NPU-specific graph fusion passes.
  2. Custom Fusion Pass Implementation

    • We implement our fusion pass by inheriting from vllm’s VllmInductorPass.
    • This allows us to:
      • Reuse vLLM’s existing FX graph inspection and debugging utilities.
      • Keep the fusion logic consistent with vLLM’s pass structure.
      • Maintain an easy migration path to the original Inductor passes when torch_npu gains Inductor support.
  3. Adopt an NPU-Specific Compiler Interface

    • We introduce our own compiler interface that integrates with torch.compile.
    • This compiler interface will:
      • Receive the FX graph from the compiled model.
      • Apply the GraphRewritePassManager to perform pattern-based graph rewrites and kernel fusions.
      • Return the optimized GraphModule for execution on Ascend NPU.

By integrating at the VllmBackend level and reusing VllmInductorPass infrastructure, we minimize code duplication, retain compatibility with upstream vLLM’s debug and inspection tools, and prepare for a seamless transition back to Inductor when NPU support becomes available.

Integration Example For a Fusion Pass

Below is a practical example showing how to implement and integrate a custom fusion pass into the GraphRewritePassManager.

Step 1: Write a Specific Fusion Pass

We define a fusion pass that registers a pattern to fuse torch.ops.npu.npu_add_rms_norm followed by torch.ops.npu.npu_quantize into a single torch.ops.npu.npu_add_rms_norm_quant op.

 # Define a fusion pattern here
class AddRMSNormQuantPattern:
    def __init__(self, vllm_config):
        self.vllm_config = vllm_config

    def register(self, patterns: List[Tuple[callable]]):
        def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset):
            """
            Pattern for AddRMSNorm + Quant fusion.
            """
            output = torch.ops.npu.npu_add_rms_norm(
                rms_norm_input,
                residual,
                rms_norm_weight,
                1e-6
            )
            new_output = output[0]
            residual = output[2]
            quantized_output = torch.ops.npu.npu_quantize(
                new_output,
                scale,
                offset,
                torch.qint8,
                -1,
                False
            )
            return quantized_output, residual

        def replace(rms_norm_input, residual, rms_norm_weight, scale, offset):
            """
            Replacement for the AddRMSNormQuant fusion.
            """
            output = torch.ops.npu.npu_add_rms_norm_quant(
                rms_norm_input,
                residual,
                rms_norm_weight,
                1. / scale,
                offset,
                epsilon=1e-6
            )
            quantized_output = output[0]
            residual = output[2]
            return quantized_output, residual

        patterns.append((pattern, replace))

 # Define a fusion pass here.
class AscendQuantFusionPass(VllmInductorPass):
    """
    A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
    """

    def __init__(self, vllm_config):
        super().__init__(vllm_config)
        self.patterns = []
        # Register the AddRMSNormQuant fusion pattern into the graph rewriter pattern list
        AddRMSNormQuantPattern(vllm_config).register(self.patterns)

    def __call__(self, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_ascend_quant_fusion")
        for pattern, replace in self.patterns:
          replace_pattern(graph, pattern, replace)
        self.dump_graph(graph, "after_ascend_quant_fusion")
        self.end_and_log()
Step 2: Add the Pass to the Pass Manager via configure
# inside the file graph_rewriter_pass_manager.py
class GraphRewritePassManager:
    ....
    def configure(self, config: VllmConfig):
        self.pass_config = config.additional_config.ascend_pass_config
        if self.pass_config.enable_addrms_norm_quant_fusion:
            # Put the specific fusion pass into the pass manager
            from .quant_fusion_pass import AscendQuantFusionPass
            self.passes.append(AscendQuantFusionPass(config))

By doing the above two step, the fusion will automaically happened inside the torch.compile in our target scenario. As we can see the graph structure below .

before graph is: 
graph():
    %output_82 : [num_users=4] = placeholder[target=output_82]
    %s0 : torch.SymInt [num_users=6] = placeholder[target=s0]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_]
    %residual_64 : [num_users=1] = placeholder[target=residual_64]
    %l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_]
    %l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_]
    %l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_]
    %l_positions_ : torch.Tensor [num_users=1] = placeholder[target=l_positions_]
    %s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
    %l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_ : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_]
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%output_82, (slice(None, s0, None), slice(None, None, None), slice(None, None, None))), kwargs = {})
    %setitem : [num_users=0] = call_function[target=operator.setitem](args = (%output_82, (slice(None, None, None), slice(None, None, None), slice(None, None, None)), %getitem), kwargs = {})
    %view : [num_users=0] = call_method[target=view](args = (%output_82, %s0, 4096), kwargs = {})
    %view_1 : [num_users=1] = call_method[target=view](args = (%output_82, -1, 4096), kwargs = {})
    %npu_quantize : [num_users=1] = call_function[target=torch.ops.npu.npu_quantize](args = (%view_1, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset, torch.qint8, -1, False), kwargs = {})
    %npu_quant_matmul : [num_users=1] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%npu_quantize, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
    %npu_add_rms_norm : [num_users=2] = call_function[target=torch.ops.npu.npu_add_rms_norm](args = (%npu_quant_matmul, %residual_64, %l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_, 1e-06), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm, 0), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm, 2), kwargs = {})
    %npu_quantize_1 : [num_users=1] = call_function[target=torch.ops.npu.npu_quantize](args = (%getitem_1, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset, torch.qint8, -1, False), kwargs = {})
    %npu_quant_matmul_1 : [num_users=2] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%npu_quantize_1, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_quant_matmul_1, (Ellipsis, slice(None, 12288, None))), kwargs = {})
    %silu : [num_users=1] = call_function[target=torch.nn.functional.silu](args = (%getitem_3,), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_quant_matmul_1, (Ellipsis, slice(12288, None, None))), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%silu, %getitem_4), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_, None), kwargs = {})
    %npu_add_rms_norm_1 : [num_users=2] = call_function[target=torch.ops.npu.npu_add_rms_norm](args = (%linear, %getitem_2, %l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_, 1e-06), kwargs = {})
    %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_1, 0), kwargs = {})
    %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_1, 2), kwargs = {})
    %npu_quantize_2 : [num_users=1] = call_function[target=torch.ops.npu.npu_quantize](args = (%getitem_5, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset, torch.qint8, -1, False), kwargs = {})
    %npu_quant_matmul_2 : [num_users=1] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%npu_quantize_2, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
    %split : [num_users=3] = call_method[target=split](args = (%npu_quant_matmul_2, [4096, 1024, 1024]), kwargs = {dim: -1})
    %getitem_7 : [num_users=2] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
    %getitem_8 : [num_users=2] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
    %getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 2), kwargs = {})
    %view_2 : [num_users=1] = call_method[target=view](args = (%getitem_7, %s0, 32, 128), kwargs = {})
    %npu_rms_norm : [num_users=1] = call_function[target=torch.ops.npu.npu_rms_norm](args = (%view_2, %l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_, 1e-06), kwargs = {})
    %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_rms_norm, 0), kwargs = {})
    %size : [num_users=1] = call_method[target=size](args = (%getitem_7,), kwargs = {})
    %view_3 : [num_users=2] = call_method[target=view](args = (%getitem_10, %size), kwargs = {})
    %view_4 : [num_users=1] = call_method[target=view](args = (%getitem_8, %s0, 8, 128), kwargs = {})
    %npu_rms_norm_1 : [num_users=1] = call_function[target=torch.ops.npu.npu_rms_norm](args = (%view_4, %l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_, 1e-06), kwargs = {})
    %getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_rms_norm_1, 0), kwargs = {})
    %size_1 : [num_users=1] = call_method[target=size](args = (%getitem_8,), kwargs = {})
    %view_5 : [num_users=2] = call_method[target=view](args = (%getitem_11, %size_1), kwargs = {})
    %size_2 : [num_users=1] = call_method[target=size](args = (%view_3,), kwargs = {})
    %size_3 : [num_users=1] = call_method[target=size](args = (%view_5,), kwargs = {})
    %contiguous : [num_users=1] = call_method[target=contiguous](args = (%view_3,), kwargs = {})
    %view_6 : [num_users=2] = call_method[target=view](args = (%contiguous, %s0, -1), kwargs = {})
    %contiguous_1 : [num_users=1] = call_method[target=contiguous](args = (%view_5,), kwargs = {})
    %view_7 : [num_users=2] = call_method[target=view](args = (%contiguous_1, %s0, -1), kwargs = {})
    %_npu_rotary_embedding : [num_users=0] = call_function[target=torch.ops.atb._npu_rotary_embedding](args = (%l_positions_, %view_6, %view_7, 128, %l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, True), kwargs = {})
    %view_8 : [num_users=2] = call_method[target=view](args = (%view_6, %size_2), kwargs = {})
    %view_9 : [num_users=1] = call_method[target=view](args = (%view_7, %size_3), kwargs = {})
    %size_4 : [num_users=1] = call_method[target=size](args = (%view_8,), kwargs = {})
    %zeros : [num_users=1] = call_function[target=torch.zeros](args = (%size_4,), kwargs = {dtype: torch.bfloat16, device: npu:0})
    %view_10 : [num_users=1] = call_method[target=view](args = (%view_8, -1, 32, 128), kwargs = {})
    %view_11 : [num_users=1] = call_method[target=view](args = (%zeros, -1, 32, 128), kwargs = {})
    %view_12 : [num_users=1] = call_method[target=view](args = (%view_9, -1, 8, 128), kwargs = {})
    %view_13 : [num_users=1] = call_method[target=view](args = (%getitem_9, -1, 8, 128), kwargs = {})
    return (view_10, view_12, view_13, view_11, getitem_6)
after graph is: 
graph():
    %output_82 : [num_users=4] = placeholder[target=output_82]
    %s0 : torch.SymInt [num_users=6] = placeholder[target=s0]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_]
    %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_]
    %residual_64 : [num_users=1] = placeholder[target=residual_64]
    %l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_]
    %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_]
    %l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_]
    %l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_]
    %l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_]
    %l_positions_ : torch.Tensor [num_users=1] = placeholder[target=l_positions_]
    %s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
    %l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_ : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_]
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%output_82, (slice(None, s0, None), slice(None, None, None), slice(None, None, None))), kwargs = {})
    %setitem : [num_users=0] = call_function[target=operator.setitem](args = (%output_82, (slice(None, None, None), slice(None, None, None), slice(None, None, None)), %getitem), kwargs = {})
    %view : [num_users=0] = call_method[target=view](args = (%output_82, %s0, 4096), kwargs = {})
    %view_1 : [num_users=1] = call_method[target=view](args = (%output_82, -1, 4096), kwargs = {})
    %npu_quantize : [num_users=1] = call_function[target=torch.ops.npu.npu_quantize](args = (%view_1, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset, torch.qint8, -1, False), kwargs = {})
    %npu_quant_matmul : [num_users=1] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%npu_quantize, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
    %truediv : [num_users=1] = call_function[target=operator.truediv](args = (1, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal), kwargs = {})
    %npu_add_rms_norm_quant : [num_users=2] = call_function[target=torch.ops.npu.npu_add_rms_norm_quant](args = (%npu_quant_matmul, %residual_64, %l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_, %truediv, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset), kwargs = {epsilon: 1e-06})
    %getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_quant, 0), kwargs = {})
    %getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_quant, 2), kwargs = {})
    %npu_quant_matmul_1 : [num_users=2] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%getitem_12, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_quant_matmul_1, (Ellipsis, slice(None, 12288, None))), kwargs = {})
    %silu : [num_users=1] = call_function[target=torch.nn.functional.silu](args = (%getitem_3,), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_quant_matmul_1, (Ellipsis, slice(12288, None, None))), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%silu, %getitem_4), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_, None), kwargs = {})
    %truediv_1 : [num_users=1] = call_function[target=operator.truediv](args = (1, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal), kwargs = {})
    %npu_add_rms_norm_quant_1 : [num_users=2] = call_function[target=torch.ops.npu.npu_add_rms_norm_quant](args = (%linear, %getitem_13, %l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_, %truediv_1, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset), kwargs = {epsilon: 1e-06})
    %getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_quant_1, 0), kwargs = {})
    %getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_quant_1, 2), kwargs = {})
    %npu_quant_matmul_2 : [num_users=1] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%getitem_14, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
    %split : [num_users=3] = call_method[target=split](args = (%npu_quant_matmul_2, [4096, 1024, 1024]), kwargs = {dim: -1})
    %getitem_7 : [num_users=2] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
    %getitem_8 : [num_users=2] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
    %getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 2), kwargs = {})
    %view_2 : [num_users=1] = call_method[target=view](args = (%getitem_7, %s0, 32, 128), kwargs = {})
    %npu_rms_norm : [num_users=1] = call_function[target=torch.ops.npu.npu_rms_norm](args = (%view_2, %l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_, 1e-06), kwargs = {})
    %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_rms_norm, 0), kwargs = {})
    %size : [num_users=1] = call_method[target=size](args = (%getitem_7,), kwargs = {})
    %view_3 : [num_users=2] = call_method[target=view](args = (%getitem_10, %size), kwargs = {})
    %view_4 : [num_users=1] = call_method[target=view](args = (%getitem_8, %s0, 8, 128), kwargs = {})
    %npu_rms_norm_1 : [num_users=1] = call_function[target=torch.ops.npu.npu_rms_norm](args = (%view_4, %l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_, 1e-06), kwargs = {})
    %getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_rms_norm_1, 0), kwargs = {})
    %size_1 : [num_users=1] = call_method[target=size](args = (%getitem_8,), kwargs = {})
    %view_5 : [num_users=2] = call_method[target=view](args = (%getitem_11, %size_1), kwargs = {})
    %size_2 : [num_users=1] = call_method[target=size](args = (%view_3,), kwargs = {})
    %size_3 : [num_users=1] = call_method[target=size](args = (%view_5,), kwargs = {})
    %contiguous : [num_users=1] = call_method[target=contiguous](args = (%view_3,), kwargs = {})
    %view_6 : [num_users=2] = call_method[target=view](args = (%contiguous, %s0, -1), kwargs = {})
    %contiguous_1 : [num_users=1] = call_method[target=contiguous](args = (%view_5,), kwargs = {})
    %view_7 : [num_users=2] = call_method[target=view](args = (%contiguous_1, %s0, -1), kwargs = {})
    %_npu_rotary_embedding : [num_users=0] = call_function[target=torch.ops.atb._npu_rotary_embedding](args = (%l_positions_, %view_6, %view_7, 128, %l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, True), kwargs = {})
    %view_8 : [num_users=2] = call_method[target=view](args = (%view_6, %size_2), kwargs = {})
    %view_9 : [num_users=1] = call_method[target=view](args = (%view_7, %size_3), kwargs = {})
    %size_4 : [num_users=1] = call_method[target=size](args = (%view_8,), kwargs = {})
    %zeros : [num_users=1] = call_function[target=torch.zeros](args = (%size_4,), kwargs = {dtype: torch.bfloat16, device: npu:0})
    %view_10 : [num_users=1] = call_method[target=view](args = (%view_8, -1, 32, 128), kwargs = {})
    %view_11 : [num_users=1] = call_method[target=view](args = (%zeros, -1, 32, 128), kwargs = {})
    %view_12 : [num_users=1] = call_method[target=view](args = (%view_9, -1, 8, 128), kwargs = {})
    %view_13 : [num_users=1] = call_method[target=view](args = (%getitem_9, -1, 8, 128), kwargs = {})
    return (view_10, view_12, view_13, view_11, getitem_15)

Feedback Period.

No response

CC List.

@wangxiyuan @Yikun @jianzs @jgong5 @realliujiaxu

Any Other Things.

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCRequest For Comments

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions