Skip to content

Commit 49b047a

Browse files
committed
Fix bugs and CI and optimize codes
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
1 parent 3ec1c35 commit 49b047a

File tree

11 files changed

+118
-179
lines changed

11 files changed

+118
-179
lines changed

tests/ut/models/test_deepseek_v2.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
2828
CustomDeepseekV2RowParallelLinear,
2929
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
30-
CustomDeepseekV2SiluAndMul)
30+
CustomDeepseekV2SiluAndMul, CustomLogitsProcessor, CustomParallelLMHead)
3131

3232

3333
@pytest.fixture
@@ -317,3 +317,19 @@ def test_custom_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
317317
):
318318
loaded = model.load_weights(weights)
319319
assert loaded is not None
320+
321+
322+
def test_custom_deepseek_v2_lmhead(mock_distributed, vllm_config):
323+
model = CustomDeepseekV2ForCausalLM(vllm_config=vllm_config)
324+
lmhead = CustomParallelLMHead(model.model.config.vocab_size,
325+
model.model.config.hidden_size)
326+
logits_processor = CustomLogitsProcessor(model.model.config.vocab_size)
327+
328+
input_ids = torch.randint(0, 10000, (2, 4))
329+
positions = torch.arange(4).repeat(2, 1)
330+
with patch.object(model.model,
331+
"forward",
332+
return_value=torch.randn(2, 4, 128)):
333+
output = model(input_ids, positions)
334+
logits = logits_processor(lmhead, output)
335+
assert logits.shape == (2, 4, 10000)

vllm_ascend/ascend_config.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,15 @@ def __init__(self, vllm_config):
5353
self.lmhead_tensor_parallel_size = additional_config.get(
5454
"lmhead_tensor_parallel_size", None)
5555
if self.lmhead_tensor_parallel_size is not None:
56-
logger.info(f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario")
57-
assert(
56+
logger.info(
57+
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
58+
)
59+
assert (
5860
vllm_config.parallel_config.tensor_parallel_size == 1
59-
),"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
60-
assert(
61-
self.torchair_graph_config.enabled == True
61+
), "lmhead_tensor_parallel_size is only supported in the pure DP scenario"
62+
assert (
63+
self.torchair_graph_config.enabled
6264
), "lmhead_tensor_parallel_size is only supported in graph mode"
63-
assert(
64-
vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer
65-
),"lmhead_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
6665

6766

6867
class TorchairGraphConfig:

vllm_ascend/distributed/parallel_state.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99
_MC2: Optional[GroupCoordinator] = None
1010
_LMTP: Optional[GroupCoordinator] = None
1111

12+
1213
def get_mc2_group() -> GroupCoordinator:
1314
assert _MC2 is not None, ("mc2 group is not initialized")
1415
return _MC2
1516

17+
1618
def get_lmheadtp_group() -> GroupCoordinator:
1719
assert _LMTP is not None, (
1820
"lm head tensor parallel group is not initialized")
1921
return _LMTP
2022

23+
2124
def model_parallel_initialized():
2225
return (_MC2 is not None)
2326

@@ -43,22 +46,23 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
4346
get_world_group().local_rank,
4447
backend,
4548
group_name="mc2")
46-
49+
4750
lmhead_tensor_parallel_size = parallel_config.lmhead_tensor_parallel_size
4851
if lmhead_tensor_parallel_size is not None:
4952
group_ranks = []
5053
global _LMTP
5154
num_lmhead_tensor_parallel_groups: int = (world_size //
52-
lmhead_tensor_parallel_size)
55+
lmhead_tensor_parallel_size)
5356
for i in range(num_lmhead_tensor_parallel_groups):
5457
ranks = list(
5558
range(i * lmhead_tensor_parallel_size,
56-
(i + 1) * lmhead_tensor_parallel_size))
59+
(i + 1) * lmhead_tensor_parallel_size))
5760
group_ranks.append(ranks)
5861
_LMTP = init_model_parallel_group(group_ranks,
59-
get_world_group().local_rank,
60-
backend,
61-
group_name="lmheadtp")
62+
get_world_group().local_rank,
63+
backend,
64+
group_name="lmheadtp")
65+
6266

6367
def destroy_ascend_model_parallel():
6468
global _MC2

vllm_ascend/models/deepseek_mtp.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,20 @@
2525
from vllm.attention.backends.abstract import AttentionMetadata
2626
from vllm.config import CacheConfig, ModelConfig, VllmConfig
2727
from vllm.model_executor.layers.layernorm import RMSNorm
28-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2928
from vllm.model_executor.layers.quantization import QuantizationConfig
3029
from vllm.model_executor.layers.sampler import get_sampler
31-
from vllm.model_executor.layers.vocab_parallel_embedding import (
32-
ParallelLMHead, VocabParallelEmbedding)
30+
from vllm.model_executor.layers.vocab_parallel_embedding import \
31+
VocabParallelEmbedding
3332
from vllm.model_executor.models.deepseek_mtp import (
3433
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
3534
SharedHead)
3635
from vllm.model_executor.models.utils import maybe_prefix
3736
from vllm.model_executor.sampling_metadata import SamplingMetadata
3837
from vllm.sequence import IntermediateTensors
3938

39+
from vllm_ascend.ops.vocab_parallel_embedding import (CustomLogitsProcessor,
40+
CustomParallelLMHead)
41+
4042
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
4143

4244

@@ -48,10 +50,10 @@ def __init__(self,
4850
prefix: str = "") -> None:
4951
nn.Module.__init__(self)
5052
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
51-
self.head = ParallelLMHead(config.vocab_size,
52-
config.hidden_size,
53-
quant_config=quant_config,
54-
prefix=maybe_prefix(prefix, "head"))
53+
self.head = CustomParallelLMHead(config.vocab_size,
54+
config.hidden_size,
55+
quant_config=quant_config,
56+
prefix=maybe_prefix(prefix, "head"))
5557

5658

5759
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
@@ -141,7 +143,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
141143
for idx in range(self.mtp_start_layer_idx,
142144
self.mtp_start_layer_idx + self.num_mtp_layers)
143145
]
144-
self.logits_processor = LogitsProcessor(config.vocab_size)
146+
self.logits_processor = CustomLogitsProcessor(config.vocab_size)
145147

146148
def forward(
147149
self,

vllm_ascend/models/deepseek_v2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,11 @@
4949
ReplicatedLinear,
5050
RowParallelLinear,
5151
UnquantizedLinearMethod)
52-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
5352
from vllm.model_executor.layers.quantization import QuantizationConfig
5453
from vllm.model_executor.layers.rotary_embedding import get_rope
5554
from vllm.model_executor.layers.sampler import get_sampler
56-
from vllm.model_executor.layers.vocab_parallel_embedding import (
57-
ParallelLMHead, VocabParallelEmbedding)
55+
from vllm.model_executor.layers.vocab_parallel_embedding import \
56+
VocabParallelEmbedding
5857
from vllm.model_executor.model_loader.weight_utils import (
5958
default_weight_loader, maybe_remap_kv_scale_name)
6059
from vllm.model_executor.models.deepseek_v2 import \
@@ -68,14 +67,15 @@
6867
PPMissingLayer, is_pp_missing_parameter,
6968
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
7069
from vllm.sequence import IntermediateTensors
71-
from vllm.model_executor.sampling_metadata import SamplingMetadata
7270

7371
from vllm_ascend.ascend_config import get_ascend_config
7472
from vllm_ascend.ops.fused_moe import AscendFusedMoE
73+
from vllm_ascend.ops.vocab_parallel_embedding import (CustomLogitsProcessor,
74+
CustomParallelLMHead)
7575
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7676
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7777
from vllm_ascend.utils import dispose_tensor, npu_prefetch
78-
from vllm_ascend.ops.vocab_parallel_embedding import CustomParallelLMHead, CustomLogitsProcessor
78+
7979

8080
class CustomDeepseekV2SiluAndMul(SiluAndMul):
8181

@@ -930,7 +930,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
930930
config.hidden_size,
931931
quant_config=quant_config,
932932
prefix=maybe_prefix(
933-
prefix, "lm_head"))
933+
prefix, "lm_head"))
934934
else:
935935
self.lm_head = PPMissingLayer()
936936
self.logits_processor = CustomLogitsProcessor(config.vocab_size)

vllm_ascend/ops/vocab_parallel_embedding.py

Lines changed: 30 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,20 @@
1919

2020
import torch
2121
from torch.nn import Module
22-
import torch.distributed as dist
23-
from torch.nn.parameter import Parameter, UninitializedParameter
24-
25-
from vllm.distributed import (
26-
divide,
27-
get_tensor_model_parallel_rank,
28-
get_tensor_model_parallel_world_size,
29-
tensor_model_parallel_all_reduce
30-
)
31-
from vllm.model_executor.layers.vocab_parallel_embedding import (
32-
VocabParallelEmbedding,
33-
DEFAULT_VOCAB_PADDING_SIZE,
34-
pad_vocab_size,
35-
UnquantizedEmbeddingMethod,
36-
ParallelLMHead
37-
)
38-
from vllm.model_executor.layers.logits_processor import (
39-
LogitsProcessor,
40-
_apply_logits_processors,
41-
_prune_hidden_states
42-
)
43-
from vllm.model_executor.parameter import BasevLLMParameter
44-
from vllm.model_executor.utils import set_weight_attrs, _enable_lmhead_tp
45-
from vllm.model_executor.sampling_metadata import SamplingMetadata
22+
from torch.nn.parameter import Parameter
23+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
24+
get_tensor_model_parallel_world_size,
25+
tensor_model_parallel_all_reduce)
26+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4627
from vllm.model_executor.layers.quantization.base_config import (
47-
QuantizationConfig,
48-
QuantizeMethodBase,
49-
method_has_implemented_embedding
50-
)
28+
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
29+
from vllm.model_executor.layers.vocab_parallel_embedding import (
30+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
31+
VocabParallelEmbedding, pad_vocab_size)
32+
from vllm.model_executor.utils import set_weight_attrs
5133

5234
from vllm_ascend.distributed.parallel_state import get_lmheadtp_group
53-
from vllm_ascend.ascend_config import get_ascend_config
35+
from vllm_ascend.utils import lmhead_tp_enable
5436

5537

5638
def get_masked_input_and_mask(
@@ -105,8 +87,7 @@ def vocab_parallel_embedding_forward(self, input_):
10587

10688

10789
class CustomParallelLMHead(ParallelLMHead):
108-
109-
"""Costom Parallelized LM head, added the feature of lmheadTP in pure dp scenario
90+
"""Custom Parallelized LM head, added the feature of lmheadTP in pure dp scenario
11091
11192
Output logits weight matrices used in the Sampler. The weight and bias
11293
tensors are padded to make sure they are divisible by the number of
@@ -120,6 +101,7 @@ class CustomParallelLMHead(ParallelLMHead):
120101
org_num_embeddings: original vocabulary size (without LoRA).
121102
padding_size: padding size for the vocabulary.
122103
"""
104+
123105
def __init__(self,
124106
num_embeddings: int,
125107
embedding_dim: int,
@@ -128,16 +110,16 @@ def __init__(self,
128110
org_num_embeddings: Optional[int] = None,
129111
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
130112
quant_config: Optional[QuantizationConfig] = None,
131-
prefix: str = ""):
113+
prefix: str = ""):
132114
Module.__init__(self)
133115

134-
if _enable_lmhead_tp():
116+
if lmhead_tp_enable():
135117
tp_rank = get_lmheadtp_group().rank_in_group
136118
self.tp_size = get_lmheadtp_group().world_size
137119
else:
138120
tp_rank = get_tensor_model_parallel_rank()
139121
self.tp_size = get_tensor_model_parallel_world_size()
140-
122+
141123
self.num_embeddings = num_embeddings
142124
self.padding_size = padding_size
143125
self.org_vocab_size = org_num_embeddings or num_embeddings
@@ -197,7 +179,7 @@ def __init__(self,
197179
self.num_embeddings_padded,
198180
params_dtype=params_dtype,
199181
weight_loader=self.weight_loader)
200-
182+
201183
self.quant_config = quant_config
202184
if bias:
203185
self.bias = Parameter(
@@ -209,90 +191,32 @@ def __init__(self,
209191
})
210192
else:
211193
self.register_parameter("bias", None)
212-
194+
195+
213196
class CustomLogitsProcessor(LogitsProcessor):
214197
"""Custom logits processor extending base LogitsProcessor functionality.
215198
Added the feature of lmheadTP in pure dp scenario
216199
"""
217-
218-
def __init__(self,
219-
vocab_size: int,
220-
org_vocab_size: Optional[int] = None,
221-
scale: float = 1.0,
222-
logits_as_input: bool = False,
223-
soft_cap: Optional[float] = None) -> None:
224-
super().__init__(
225-
vocab_size=vocab_size,
226-
org_vocab_size=org_vocab_size,
227-
scale=scale,
228-
logits_as_input=logits_as_input,
229-
soft_cap=soft_cap
230-
)
231200

232-
def forward(
201+
def _get_logits(
233202
self,
234-
lm_head: CustomParallelLMHead,
235203
hidden_states: torch.Tensor,
236-
sampling_metadata: Optional[SamplingMetadata] = None,
237-
embedding_bias: Optional[torch.Tensor] = None,
204+
lm_head: CustomParallelLMHead,
205+
embedding_bias: Optional[torch.Tensor],
238206
) -> Optional[torch.Tensor]:
239-
if self.logits_as_input:
240-
logits = hidden_states
241-
else:
242-
if sampling_metadata is not None:
243-
hidden_states = _prune_hidden_states(hidden_states,
244-
sampling_metadata)
245-
246-
# Get the logits for the next tokens.
247-
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
248-
if logits is not None:
249-
if self.soft_cap is not None:
250-
logits = logits / self.soft_cap
251-
logits = torch.tanh(logits)
252-
logits = logits * self.soft_cap
253-
254-
if self.scale != 1.0:
255-
logits *= self.scale
256207

257-
# Apply logits processors (if any).
258-
if sampling_metadata is not None and \
259-
sampling_metadata.seq_groups is not None:
260-
logits = _apply_logits_processors(logits, sampling_metadata)
261-
262-
return logits
263-
264-
def _get_logits(
265-
self,
266-
hidden_states: torch.Tensor,
267-
lm_head: CustomParallelLMHead,
268-
embedding_bias: Optional[torch.Tensor],
269-
) -> Optional[torch.Tensor]:
270-
"""
271-
Compute logits for next token prediction using parallel processing.
272-
273-
Args:
274-
hidden_states: Current hidden states from the model with shape [batch_size, hidden_size]
275-
lm_head: Parallel embedding layer for vocabulary predictions
276-
embedding_bias: Optional bias tensor to add to logits with shape [vocab_size]
277-
278-
Returns:
279-
Logits tensor for next token prediction with shape [batch_size, vocab_size] or None
280-
"""
281-
282-
if _enable_lmhead_tp():
208+
if lmhead_tp_enable():
283209
# Gather hidden states from all devices in tensor parallel group
284-
gathered_hidden_states = get_lmheadtp_group().all_gather(hidden_states, dim=0)
210+
gathered_hidden_states = get_lmheadtp_group().all_gather(
211+
hidden_states, dim=0)
285212
else:
286213
gathered_hidden_states = hidden_states
287214

288-
# Compute logits using quantized matrix multiplication
289-
local_logits = lm_head.quant_method.apply(
290-
lm_head,
291-
gathered_hidden_states,
292-
bias=embedding_bias
293-
)
215+
local_logits = lm_head.quant_method.apply(lm_head,
216+
gathered_hidden_states,
217+
bias=embedding_bias)
294218

295-
if _enable_lmhead_tp():
219+
if lmhead_tp_enable():
296220
logits = get_lmheadtp_group().all_to_all(local_logits)
297221
else:
298222
# Gather logits for tensor parallel
@@ -301,6 +225,5 @@ def _get_logits(
301225
# Remove paddings in vocab (if any)
302226
if logits is not None:
303227
logits = logits[..., :self.org_vocab_size]
304-
228+
305229
return logits
306-

0 commit comments

Comments
 (0)