Skip to content

Commit a44c7a3

Browse files
committed
Optimize perf of Qwen3
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 69b817e commit a44c7a3

File tree

4 files changed

+261
-5
lines changed

4 files changed

+261
-5
lines changed

vllm_ascend/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1415

1516
ModelRegistry.register_model(
1617
"DeepSeekMTPModel",
@@ -47,3 +48,7 @@ def register_model():
4748
ModelRegistry.register_model(
4849
"Qwen3MoeForCausalLM",
4950
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
51+
52+
ModelRegistry.register_model(
53+
"Qwen3ForCausalLM",
54+
"vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")

vllm_ascend/models/qwen3.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from collections.abc import Iterable
2+
from typing import Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
from transformers import Qwen3Config
7+
8+
from vllm.attention import AttentionType
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import CacheConfig, VllmConfig
11+
from vllm.distributed import get_pp_group
12+
from vllm.model_executor.layers.layernorm import RMSNorm
13+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
14+
from vllm.model_executor.layers.quantization import QuantizationConfig
15+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
16+
from vllm.model_executor.sampling_metadata import SamplingMetadata
17+
from vllm.sequence import IntermediateTensors
18+
19+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
20+
from vllm.model_executor.models.qwen2 import Qwen2MLP as Qwen3MLP
21+
from vllm.model_executor.models.qwen2 import Qwen2Model
22+
from vllm.model_executor.models.qwen3 import Qwen3Attention
23+
from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
24+
25+
from vllm_ascend.ops.layernorm import AddRMSNormQuant
26+
27+
class CustomQwen3DecoderLayer(nn.Module):
28+
29+
def __init__(
30+
self,
31+
config: Qwen3Config,
32+
cache_config: Optional[CacheConfig] = None,
33+
quant_config: Optional[QuantizationConfig] = None,
34+
prefix: str = "",
35+
) -> None:
36+
super().__init__()
37+
self.hidden_size = config.hidden_size
38+
# Requires transformers > 4.32.0
39+
rope_theta = getattr(config, "rope_theta", 1000000)
40+
rope_scaling = getattr(config, "rope_scaling", None)
41+
42+
# By default, Qwen3 uses causal attention as it is a decoder-only model.
43+
# You can override the HF config with `is_causal=False` to enable
44+
# bidirectional attention, which is used in some embedding models
45+
# (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
46+
if getattr(config, "is_causal", True):
47+
attn_type = AttentionType.DECODER
48+
else:
49+
attn_type = AttentionType.ENCODER_ONLY
50+
51+
self.self_attn = Qwen3Attention(
52+
hidden_size=self.hidden_size,
53+
num_heads=config.num_attention_heads,
54+
max_position=config.max_position_embeddings,
55+
num_kv_heads=config.num_key_value_heads,
56+
rope_theta=rope_theta,
57+
rms_norm_eps=config.rms_norm_eps,
58+
qkv_bias=getattr(config, 'attention_bias', False),
59+
head_dim=getattr(config, 'head_dim', None),
60+
cache_config=cache_config,
61+
quant_config=quant_config,
62+
rope_scaling=rope_scaling,
63+
prefix=f"{prefix}.self_attn",
64+
attn_type=attn_type,
65+
)
66+
self.mlp = Qwen3MLP(
67+
hidden_size=self.hidden_size,
68+
intermediate_size=config.intermediate_size,
69+
hidden_act=config.hidden_act,
70+
quant_config=quant_config,
71+
prefix=f"{prefix}.mlp",
72+
)
73+
if quant_config is None:
74+
self.input_layernorm = RMSNorm(config.hidden_size,
75+
eps=config.rms_norm_eps)
76+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
77+
eps=config.rms_norm_eps)
78+
else:
79+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
80+
assert isinstance(quant_config, AscendQuantConfig)
81+
self.input_layernorm = AddRMSNormQuant(config.hidden_size,
82+
layer=self.self_attn.qkv_proj,
83+
eps=config.rms_norm_eps)
84+
self.post_attention_layernorm = AddRMSNormQuant(config.hidden_size,
85+
layer=self.mlp.gate_up_proj,
86+
eps=config.rms_norm_eps)
87+
88+
def forward(
89+
self,
90+
positions: torch.Tensor,
91+
hidden_states: torch.Tensor,
92+
residual: Optional[torch.Tensor],
93+
) -> tuple[torch.Tensor, torch.Tensor]:
94+
# Self Attention
95+
if residual is None:
96+
residual = hidden_states
97+
hidden_states = self.input_layernorm(hidden_states)
98+
else:
99+
hidden_states, residual = self.input_layernorm(
100+
hidden_states, residual)
101+
hidden_states = self.self_attn(
102+
positions=positions,
103+
hidden_states=hidden_states,
104+
)
105+
hidden_states, residual = self.post_attention_layernorm(
106+
hidden_states, residual)
107+
hidden_states = self.mlp(hidden_states)
108+
return hidden_states, residual
109+
110+
111+
ALL_DECODER_LAYER_TYPES = {
112+
"attention": CustomQwen3DecoderLayer,
113+
}
114+
115+
116+
@support_torch_compile(
117+
dynamic_arg_dims={
118+
"input_ids": 0,
119+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
120+
# otherwise (seq_len, ).
121+
"positions": -1,
122+
"intermediate_tensors": 0,
123+
"inputs_embeds": 0,
124+
})
125+
class CustomQwen3Model(Qwen2Model):
126+
127+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
128+
super().__init__(vllm_config=vllm_config,
129+
prefix=prefix,
130+
decoder_layer_type=CustomQwen3DecoderLayer)
131+
132+
133+
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
134+
# add `CustomQwen3Model` to init self.model
135+
packed_modules_mapping = {
136+
"qkv_proj": [
137+
"q_proj",
138+
"k_proj",
139+
"v_proj",
140+
],
141+
"gate_up_proj": [
142+
"gate_proj",
143+
"up_proj",
144+
],
145+
}
146+
147+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
148+
super().__init__()
149+
config = vllm_config.model_config.hf_config
150+
quant_config = vllm_config.quant_config
151+
lora_config = vllm_config.lora_config
152+
153+
self.config = config
154+
self.lora_config = lora_config
155+
156+
self.quant_config = quant_config
157+
self.model = CustomQwen3Model(vllm_config=vllm_config,
158+
prefix=maybe_prefix(prefix, "model"))
159+
160+
if get_pp_group().is_last_rank:
161+
if config.tie_word_embeddings:
162+
self.lm_head = self.model.embed_tokens
163+
else:
164+
self.lm_head = ParallelLMHead(config.vocab_size,
165+
config.hidden_size,
166+
quant_config=quant_config,
167+
prefix=maybe_prefix(
168+
prefix, "lm_head"))
169+
else:
170+
self.lm_head = PPMissingLayer()
171+
172+
self.logits_processor = LogitsProcessor(config.vocab_size)
173+
174+
self.make_empty_intermediate_tensors = (
175+
self.model.make_empty_intermediate_tensors)
176+
177+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
178+
return self.model.get_input_embeddings(input_ids)
179+
180+
def forward(
181+
self,
182+
input_ids: torch.Tensor,
183+
positions: torch.Tensor,
184+
intermediate_tensors: Optional[IntermediateTensors] = None,
185+
inputs_embeds: Optional[torch.Tensor] = None,
186+
) -> Union[torch.Tensor, IntermediateTensors]:
187+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
188+
inputs_embeds)
189+
return hidden_states
190+
191+
def compute_logits(
192+
self,
193+
hidden_states: torch.Tensor,
194+
sampling_metadata: SamplingMetadata,
195+
) -> Optional[torch.Tensor]:
196+
logits = self.logits_processor(self.lm_head, hidden_states,
197+
sampling_metadata)
198+
return logits
199+
200+
def load_weights(self, weights: Iterable[tuple[str,
201+
torch.Tensor]]) -> set[str]:
202+
loader = AutoWeightsLoader(
203+
self,
204+
skip_prefixes=(["lm_head."]
205+
if self.config.tie_word_embeddings else None),
206+
)
207+
return loader.load_weights(weights)

vllm_ascend/ops/layernorm.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,45 @@
2121
from vllm.model_executor.layers.layernorm import RMSNorm
2222

2323

24+
class AddRMSNormQuant(RMSNorm):
25+
"""Root mean square normalization.
26+
27+
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
28+
Refer to https://arxiv.org/abs/1910.07467
29+
"""
30+
def __init__(
31+
self,
32+
hidden_size: int,
33+
layer: torch.nn.Module,
34+
eps: float = 1e-6,
35+
var_hidden_size: Optional[int] = None,
36+
has_weight: bool = True,
37+
dtype: Optional[torch.dtype] = None,
38+
) -> None:
39+
super().__init__(hidden_size,
40+
eps,
41+
var_hidden_size,
42+
has_weight,
43+
dtype)
44+
self.layer = layer
45+
46+
def forward(
47+
self,
48+
x: torch.Tensor,
49+
residual: Optional[torch.Tensor] = None,
50+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
51+
import torch_npu
52+
53+
if residual is not None:
54+
x, _, residual = torch_npu.npu_add_rms_norm_quant(x, residual, self.weight,
55+
self.layer.aclnn_input_scale,
56+
self.layer.aclnn_input_offset,
57+
epsilon=self.variance_epsilon)
58+
return x, residual
59+
60+
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
61+
return x
62+
2463
def forward_oot(
2564
self,
2665
x: torch.Tensor,

vllm_ascend/quantization/w8a8.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class AscendW8A8LinearMethod:
3333
Args:
3434
w_sym: whether the linear weight is symmetrically quantized.
3535
"""
36+
params_dtype: torch.dtype = torch.bfloat16
3637

3738
def __init__(self) -> None:
3839
# aclnn quant matmul requires to transpose matrix B, set to true by default.
@@ -54,6 +55,7 @@ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
5455
params_dict = {}
5556
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
5657
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
58+
AscendW8A8LinearMethod.params_dtype = params_dtype
5759
return params_dict
5860

5961
@staticmethod
@@ -75,6 +77,7 @@ def get_perchannel_param(
7577
params_dict["weight_offset"] = torch.empty(output_size,
7678
1,
7779
dtype=params_dtype)
80+
AscendW8A8LinearMethod.params_dtype = params_dtype
7881
return params_dict
7982

8083
@staticmethod
@@ -84,11 +87,10 @@ def apply(
8487
bias: Optional[torch.Tensor] = None,
8588
tp_rank: Optional[int] = 0,
8689
) -> torch.Tensor:
87-
original_dtype = x.dtype
88-
if original_dtype != torch.int8:
90+
if x.dtype != torch.int8:
8991
x = quant_per_tensor(
9092
x,
91-
layer.aclnn_input_scale,
93+
layer.aclnn_input_scale_reciprocal,
9294
layer.aclnn_input_offset,
9395
)
9496
quant_bias = layer.quant_bias if tp_rank == 0 else None
@@ -97,12 +99,15 @@ def apply(
9799
layer.weight,
98100
layer.deq_scale,
99101
bias=quant_bias,
100-
output_dtype=original_dtype,
102+
output_dtype=AscendW8A8LinearMethod.params_dtype,
101103
)
102104

103105
def process_weights_after_loading(self, layer):
104106
expanding_factor = layer.weight.data.shape[1]
105-
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
107+
layer.aclnn_input_scale = torch.nn.Parameter(
108+
layer.input_scale.data.repeat(expanding_factor),
109+
requires_grad=False)
110+
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
106111
layer.input_scale.data.repeat(expanding_factor),
107112
requires_grad=False)
108113
layer.aclnn_input_offset = torch.nn.Parameter(

0 commit comments

Comments
 (0)