-
Notifications
You must be signed in to change notification settings - Fork 543
Description
Motivation.
Currently, the vLLM project’s high-performance execution path relies heavily on PyTorch Inductor to perform automatic kernel fusion. However, for vLLM-Ascend, we face two major limitations:
-
No Inductor Support in
torch_npu- The
torch_npubackend does not implement Inductor becaused of the limited support on triton, meaning we cannot leverage vLLM’s existing kernel fusion mechanisms out-of-the-box.
- The
-
High Maintenance Burden from Model Duplication
- To achieve equivalent fusion behavior, we have been rewriting and maintaining large portions of model code directly inside
vllm-ascend, like Qwen3, Qwen3Moe, Qwen2VL etc. Some of them only have very few changes compared with the community one like Qwen3 only add a simple fusion kernel onadd_rms_normandquantize. - This results in:
- Code duplication across
vllmandvllm-ascend. - Divergence over time, making upgrades from upstream
vllmdifficult. - Reduced maintainability and increased development cost.
- Code duplication across
- To achieve equivalent fusion behavior, we have been rewriting and maintaining large portions of model code directly inside
Without an automatic fusion mechanism, Ascend NPU execution suffers from suboptimal kernel launch patterns, leading to:
- Higher latency due to excessive kernel dispatch.
- Poor hardware utilization.
- Increased complexity in managing performance-critical paths.
We need a solution that:
- Keeps all model definitions inside
vllm. - Implements NPU-specific fusions in
vllm-ascendwithout forking or rewriting model code. - Is maintainable and extensible for future optimizations.
Proposed Change.
We propose implementing an Automatic Kernel Fusion Compiler Interface for Ascend NPUs that:
- Operates on
torch.fx.graphextracted from originalvllmmodels. - Uses PyTorch’s Graph Rewriter API (
torch.fx.subgraph_rewriter) to:- Match known computational subgraphs that can be fused.
- Replace them with fused kernels implemented in
torch_npu.
- Runs before the model is executed, ensuring the runtime graph is already optimized for Ascend NPUs.
Design Overview
Our approach builds on the existing vllm execution pipeline while replacing only the Inductor-specific fusion logic with an NPU-compatible implementation. The key changes are:
-
Leverage vLLM’s Existing
VllmBackendPath- We keep the original
VllmBackendexecution flow in place. - Inside this flow, we intercept and replace the
PostGradPassManagerwith our ownGraphRewritePassManager, which will host all NPU-specific graph fusion passes.
- We keep the original
-
Custom Fusion Pass Implementation
- We implement our fusion pass by inheriting from
vllm’sVllmInductorPass. - This allows us to:
- Reuse vLLM’s existing FX graph inspection and debugging utilities.
- Keep the fusion logic consistent with vLLM’s pass structure.
- Maintain an easy migration path to the original Inductor passes when
torch_npugains Inductor support.
- We implement our fusion pass by inheriting from
-
Adopt an NPU-Specific Compiler Interface
- We introduce our own compiler interface that integrates with
torch.compile. - This compiler interface will:
- Receive the FX graph from the compiled model.
- Apply the
GraphRewritePassManagerto perform pattern-based graph rewrites and kernel fusions. - Return the optimized
GraphModulefor execution on Ascend NPU.
- We introduce our own compiler interface that integrates with
By integrating at the VllmBackend level and reusing VllmInductorPass infrastructure, we minimize code duplication, retain compatibility with upstream vLLM’s debug and inspection tools, and prepare for a seamless transition back to Inductor when NPU support becomes available.
Integration Example For a Fusion Pass
Below is a practical example showing how to implement and integrate a custom fusion pass into the GraphRewritePassManager.
Step 1: Write a Specific Fusion Pass
We define a fusion pass that registers a pattern to fuse torch.ops.npu.npu_add_rms_norm followed by torch.ops.npu.npu_quantize into a single torch.ops.npu.npu_add_rms_norm_quant op.
# Define a fusion pattern here
class AddRMSNormQuantPattern:
def __init__(self, vllm_config):
self.vllm_config = vllm_config
def register(self, patterns: List[Tuple[callable]]):
def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset):
"""
Pattern for AddRMSNorm + Quant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(
rms_norm_input,
residual,
rms_norm_weight,
1e-6
)
new_output = output[0]
residual = output[2]
quantized_output = torch.ops.npu.npu_quantize(
new_output,
scale,
offset,
torch.qint8,
-1,
False
)
return quantized_output, residual
def replace(rms_norm_input, residual, rms_norm_weight, scale, offset):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
1. / scale,
offset,
epsilon=1e-6
)
quantized_output = output[0]
residual = output[2]
return quantized_output, residual
patterns.append((pattern, replace))
# Define a fusion pass here.
class AscendQuantFusionPass(VllmInductorPass):
"""
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
"""
def __init__(self, vllm_config):
super().__init__(vllm_config)
self.patterns = []
# Register the AddRMSNormQuant fusion pattern into the graph rewriter pattern list
AddRMSNormQuantPattern(vllm_config).register(self.patterns)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_ascend_quant_fusion")
for pattern, replace in self.patterns:
replace_pattern(graph, pattern, replace)
self.dump_graph(graph, "after_ascend_quant_fusion")
self.end_and_log()Step 2: Add the Pass to the Pass Manager via configure
# inside the file graph_rewriter_pass_manager.py
class GraphRewritePassManager:
....
def configure(self, config: VllmConfig):
self.pass_config = config.additional_config.ascend_pass_config
if self.pass_config.enable_addrms_norm_quant_fusion:
# Put the specific fusion pass into the pass manager
from .quant_fusion_pass import AscendQuantFusionPass
self.passes.append(AscendQuantFusionPass(config))By doing the above two step, the fusion will automaically happened inside the torch.compile in our target scenario. As we can see the graph structure below .
before graph is:
graph():
%output_82 : [num_users=4] = placeholder[target=output_82]
%s0 : torch.SymInt [num_users=6] = placeholder[target=s0]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_]
%residual_64 : [num_users=1] = placeholder[target=residual_64]
%l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_]
%l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_]
%l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_]
%l_positions_ : torch.Tensor [num_users=1] = placeholder[target=l_positions_]
%s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
%l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_ : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_]
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%output_82, (slice(None, s0, None), slice(None, None, None), slice(None, None, None))), kwargs = {})
%setitem : [num_users=0] = call_function[target=operator.setitem](args = (%output_82, (slice(None, None, None), slice(None, None, None), slice(None, None, None)), %getitem), kwargs = {})
%view : [num_users=0] = call_method[target=view](args = (%output_82, %s0, 4096), kwargs = {})
%view_1 : [num_users=1] = call_method[target=view](args = (%output_82, -1, 4096), kwargs = {})
%npu_quantize : [num_users=1] = call_function[target=torch.ops.npu.npu_quantize](args = (%view_1, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset, torch.qint8, -1, False), kwargs = {})
%npu_quant_matmul : [num_users=1] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%npu_quantize, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
%npu_add_rms_norm : [num_users=2] = call_function[target=torch.ops.npu.npu_add_rms_norm](args = (%npu_quant_matmul, %residual_64, %l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_, 1e-06), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm, 0), kwargs = {})
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm, 2), kwargs = {})
%npu_quantize_1 : [num_users=1] = call_function[target=torch.ops.npu.npu_quantize](args = (%getitem_1, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset, torch.qint8, -1, False), kwargs = {})
%npu_quant_matmul_1 : [num_users=2] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%npu_quantize_1, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_quant_matmul_1, (Ellipsis, slice(None, 12288, None))), kwargs = {})
%silu : [num_users=1] = call_function[target=torch.nn.functional.silu](args = (%getitem_3,), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_quant_matmul_1, (Ellipsis, slice(12288, None, None))), kwargs = {})
%mul : [num_users=1] = call_function[target=operator.mul](args = (%silu, %getitem_4), kwargs = {})
%linear : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_, None), kwargs = {})
%npu_add_rms_norm_1 : [num_users=2] = call_function[target=torch.ops.npu.npu_add_rms_norm](args = (%linear, %getitem_2, %l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_, 1e-06), kwargs = {})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_1, 0), kwargs = {})
%getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_1, 2), kwargs = {})
%npu_quantize_2 : [num_users=1] = call_function[target=torch.ops.npu.npu_quantize](args = (%getitem_5, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset, torch.qint8, -1, False), kwargs = {})
%npu_quant_matmul_2 : [num_users=1] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%npu_quantize_2, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
%split : [num_users=3] = call_method[target=split](args = (%npu_quant_matmul_2, [4096, 1024, 1024]), kwargs = {dim: -1})
%getitem_7 : [num_users=2] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
%getitem_8 : [num_users=2] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
%getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 2), kwargs = {})
%view_2 : [num_users=1] = call_method[target=view](args = (%getitem_7, %s0, 32, 128), kwargs = {})
%npu_rms_norm : [num_users=1] = call_function[target=torch.ops.npu.npu_rms_norm](args = (%view_2, %l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_, 1e-06), kwargs = {})
%getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_rms_norm, 0), kwargs = {})
%size : [num_users=1] = call_method[target=size](args = (%getitem_7,), kwargs = {})
%view_3 : [num_users=2] = call_method[target=view](args = (%getitem_10, %size), kwargs = {})
%view_4 : [num_users=1] = call_method[target=view](args = (%getitem_8, %s0, 8, 128), kwargs = {})
%npu_rms_norm_1 : [num_users=1] = call_function[target=torch.ops.npu.npu_rms_norm](args = (%view_4, %l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_, 1e-06), kwargs = {})
%getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_rms_norm_1, 0), kwargs = {})
%size_1 : [num_users=1] = call_method[target=size](args = (%getitem_8,), kwargs = {})
%view_5 : [num_users=2] = call_method[target=view](args = (%getitem_11, %size_1), kwargs = {})
%size_2 : [num_users=1] = call_method[target=size](args = (%view_3,), kwargs = {})
%size_3 : [num_users=1] = call_method[target=size](args = (%view_5,), kwargs = {})
%contiguous : [num_users=1] = call_method[target=contiguous](args = (%view_3,), kwargs = {})
%view_6 : [num_users=2] = call_method[target=view](args = (%contiguous, %s0, -1), kwargs = {})
%contiguous_1 : [num_users=1] = call_method[target=contiguous](args = (%view_5,), kwargs = {})
%view_7 : [num_users=2] = call_method[target=view](args = (%contiguous_1, %s0, -1), kwargs = {})
%_npu_rotary_embedding : [num_users=0] = call_function[target=torch.ops.atb._npu_rotary_embedding](args = (%l_positions_, %view_6, %view_7, 128, %l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, True), kwargs = {})
%view_8 : [num_users=2] = call_method[target=view](args = (%view_6, %size_2), kwargs = {})
%view_9 : [num_users=1] = call_method[target=view](args = (%view_7, %size_3), kwargs = {})
%size_4 : [num_users=1] = call_method[target=size](args = (%view_8,), kwargs = {})
%zeros : [num_users=1] = call_function[target=torch.zeros](args = (%size_4,), kwargs = {dtype: torch.bfloat16, device: npu:0})
%view_10 : [num_users=1] = call_method[target=view](args = (%view_8, -1, 32, 128), kwargs = {})
%view_11 : [num_users=1] = call_method[target=view](args = (%zeros, -1, 32, 128), kwargs = {})
%view_12 : [num_users=1] = call_method[target=view](args = (%view_9, -1, 8, 128), kwargs = {})
%view_13 : [num_users=1] = call_method[target=view](args = (%getitem_9, -1, 8, 128), kwargs = {})
return (view_10, view_12, view_13, view_11, getitem_6)
after graph is:
graph():
%output_82 : [num_users=4] = placeholder[target=output_82]
%s0 : torch.SymInt [num_users=6] = placeholder[target=s0]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_]
%l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_]
%residual_64 : [num_users=1] = placeholder[target=residual_64]
%l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_]
%l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_]
%l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_]
%l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_]
%l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_]
%l_positions_ : torch.Tensor [num_users=1] = placeholder[target=l_positions_]
%s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
%l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_ : torch.Tensor [num_users=1] = placeholder[target=l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_]
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%output_82, (slice(None, s0, None), slice(None, None, None), slice(None, None, None))), kwargs = {})
%setitem : [num_users=0] = call_function[target=operator.setitem](args = (%output_82, (slice(None, None, None), slice(None, None, None), slice(None, None, None)), %getitem), kwargs = {})
%view : [num_users=0] = call_method[target=view](args = (%output_82, %s0, 4096), kwargs = {})
%view_1 : [num_users=1] = call_method[target=view](args = (%output_82, -1, 4096), kwargs = {})
%npu_quantize : [num_users=1] = call_function[target=torch.ops.npu.npu_quantize](args = (%view_1, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_scale_reciprocal, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_aclnn_input_offset, torch.qint8, -1, False), kwargs = {})
%npu_quant_matmul : [num_users=1] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%npu_quantize, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_weight_, %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_16_modules_self_attn_modules_o_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
%truediv : [num_users=1] = call_function[target=operator.truediv](args = (1, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_scale_reciprocal), kwargs = {})
%npu_add_rms_norm_quant : [num_users=2] = call_function[target=torch.ops.npu.npu_add_rms_norm_quant](args = (%npu_quant_matmul, %residual_64, %l_self_modules_layers_modules_16_modules_post_attention_layernorm_parameters_weight_, %truediv, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_aclnn_input_offset), kwargs = {epsilon: 1e-06})
%getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_quant, 0), kwargs = {})
%getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_quant, 2), kwargs = {})
%npu_quant_matmul_1 : [num_users=2] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%getitem_12, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_weight_, %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_16_modules_mlp_modules_gate_up_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_quant_matmul_1, (Ellipsis, slice(None, 12288, None))), kwargs = {})
%silu : [num_users=1] = call_function[target=torch.nn.functional.silu](args = (%getitem_3,), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_quant_matmul_1, (Ellipsis, slice(12288, None, None))), kwargs = {})
%mul : [num_users=1] = call_function[target=operator.mul](args = (%silu, %getitem_4), kwargs = {})
%linear : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %l_self_modules_layers_modules_16_modules_mlp_modules_down_proj_parameters_weight_, None), kwargs = {})
%truediv_1 : [num_users=1] = call_function[target=operator.truediv](args = (1, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_scale_reciprocal), kwargs = {})
%npu_add_rms_norm_quant_1 : [num_users=2] = call_function[target=torch.ops.npu.npu_add_rms_norm_quant](args = (%linear, %getitem_13, %l_self_modules_layers_modules_17_modules_input_layernorm_parameters_weight_, %truediv_1, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_aclnn_input_offset), kwargs = {epsilon: 1e-06})
%getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_quant_1, 0), kwargs = {})
%getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_add_rms_norm_quant_1, 2), kwargs = {})
%npu_quant_matmul_2 : [num_users=1] = call_function[target=torch.ops.npu.npu_quant_matmul](args = (%getitem_14, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_weight_, %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_deq_scale_), kwargs = {bias: %l_self_modules_layers_modules_17_modules_self_attn_modules_qkv_proj_parameters_quant_bias_, output_dtype: torch.bfloat16})
%split : [num_users=3] = call_method[target=split](args = (%npu_quant_matmul_2, [4096, 1024, 1024]), kwargs = {dim: -1})
%getitem_7 : [num_users=2] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
%getitem_8 : [num_users=2] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
%getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 2), kwargs = {})
%view_2 : [num_users=1] = call_method[target=view](args = (%getitem_7, %s0, 32, 128), kwargs = {})
%npu_rms_norm : [num_users=1] = call_function[target=torch.ops.npu.npu_rms_norm](args = (%view_2, %l_self_modules_layers_modules_17_modules_self_attn_modules_q_norm_parameters_weight_, 1e-06), kwargs = {})
%getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_rms_norm, 0), kwargs = {})
%size : [num_users=1] = call_method[target=size](args = (%getitem_7,), kwargs = {})
%view_3 : [num_users=2] = call_method[target=view](args = (%getitem_10, %size), kwargs = {})
%view_4 : [num_users=1] = call_method[target=view](args = (%getitem_8, %s0, 8, 128), kwargs = {})
%npu_rms_norm_1 : [num_users=1] = call_function[target=torch.ops.npu.npu_rms_norm](args = (%view_4, %l_self_modules_layers_modules_17_modules_self_attn_modules_k_norm_parameters_weight_, 1e-06), kwargs = {})
%getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%npu_rms_norm_1, 0), kwargs = {})
%size_1 : [num_users=1] = call_method[target=size](args = (%getitem_8,), kwargs = {})
%view_5 : [num_users=2] = call_method[target=view](args = (%getitem_11, %size_1), kwargs = {})
%size_2 : [num_users=1] = call_method[target=size](args = (%view_3,), kwargs = {})
%size_3 : [num_users=1] = call_method[target=size](args = (%view_5,), kwargs = {})
%contiguous : [num_users=1] = call_method[target=contiguous](args = (%view_3,), kwargs = {})
%view_6 : [num_users=2] = call_method[target=view](args = (%contiguous, %s0, -1), kwargs = {})
%contiguous_1 : [num_users=1] = call_method[target=contiguous](args = (%view_5,), kwargs = {})
%view_7 : [num_users=2] = call_method[target=view](args = (%contiguous_1, %s0, -1), kwargs = {})
%_npu_rotary_embedding : [num_users=0] = call_function[target=torch.ops.atb._npu_rotary_embedding](args = (%l_positions_, %view_6, %view_7, 128, %l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, True), kwargs = {})
%view_8 : [num_users=2] = call_method[target=view](args = (%view_6, %size_2), kwargs = {})
%view_9 : [num_users=1] = call_method[target=view](args = (%view_7, %size_3), kwargs = {})
%size_4 : [num_users=1] = call_method[target=size](args = (%view_8,), kwargs = {})
%zeros : [num_users=1] = call_function[target=torch.zeros](args = (%size_4,), kwargs = {dtype: torch.bfloat16, device: npu:0})
%view_10 : [num_users=1] = call_method[target=view](args = (%view_8, -1, 32, 128), kwargs = {})
%view_11 : [num_users=1] = call_method[target=view](args = (%zeros, -1, 32, 128), kwargs = {})
%view_12 : [num_users=1] = call_method[target=view](args = (%view_9, -1, 8, 128), kwargs = {})
%view_13 : [num_users=1] = call_method[target=view](args = (%getitem_9, -1, 8, 128), kwargs = {})
return (view_10, view_12, view_13, view_11, getitem_15)
Feedback Period.
No response
CC List.
@wangxiyuan @Yikun @jianzs @jgong5 @realliujiaxu
Any Other Things.
No response