Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/user_guide/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ The following table lists the additional configuration options available in vLLM
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |

The details of each config option are as follows:

Expand Down
23 changes: 19 additions & 4 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@
16)
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def get_bsh_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)

Check warning on line 79 in vllm_ascend/attention/attention_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/attention_v1.py#L79

Added line #L79 was not covered by tests

@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
Expand Down Expand Up @@ -279,6 +288,13 @@
value=value,
output=output,
layer_name=layer.layer_name)

elif hasattr(layer, 'quant_method'):
output = layer.quant_method.apply(layer, query, key, value,

Check warning on line 293 in vllm_ascend/attention/attention_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/attention_v1.py#L292-L293

Added lines #L292 - L293 were not covered by tests
kv_cache, attn_metadata,
self.attn_type, self.scale,
output)

else:
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
Expand Down Expand Up @@ -308,11 +324,8 @@
value_cache=self.value_cache,
slot_indices=slots)

if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:

Check warning on line 328 in vllm_ascend/attention/attention_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/attention_v1.py#L328

Added line #L328 was not covered by tests
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
Expand Down Expand Up @@ -414,6 +427,8 @@
out=output)

# to make in-place change to the output tensor
if hasattr(layer, 'quant_method'):
output = output.view(num_tokens, self.num_heads, self.head_size)

Check warning on line 431 in vllm_ascend/attention/attention_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/attention_v1.py#L430-L431

Added lines #L430 - L431 were not covered by tests
ori_output[:, :, :] = output[:num_tokens, :, :]
return output.view(num_tokens, self.hidden_size)

Expand Down
49 changes: 48 additions & 1 deletion vllm_ascend/models/pangu_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def forward(
# native FusedMoE. here we need to design a better FusedMoE
# (maybe using AscendFusedMoE) to enable these different
# communication schema.
final_hidden_states = self.experts.quant_method(
final_hidden_states = self.experts.quant_method.apply(
layer=self.experts,
x=hidden_states,
router_logits=router_logits,
Expand Down Expand Up @@ -937,6 +937,8 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
Expand Down Expand Up @@ -972,6 +974,51 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if "module" in name:
continue

if name.endswith('kv_cache_offset'):
continue

if name.endswith("k_proj.kv_cache_scale"):
remapped_kv_scale_name = name.replace(
"k_proj.kv_cache_scale", "attn.key_antiquant_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
loaded_weight = torch.tensor_split(loaded_weight,
tp_size,
dim=0)[tp_rank]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

if name.endswith("v_proj.kv_cache_scale"):
remapped_kv_scale_name = name.replace(
"v_proj.kv_cache_scale", "attn.value_antiquant_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
loaded_weight = torch.tensor_split(loaded_weight,
tp_size,
dim=0)[tp_rank]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
Expand Down
4 changes: 4 additions & 0 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
kv_cache_dtype = vllm_config.additional_config.get(
"kv_cache_dtype", None)
if kv_cache_dtype is not None:
vllm_config.cache_config.cache_dtype = kv_cache_dtype

Check warning on line 130 in vllm_ascend/platform.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/platform.py#L130

Added line #L130 was not covered by tests

if parallel_config:
# Default value for expert tensor parallel size
Expand Down
34 changes: 8 additions & 26 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@
'fa_quant_type' in self.quant_description.keys() and \
self.quant_description['fa_quant_type'] is not None:
return AscendKVCacheMethod(self, prefix)
elif isinstance(layer, Attention) and self.quant_description.get(

Check warning on line 101 in vllm_ascend/quantization/quant_config.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/quant_config.py#L101

Added line #L101 was not covered by tests
'kv_quant_type') == 'C8':
return AscendKVCacheMethod(self, prefix)

Check warning on line 103 in vllm_ascend/quantization/quant_config.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/quant_config.py#L103

Added line #L103 was not covered by tests
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
Expand Down Expand Up @@ -235,32 +238,11 @@
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)

def apply(self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
scale: torch.Tensor,
block_tables: torch.Tensor,
isPrefill: bool,
attn_metadata,
output,
seq_lens_tensor_cpu: Optional[int] = None) -> torch.Tensor:
return self.quant_method.apply(layer,
query,
key,
value,
k_cache,
v_cache,
scale,
block_tables,
isPrefill,
attn_metadata.attn_mask,
attn_metadata.slot_mapping,
output,
seq_lens_tensor_cpu=seq_lens_tensor_cpu)
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,

Check warning on line 244 in vllm_ascend/quantization/quant_config.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/quant_config.py#L244

Added line #L244 was not covered by tests
attn_metadata, attn_type, scale, output)


class AscendFusedMoEMethod(FusedMoEMethodBase):
Expand Down
14 changes: 13 additions & 1 deletion vllm_ascend/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
wrapper_rmsnorm_init)
from .w8a8 import AscendW8A8LinearMethod
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod)
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)

Expand Down Expand Up @@ -250,6 +251,8 @@
# Attention
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
quant_type = quant_description['fa_quant_type']
if '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
quant_type = quant_description['kv_quant_type']

Check warning on line 255 in vllm_ascend/quantization/quantizer.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/quantizer.py#L254-L255

Added lines #L254 - L255 were not covered by tests
# Linear
else:
quant_type = cls.get_linear_quant_type(quant_description, prefix,
Expand All @@ -269,6 +272,14 @@
def build_linear_method():
return AscendW8A8LinearMethod()

@staticmethod
def build_moe_method():
return AscendW8A8FusedMoEMethod()

Check warning on line 277 in vllm_ascend/quantization/quantizer.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/quantizer.py#L277

Added line #L277 was not covered by tests

@staticmethod
def build_attention_method():
return AscendC8KVCacheMethod()

Check warning on line 281 in vllm_ascend/quantization/quantizer.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/quantizer.py#L281

Added line #L281 was not covered by tests


class W8A8DYNAMICQuantizer(VLLMAscendQuantizer):

Expand All @@ -284,4 +295,5 @@
SUPPORT_ASCEND_QUANTIZER_TYPE = {
"W8A8": W8A8Quantizer,
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
"C8": W8A8Quantizer,
}
Loading
Loading