Skip to content

Commit 6f8239c

Browse files
committed
Refactor scattered w8a8 dynamic quantization operations
AscendW8A8DynamicLinearMethod is integrated into CustomDeepseekV2MLP in a very awkward way, causing scattered quantization operations all over the model scripts. Refactor to solve this problem. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 95414ba commit 6f8239c

File tree

3 files changed

+83
-181
lines changed

3 files changed

+83
-181
lines changed

vllm_ascend/models/deepseek_dbo.py

Lines changed: 4 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import torch
3131
import torch.distributed as dist
32-
import torch_npu
32+
import torch_npu # noqa: F401
3333
import vllm.envs as envs
3434
from torch import nn
3535
from transformers import PretrainedConfig
@@ -40,13 +40,10 @@
4040
get_tp_group, tensor_model_parallel_all_reduce)
4141
from vllm.distributed.parallel_state import get_dp_group
4242
from vllm.forward_context import get_forward_context
43-
from vllm.model_executor.layers.activation import SiluAndMul
4443
from vllm.model_executor.layers.layernorm import RMSNorm
4544
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
46-
MergedColumnParallelLinear,
4745
ReplicatedLinear,
48-
RowParallelLinear,
49-
UnquantizedLinearMethod)
46+
RowParallelLinear)
5047
from vllm.model_executor.layers.logits_processor import LogitsProcessor
5148
from vllm.model_executor.layers.quantization import QuantizationConfig
5249
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -67,6 +64,7 @@
6764

6865
import vllm_ascend.envs as envs_ascend
6966
from vllm_ascend.ascend_config import get_ascend_config
67+
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
7068
from vllm_ascend.multistream.base import MSEventKey
7169
from vllm_ascend.multistream.context import (
7270
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -78,117 +76,17 @@
7876
make_multistream_metadata_ds)
7977
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8078
from vllm_ascend.ops.fused_moe import AscendFusedMoE
81-
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
8279
from vllm_ascend.utils import dispose_tensor
8380

8481
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
8582
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
8683

8784

88-
class CustomDeepseekDBOMLP(nn.Module):
89-
90-
def __init__(
91-
self,
92-
hidden_size: int,
93-
intermediate_size: int,
94-
hidden_act: str,
95-
quant_config: Optional[QuantizationConfig] = None,
96-
reduce_results: bool = True,
97-
prefix: str = "",
98-
) -> None:
99-
super().__init__()
100-
self.gate_up_proj = MergedColumnParallelLinear(
101-
hidden_size, [intermediate_size] * 2,
102-
bias=False,
103-
quant_config=quant_config,
104-
prefix=f"{prefix}.gate_up_proj")
105-
self.down_proj = RowParallelLinear(intermediate_size,
106-
hidden_size,
107-
bias=False,
108-
quant_config=quant_config,
109-
reduce_results=reduce_results,
110-
prefix=f"{prefix}.down_proj")
111-
if hidden_act != "silu":
112-
raise ValueError(f"Unsupported activation: {hidden_act}. "
113-
"Only silu is supported for now.")
114-
self.act_fn = SiluAndMul()
115-
116-
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
117-
self.is_dynamic_quant = not isinstance(
118-
self.gate_up_proj.quant_method,
119-
UnquantizedLinearMethod) and isinstance(
120-
self.gate_up_proj.quant_method.quant_method,
121-
AscendW8A8DynamicLinearMethod)
122-
123-
def forward(self, x):
124-
if self.is_dynamic_quant:
125-
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
126-
x = torch_npu.npu_quant_matmul(
127-
x,
128-
self.gate_up_proj.weight,
129-
self.gate_up_proj.weight_scale,
130-
output_dtype=torch.int32,
131-
)
132-
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
133-
x=x,
134-
weight_scale=self.gate_up_proj.weight_scale_fp32,
135-
activation_scale=dynamic_scale,
136-
bias=None,
137-
quant_scale=None,
138-
quant_offset=None,
139-
group_index=None,
140-
activate_left=True,
141-
quant_mode=1)
142-
x = torch_npu.npu_quant_matmul(
143-
x,
144-
self.down_proj.weight,
145-
self.down_proj.weight_scale,
146-
pertoken_scale=dynamic_scale,
147-
output_dtype=torch.bfloat16,
148-
)
149-
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
150-
x = tensor_model_parallel_all_reduce(x)
151-
return x
152-
gate_up, _ = self.gate_up_proj(x)
153-
x = self.act_fn(gate_up)
154-
x, _ = self.down_proj(x)
155-
return x
85+
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
15686

15787
def _forward_ms_mlp(self, x):
15888
current_ms_metadata = get_multistream_comm_context()
15989
assert current_ms_metadata is not None
160-
if self.is_dynamic_quant:
161-
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
162-
x = torch_npu.npu_quant_matmul(
163-
x,
164-
self.gate_up_proj.weight,
165-
self.gate_up_proj.weight_scale,
166-
output_dtype=torch.int32,
167-
)
168-
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
169-
x=x,
170-
weight_scale=self.gate_up_proj.weight_scale_fp32,
171-
activation_scale=dynamic_scale,
172-
bias=None,
173-
quant_scale=None,
174-
quant_offset=None,
175-
group_index=None,
176-
activate_left=True,
177-
quant_mode=1)
178-
x = torch_npu.npu_quant_matmul(
179-
x,
180-
self.down_proj.weight,
181-
self.down_proj.weight_scale,
182-
pertoken_scale=dynamic_scale,
183-
output_dtype=torch.bfloat16,
184-
)
185-
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
186-
current_ms_metadata.before_comm_event.record()
187-
with torch.npu.stream(current_ms_metadata.comm_stream):
188-
current_ms_metadata.before_comm_event.wait()
189-
x = tensor_model_parallel_all_reduce(x)
190-
current_ms_metadata.after_comm_event.record()
191-
return x
19290
gate_up, _ = self.gate_up_proj(x)
19391
x = self.act_fn(gate_up)
19492
current_ms_metadata.before_comm_event.record()

vllm_ascend/models/deepseek_v2.py

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

28-
from typing import Any, Dict, List, Optional, Union
28+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2929

3030
import torch
3131
import torch.distributed as dist
@@ -69,12 +69,38 @@
6969
from vllm_ascend.ascend_config import get_ascend_config
7070
from vllm_ascend.distributed.parallel_state import get_ep_group
7171
from vllm_ascend.ops.fused_moe import AscendFusedMoE
72+
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7273
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7374
from vllm_ascend.utils import dispose_tensor
7475

7576
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7677

7778

79+
class CustomDeepseekV2SiluAndMul(SiluAndMul):
80+
81+
def __init__(self,
82+
*,
83+
weight_scale: Optional[Callable[[], torch.Tensor]] = None):
84+
super().__init__()
85+
self.weight_scale = weight_scale
86+
87+
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
88+
torch.Tensor]]):
89+
if isinstance(x, tuple):
90+
assert self.weight_scale is not None
91+
# For AscendW8A8DynamicLinearMethod:
92+
# a dynamic scale is passed along with the quantized value.
93+
quantized_x, dynamic_scale = x
94+
return torch_npu.npu_dequant_swiglu_quant(
95+
x=quantized_x,
96+
weight_scale=self.weight_scale(),
97+
activation_scale=dynamic_scale,
98+
activate_left=True,
99+
quant_mode=1)
100+
else:
101+
return super().forward_oot(x)
102+
103+
78104
class CustomDeepseekV2MLP(nn.Module):
79105

80106
def __init__(
@@ -101,44 +127,36 @@ def __init__(
101127
if hidden_act != "silu":
102128
raise ValueError(f"Unsupported activation: {hidden_act}. "
103129
"Only silu is supported for now.")
104-
self.act_fn = SiluAndMul()
105130

106-
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
107-
self.is_dynamic_quant = not isinstance(
108-
self.gate_up_proj.quant_method,
109-
UnquantizedLinearMethod) and isinstance(
110-
self.gate_up_proj.quant_method.quant_method,
111-
AscendW8A8DynamicLinearMethod)
131+
quant_method = self.gate_up_proj.quant_method
132+
if isinstance(quant_method, UnquantizedLinearMethod):
133+
self.act_fn = CustomDeepseekV2SiluAndMul()
134+
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
135+
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
136+
# TODO(sdmyzlp): Currently preserved as before:
137+
# 1. The only quantization supported for silu is W8A8Dynamic
138+
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
139+
#
140+
# Maybe one can implement a better and more general configuration
141+
# scheme, e.g. by somehow passing around the tweaked `quant_config`
142+
self.act_fn = CustomDeepseekV2SiluAndMul(
143+
# Use lazy binding, for `weight_scale_fp32` is accessible
144+
# only after `process_weights_after_loading`.
145+
weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
146+
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
147+
self.gate_up_proj._dynamic_quant_config = {
148+
"output_dtype": torch.int32,
149+
"return_scale": True,
150+
}
151+
self.down_proj._dynamic_quant_config = {
152+
"output_dtype": torch.bfloat16,
153+
"return_scale": False,
154+
}
155+
else:
156+
raise NotImplementedError(
157+
f"Quantization with [{type(quant_method)}] is NOT supported")
112158

113159
def forward(self, x):
114-
if self.is_dynamic_quant:
115-
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
116-
x = torch_npu.npu_quant_matmul(
117-
x,
118-
self.gate_up_proj.weight,
119-
self.gate_up_proj.weight_scale,
120-
output_dtype=torch.int32,
121-
)
122-
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
123-
x=x,
124-
weight_scale=self.gate_up_proj.weight_scale_fp32,
125-
activation_scale=dynamic_scale,
126-
bias=None,
127-
quant_scale=None,
128-
quant_offset=None,
129-
group_index=None,
130-
activate_left=True,
131-
quant_mode=1)
132-
x = torch_npu.npu_quant_matmul(
133-
x,
134-
self.down_proj.weight,
135-
self.down_proj.weight_scale,
136-
pertoken_scale=dynamic_scale,
137-
output_dtype=torch.bfloat16,
138-
)
139-
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
140-
x = tensor_model_parallel_all_reduce(x)
141-
return x
142160
gate_up, _ = self.gate_up_proj(x)
143161
x = self.act_fn(gate_up)
144162
x, _ = self.down_proj(x)

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Any, Callable, Dict, Optional
18+
from typing import Any, Callable, Dict, Optional, Tuple, Union
1919

2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
2323
import torchair as tng # type: ignore
24-
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
24+
from vllm.distributed import GroupCoordinator
2525

2626
import vllm_ascend.envs as envs_ascend
2727
from vllm_ascend.ascend_config import get_ascend_config
@@ -77,19 +77,9 @@ def apply_mlp(hidden_states: torch.Tensor,
7777
shared_experts = kwargs.get('shared_experts', None)
7878
if shared_experts:
7979
shared_gate_up = kwargs.get('shared_gate_up', None)
80-
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
8180
with tng.scope.npu_stream_switch('cv'):
82-
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
83-
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
84-
x=shared_gate_up,
85-
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
86-
activation_scale=shared_dynamic_scale,
87-
bias=None,
88-
quant_scale=None,
89-
quant_offset=None,
90-
group_index=None,
91-
activate_left=True,
92-
quant_mode=1)
81+
tng.scope.npu_wait_tensor(shared_gate_up[0], hidden_states)
82+
shared_act = shared_experts.act_fn(shared_gate_up)
9383

9484
# gmm1: gate_up_proj
9585
hidden_states = torch_npu.npu_grouped_matmul(
@@ -122,16 +112,9 @@ def apply_mlp(hidden_states: torch.Tensor,
122112

123113
if shared_experts:
124114
with tng.scope.npu_stream_switch('cv'):
125-
tng.scope.npu_wait_tensor(shared_x, hidden_states)
126-
shared_output = torch_npu.npu_quant_matmul(
127-
shared_x,
128-
shared_experts.down_proj.weight,
129-
shared_experts.down_proj.weight_scale,
130-
pertoken_scale=shared_dynamic_scale,
131-
output_dtype=torch.bfloat16,
132-
)
133-
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
134-
shared_output = tensor_model_parallel_all_reduce(shared_output)
115+
tng.scope.npu_wait_tensor(shared_act[0], hidden_states)
116+
shared_output, _ = shared_experts.down_proj(shared_act)
117+
135118
if shared_experts:
136119
return hidden_states, shared_output
137120
return hidden_states
@@ -193,17 +176,10 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
193176
shared_hidden_states = kwargs.get('shared_hidden_states', None)
194177
with tng.scope.npu_stream_switch('cv'):
195178
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
196-
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
179+
shared_gate_up, _ = shared_experts.gate_up_proj(
197180
shared_hidden_states)
198-
shared_gate_up = torch_npu.npu_quant_matmul(
199-
shared_x,
200-
shared_experts.gate_up_proj.weight,
201-
shared_experts.gate_up_proj.weight_scale,
202-
output_dtype=torch.int32,
203-
)
204181
kwargs.update({
205182
"shared_gate_up": shared_gate_up,
206-
"shared_dynamic_scale": shared_dynamic_scale,
207183
})
208184

209185
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
@@ -541,21 +517,31 @@ def get_perchannel_param(
541517
@staticmethod
542518
def apply(
543519
layer: torch.nn.Module,
544-
x: torch.Tensor,
520+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
545521
bias: Optional[torch.Tensor] = None,
546522
tp_rank: Optional[int] = 0,
547523
) -> torch.Tensor:
548-
original_dtype = x.dtype
549-
# use ATB quantize
550-
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
551-
return torch_npu.npu_quant_matmul(
552-
quant_out,
524+
config = getattr(layer, "_dynamic_quant_config", {})
525+
if not isinstance(x, tuple):
526+
output_dtype = config.get("output_dtype", x.dtype)
527+
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
528+
else:
529+
assert "output_dtype" in config.keys(), (
530+
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
531+
f"for pre-quantized input, got config [{config}]")
532+
output_dtype = config["output_dtype"]
533+
quantized_x, dynamic_scale = x
534+
535+
output = torch_npu.npu_quant_matmul(
536+
quantized_x,
553537
layer.weight,
554538
layer.weight_scale,
555539
pertoken_scale=dynamic_scale,
556540
bias=bias,
557-
output_dtype=original_dtype,
541+
output_dtype=output_dtype,
558542
)
543+
return ((output, dynamic_scale)
544+
if config.get("return_scale", False) else output)
559545

560546
def process_weights_after_loading(self, layer):
561547
if self.transpose_weight:

0 commit comments

Comments
 (0)