|
15 | 15 | # limitations under the License. |
16 | 16 | # |
17 | 17 |
|
18 | | -from typing import Any, Callable, Dict, Optional |
| 18 | +from typing import Any, Callable, Dict, Optional, Tuple, Union |
19 | 19 |
|
20 | 20 | import torch |
21 | 21 | import torch.distributed as dist |
22 | 22 | import torch_npu |
23 | 23 | import torchair as tng # type: ignore |
24 | | -from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce |
| 24 | +from vllm.distributed import GroupCoordinator |
25 | 25 |
|
26 | 26 | import vllm_ascend.envs as envs_ascend |
27 | 27 | from vllm_ascend.ascend_config import get_ascend_config |
@@ -77,19 +77,9 @@ def apply_mlp(hidden_states: torch.Tensor, |
77 | 77 | shared_experts = kwargs.get('shared_experts', None) |
78 | 78 | if shared_experts: |
79 | 79 | shared_gate_up = kwargs.get('shared_gate_up', None) |
80 | | - shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None) |
81 | 80 | with tng.scope.npu_stream_switch('cv'): |
82 | | - tng.scope.npu_wait_tensor(shared_gate_up, hidden_states) |
83 | | - shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant( |
84 | | - x=shared_gate_up, |
85 | | - weight_scale=shared_experts.gate_up_proj.weight_scale_fp32, |
86 | | - activation_scale=shared_dynamic_scale, |
87 | | - bias=None, |
88 | | - quant_scale=None, |
89 | | - quant_offset=None, |
90 | | - group_index=None, |
91 | | - activate_left=True, |
92 | | - quant_mode=1) |
| 81 | + tng.scope.npu_wait_tensor(shared_gate_up[0], hidden_states) |
| 82 | + shared_act = shared_experts.act_fn(shared_gate_up) |
93 | 83 |
|
94 | 84 | # gmm1: gate_up_proj |
95 | 85 | hidden_states = torch_npu.npu_grouped_matmul( |
@@ -122,16 +112,9 @@ def apply_mlp(hidden_states: torch.Tensor, |
122 | 112 |
|
123 | 113 | if shared_experts: |
124 | 114 | with tng.scope.npu_stream_switch('cv'): |
125 | | - tng.scope.npu_wait_tensor(shared_x, hidden_states) |
126 | | - shared_output = torch_npu.npu_quant_matmul( |
127 | | - shared_x, |
128 | | - shared_experts.down_proj.weight, |
129 | | - shared_experts.down_proj.weight_scale, |
130 | | - pertoken_scale=shared_dynamic_scale, |
131 | | - output_dtype=torch.bfloat16, |
132 | | - ) |
133 | | - if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1: |
134 | | - shared_output = tensor_model_parallel_all_reduce(shared_output) |
| 115 | + tng.scope.npu_wait_tensor(shared_act[0], hidden_states) |
| 116 | + shared_output, _ = shared_experts.down_proj(shared_act) |
| 117 | + |
135 | 118 | if shared_experts: |
136 | 119 | return hidden_states, shared_output |
137 | 120 | return hidden_states |
@@ -189,17 +172,10 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, |
189 | 172 | shared_hidden_states = kwargs.get('shared_hidden_states', None) |
190 | 173 | with tng.scope.npu_stream_switch('cv'): |
191 | 174 | tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states) |
192 | | - shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant( |
| 175 | + shared_gate_up, _ = shared_experts.gate_up_proj( |
193 | 176 | shared_hidden_states) |
194 | | - shared_gate_up = torch_npu.npu_quant_matmul( |
195 | | - shared_x, |
196 | | - shared_experts.gate_up_proj.weight, |
197 | | - shared_experts.gate_up_proj.weight_scale, |
198 | | - output_dtype=torch.int32, |
199 | | - ) |
200 | 177 | kwargs.update({ |
201 | 178 | "shared_gate_up": shared_gate_up, |
202 | | - "shared_dynamic_scale": shared_dynamic_scale, |
203 | 179 | }) |
204 | 180 |
|
205 | 181 | output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) |
@@ -532,21 +508,31 @@ def get_perchannel_param( |
532 | 508 | @staticmethod |
533 | 509 | def apply( |
534 | 510 | layer: torch.nn.Module, |
535 | | - x: torch.Tensor, |
| 511 | + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
536 | 512 | bias: Optional[torch.Tensor] = None, |
537 | 513 | tp_rank: Optional[int] = 0, |
538 | 514 | ) -> torch.Tensor: |
539 | | - original_dtype = x.dtype |
540 | | - # use ATB quantize |
541 | | - quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) |
542 | | - return torch_npu.npu_quant_matmul( |
543 | | - quant_out, |
| 515 | + config = getattr(layer, "_dynamic_quant_config", {}) |
| 516 | + if not isinstance(x, tuple): |
| 517 | + output_dtype = config.get("output_dtype", x.dtype) |
| 518 | + quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) |
| 519 | + else: |
| 520 | + assert "output_dtype" in config.keys(), ( |
| 521 | + f"DynamicLinearMethod needs explicitly specified `output_dtype`" |
| 522 | + f"for pre-quantized input, got config [{config}]") |
| 523 | + output_dtype = config["output_dtype"] |
| 524 | + quantized_x, dynamic_scale = x |
| 525 | + |
| 526 | + output = torch_npu.npu_quant_matmul( |
| 527 | + quantized_x, |
544 | 528 | layer.weight, |
545 | 529 | layer.weight_scale, |
546 | 530 | pertoken_scale=dynamic_scale, |
547 | 531 | bias=bias, |
548 | | - output_dtype=original_dtype, |
| 532 | + output_dtype=output_dtype, |
549 | 533 | ) |
| 534 | + return ((output, dynamic_scale) |
| 535 | + if config.get("return_scale", False) else output) |
550 | 536 |
|
551 | 537 | def process_weights_after_loading(self, layer): |
552 | 538 | if self.transpose_weight: |
|
0 commit comments