Skip to content

Commit ec4fc9e

Browse files
committed
Optimize perf of Qwen3
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 69b817e commit ec4fc9e

File tree

4 files changed

+259
-3
lines changed

4 files changed

+259
-3
lines changed

vllm_ascend/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1415

1516
ModelRegistry.register_model(
1617
"DeepSeekMTPModel",
@@ -47,3 +48,7 @@ def register_model():
4748
ModelRegistry.register_model(
4849
"Qwen3MoeForCausalLM",
4950
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
51+
52+
ModelRegistry.register_model(
53+
"Qwen3ForCausalLM",
54+
"vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")

vllm_ascend/models/qwen3.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
from collections.abc import Iterable
2+
from typing import Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
from transformers import Qwen3Config
7+
8+
from vllm.attention import Attention, AttentionType
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import CacheConfig, VllmConfig
11+
from vllm.distributed import get_pp_group
12+
from vllm.model_executor.layers.layernorm import RMSNorm
13+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
14+
from vllm.model_executor.layers.quantization import QuantizationConfig
15+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
16+
from vllm.model_executor.sampling_metadata import SamplingMetadata
17+
from vllm.sequence import IntermediateTensors
18+
19+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
20+
from vllm.model_executor.models.qwen2 import Qwen2MLP as Qwen3MLP
21+
from vllm.model_executor.models.qwen2 import Qwen2Model
22+
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3Attention
23+
from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
24+
25+
from vllm_ascend.ops.layernorm import AddRMSNormQuant
26+
27+
class CustomQwen3DecoderLayer(nn.Module):
28+
29+
def __init__(
30+
self,
31+
config: Qwen3Config,
32+
cache_config: Optional[CacheConfig] = None,
33+
quant_config: Optional[QuantizationConfig] = None,
34+
prefix: str = "",
35+
) -> None:
36+
super().__init__()
37+
self.hidden_size = config.hidden_size
38+
# Requires transformers > 4.32.0
39+
rope_theta = getattr(config, "rope_theta", 1000000)
40+
rope_scaling = getattr(config, "rope_scaling", None)
41+
42+
# By default, Qwen3 uses causal attention as it is a decoder-only model.
43+
# You can override the HF config with `is_causal=False` to enable
44+
# bidirectional attention, which is used in some embedding models
45+
# (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
46+
if getattr(config, "is_causal", True):
47+
attn_type = AttentionType.DECODER
48+
else:
49+
attn_type = AttentionType.ENCODER_ONLY
50+
51+
self.self_attn = Qwen3Attention(
52+
hidden_size=self.hidden_size,
53+
num_heads=config.num_attention_heads,
54+
max_position=config.max_position_embeddings,
55+
num_kv_heads=config.num_key_value_heads,
56+
rope_theta=rope_theta,
57+
rms_norm_eps=config.rms_norm_eps,
58+
qkv_bias=getattr(config, 'attention_bias', False),
59+
head_dim=getattr(config, 'head_dim', None),
60+
cache_config=cache_config,
61+
quant_config=quant_config,
62+
rope_scaling=rope_scaling,
63+
prefix=f"{prefix}.self_attn",
64+
attn_type=attn_type,
65+
)
66+
self.mlp = Qwen3MLP(
67+
hidden_size=self.hidden_size,
68+
intermediate_size=config.intermediate_size,
69+
hidden_act=config.hidden_act,
70+
quant_config=quant_config,
71+
prefix=f"{prefix}.mlp",
72+
)
73+
if quant_config is None:
74+
self.input_layernorm = RMSNorm(config.hidden_size,
75+
eps=config.rms_norm_eps)
76+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
77+
eps=config.rms_norm_eps)
78+
else:
79+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
80+
assert isinstance(quant_config, AscendQuantConfig)
81+
self.input_layernorm = AddRMSNormQuant(config.hidden_size,
82+
self.self_attn.qkv_proj.aclnn_input_scale,
83+
self.self_attn.qkv_proj.aclnn_input_offset,
84+
eps=config.rms_norm_eps)
85+
self.post_attention_layernorm = AddRMSNormQuant(config.hidden_size,
86+
self.mlp.gate_up_proj.aclnn_input_scale,
87+
self.mlp.gate_up_proj.aclnn_input_offset,
88+
eps=config.rms_norm_eps)
89+
90+
def forward(
91+
self,
92+
positions: torch.Tensor,
93+
hidden_states: torch.Tensor,
94+
residual: Optional[torch.Tensor],
95+
) -> tuple[torch.Tensor, torch.Tensor]:
96+
# Self Attention
97+
if residual is None:
98+
residual = hidden_states
99+
hidden_states = self.input_layernorm(hidden_states)
100+
else:
101+
hidden_states, residual = self.input_layernorm(
102+
hidden_states, residual)
103+
hidden_states = self.self_attn(
104+
positions=positions,
105+
hidden_states=hidden_states,
106+
)
107+
hidden_states, residual = self.post_attention_layernorm(
108+
hidden_states, residual)
109+
hidden_states = self.mlp(hidden_states)
110+
return hidden_states, residual
111+
112+
113+
ALL_DECODER_LAYER_TYPES = {
114+
"attention": CustomQwen3DecoderLayer,
115+
}
116+
117+
118+
@support_torch_compile(
119+
dynamic_arg_dims={
120+
"input_ids": 0,
121+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
122+
# otherwise (seq_len, ).
123+
"positions": -1,
124+
"intermediate_tensors": 0,
125+
"inputs_embeds": 0,
126+
})
127+
class CustomQwen3Model(Qwen2Model):
128+
129+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
130+
super().__init__(vllm_config=vllm_config,
131+
prefix=prefix,
132+
decoder_layer_type=CustomQwen3DecoderLayer)
133+
134+
135+
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
136+
# add `CustomQwen3Model` to init self.model
137+
packed_modules_mapping = {
138+
"qkv_proj": [
139+
"q_proj",
140+
"k_proj",
141+
"v_proj",
142+
],
143+
"gate_up_proj": [
144+
"gate_proj",
145+
"up_proj",
146+
],
147+
}
148+
149+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
150+
super().__init__()
151+
config = vllm_config.model_config.hf_config
152+
quant_config = vllm_config.quant_config
153+
lora_config = vllm_config.lora_config
154+
155+
self.config = config
156+
self.lora_config = lora_config
157+
158+
self.quant_config = quant_config
159+
self.model = CustomQwen3Model(vllm_config=vllm_config,
160+
prefix=maybe_prefix(prefix, "model"))
161+
162+
if get_pp_group().is_last_rank:
163+
if config.tie_word_embeddings:
164+
self.lm_head = self.model.embed_tokens
165+
else:
166+
self.lm_head = ParallelLMHead(config.vocab_size,
167+
config.hidden_size,
168+
quant_config=quant_config,
169+
prefix=maybe_prefix(
170+
prefix, "lm_head"))
171+
else:
172+
self.lm_head = PPMissingLayer()
173+
174+
self.logits_processor = LogitsProcessor(config.vocab_size)
175+
176+
self.make_empty_intermediate_tensors = (
177+
self.model.make_empty_intermediate_tensors)
178+
179+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
180+
return self.model.get_input_embeddings(input_ids)
181+
182+
def forward(
183+
self,
184+
input_ids: torch.Tensor,
185+
positions: torch.Tensor,
186+
intermediate_tensors: Optional[IntermediateTensors] = None,
187+
inputs_embeds: Optional[torch.Tensor] = None,
188+
) -> Union[torch.Tensor, IntermediateTensors]:
189+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
190+
inputs_embeds)
191+
return hidden_states
192+
193+
def compute_logits(
194+
self,
195+
hidden_states: torch.Tensor,
196+
sampling_metadata: SamplingMetadata,
197+
) -> Optional[torch.Tensor]:
198+
logits = self.logits_processor(self.lm_head, hidden_states,
199+
sampling_metadata)
200+
return logits
201+
202+
def load_weights(self, weights: Iterable[tuple[str,
203+
torch.Tensor]]) -> set[str]:
204+
loader = AutoWeightsLoader(
205+
self,
206+
skip_prefixes=(["lm_head."]
207+
if self.config.tie_word_embeddings else None),
208+
)
209+
return loader.load_weights(weights)

vllm_ascend/ops/layernorm.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,46 @@
2121
from vllm.model_executor.layers.layernorm import RMSNorm
2222

2323

24+
class AddRMSNormQuant(RMSNorm):
25+
"""Root mean square normalization.
26+
27+
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
28+
Refer to https://arxiv.org/abs/1910.07467
29+
"""
30+
def __init__(
31+
self,
32+
hidden_size: int,
33+
scale: torch.Tensor,
34+
offset: torch.Tensor,
35+
eps: float = 1e-6,
36+
var_hidden_size: Optional[int] = None,
37+
has_weight: bool = True,
38+
dtype: Optional[torch.dtype] = None,
39+
) -> None:
40+
super().__init__(hidden_size,
41+
eps,
42+
var_hidden_size,
43+
has_weight,
44+
dtype)
45+
self.scale = scale
46+
self.offset = offset
47+
48+
def forward(
49+
self,
50+
x: torch.Tensor,
51+
residual: Optional[torch.Tensor] = None,
52+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
53+
import torch_npu
54+
55+
if residual is not None:
56+
x, _, residual = torch_npu.npu_add_rms_norm_quant(x, residual, self.weight,
57+
self.scale, self.offset,
58+
epsilon=self.variance_epsilon)
59+
return x, residual
60+
61+
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
62+
return x
63+
2464
def forward_oot(
2565
self,
2666
x: torch.Tensor,

vllm_ascend/quantization/w8a8.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class AscendW8A8LinearMethod:
3333
Args:
3434
w_sym: whether the linear weight is symmetrically quantized.
3535
"""
36+
params_dtype: torch.dtype = torch.bfloat16
3637

3738
def __init__(self) -> None:
3839
# aclnn quant matmul requires to transpose matrix B, set to true by default.
@@ -54,6 +55,7 @@ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
5455
params_dict = {}
5556
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
5657
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
58+
AscendW8A8LinearMethod.params_dtype = params_dtype
5759
return params_dict
5860

5961
@staticmethod
@@ -75,6 +77,7 @@ def get_perchannel_param(
7577
params_dict["weight_offset"] = torch.empty(output_size,
7678
1,
7779
dtype=params_dtype)
80+
AscendW8A8LinearMethod.params_dtype = params_dtype
7881
return params_dict
7982

8083
@staticmethod
@@ -84,8 +87,7 @@ def apply(
8487
bias: Optional[torch.Tensor] = None,
8588
tp_rank: Optional[int] = 0,
8689
) -> torch.Tensor:
87-
original_dtype = x.dtype
88-
if original_dtype != torch.int8:
90+
if x.dtype != torch.int8:
8991
x = quant_per_tensor(
9092
x,
9193
layer.aclnn_input_scale,
@@ -97,7 +99,7 @@ def apply(
9799
layer.weight,
98100
layer.deq_scale,
99101
bias=quant_bias,
100-
output_dtype=original_dtype,
102+
output_dtype=AscendW8A8LinearMethod.params_dtype,
101103
)
102104

103105
def process_weights_after_loading(self, layer):

0 commit comments

Comments
 (0)