Skip to content

Commit 43f5388

Browse files
committed
[fix]: reduced dependency on vllm for dbo
1 parent 68070f1 commit 43f5388

File tree

7 files changed

+112
-89
lines changed

7 files changed

+112
-89
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -803,31 +803,44 @@ def forward(
803803
# FIX: aicore move/copy should be also placed on the comm stream in dbo,
804804
# otherwise it may affect the accuracy or disturb the overlap of next stage
805805
# TODO: use an elegant way here to avoid it
806-
output_prefill = self._forward_prefill(
807-
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
808-
attn_metadata)
809-
from vllm.multistream.context import get_multistream_comm_context
806+
from vllm_ascend.multistream.context import get_multistream_comm_context
810807
current_ms_metadata = get_multistream_comm_context()
811-
if current_ms_metadata is not None:
808+
if current_ms_metadata is None:
809+
output[num_decode_tokens:] = self._forward_prefill(
810+
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
811+
attn_metadata)
812+
else:
813+
current_ms_metadata.before_comm_event.record()
812814
with torch.npu.stream(current_ms_metadata.comm_stream):
813-
output[num_decode_tokens:] = output_prefill
815+
current_ms_metadata.before_comm_event.wait()
816+
output[num_decode_tokens:] = self._forward_prefill(
817+
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
818+
attn_metadata)
814819
current_ms_metadata.after_comm_event.record()
815-
else:
816-
output[num_decode_tokens:] = output_prefill
820+
821+
822+
817823
if has_decode:
818824
if self.running_in_graph:
819825
return self._forward_decode(decode_ql_nope, decode_q_pe,
820826
decode_k_nope, decode_k_pe,
821827
kv_cache, attn_metadata)
822828
else:
823-
from vllm.multistream.context import get_multistream_comm_context
824-
current_ms_metadata = get_multistream_comm_context()
825-
output_decode = self._forward_decode(
826-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
827-
kv_cache, attn_metadata)
828-
if current_ms_metadata is not None:
829-
with torch.npu.stream(current_ms_metadata.comm_stream):
830-
output[:num_decode_tokens] = output_decode
831-
else:
832-
output[:num_decode_tokens] = output_decode
829+
830+
from vllm_ascend.multistream.context import get_multistream_comm_context
831+
current_ms_metadata = get_multistream_comm_context()
832+
if current_ms_metadata is None:
833+
output[:num_decode_tokens] = self._forward_decode(
834+
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
835+
kv_cache, attn_metadata)
836+
else:
837+
current_ms_metadata.before_comm_event.record()
838+
with torch.npu.stream(current_ms_metadata.comm_stream):
839+
current_ms_metadata.before_comm_event.wait()
840+
output[:num_decode_tokens] = self._forward_decode(
841+
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
842+
kv_cache, attn_metadata)
843+
current_ms_metadata.after_comm_event.record()
844+
845+
833846
return output_padded

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
lambda: os.getenv("C_COMPILER", None),
6767
"VLLM_VERSION":
6868
lambda: os.getenv("VLLM_VERSION", None),
69+
"VLLM_ENABLE_MS":
70+
lambda: bool(int(os.getenv("VLLM_ENABLE_MS", '0'))),
6971
}
7072

7173
# end-env-vars-definition

vllm_ascend/models/deepseek_v2.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@
7272
from vllm_ascend.multistream.context import (set_multistream_context,get_multistream_layer_context,
7373
advance_step_multistream_layer_context, get_multistream_comm_context)
7474
from vllm_ascend.multistream.layers import (MultiStreamPreTransformerLayer, MultiStreamPostTransformerLayer)
75-
from vllm_ascend.multistream.metadata import make_multistream_metadata_ds, MultiStreamStepMetadata
75+
from vllm_ascend.multistream.metadata import make_multistream_metadata_ds, MultiStreamStepMetadata, MultiStreamConfig
7676
from vllm_ascend.multistream.base import MSEventKey
7777
from vllm_ascend.multistream.ms_split import compute_split_seq_index
7878

7979
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
80+
VLLM_ENABLE_MS: bool = envs_ascend.VLLM_ENABLE_MS
8081

8182

8283
class CustomDeepseekV2MLP(nn.Module):
@@ -305,8 +306,10 @@ def _forward_ms_op_tp_allreduce(
305306
dist.all_gather(list(chunk_hidden_states), hidden_states,
306307
self.tp_group)
307308
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
308-
if num_tokens < self.tp_size:
309-
final_hidden_states = final_hidden_states[:num_tokens]
309+
#if num_tokens < self.tp_size:
310+
# final_hidden_states = final_hidden_states[:num_tokens]
311+
if num_tokens > 0:
312+
final_hidden_states = final_hidden_states[:-num_tokens]
310313
else:
311314
final_hidden_states = hidden_states
312315

@@ -641,6 +644,10 @@ def _forward_ms_layer(
641644
)
642645

643646
with set_multistream_context(context, i):
647+
context = get_forward_context()
648+
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context()
649+
context.attn_metadata = attn_metadata[i]
650+
644651
# input layernorm
645652
hidden_states[i], residual[i] = self._forward_ms_op_input_layernorm(hidden_states[i], residual[i])
646653
# attention and tp allreducea
@@ -664,7 +671,7 @@ def _forward_ms_layer(
664671

665672
num_token, hidden_dim = hidden_states[i].shape
666673
hidden_states[i] = hidden_states[i].view(-1, hidden_dim)
667-
num_tokens.append(num_token)
674+
#num_tokens.append(num_token)
668675
hidden_dims.append(hidden_dim)
669676
if self.mlp.n_shared_experts is not None:
670677
# TODO: we can move shared expert computation into next block if reduce results is false
@@ -686,13 +693,20 @@ def _forward_ms_layer(
686693
enable_force_load_balance = False
687694

688695
if self.mlp.tp_size > 1:
689-
if num_tokens[i] < self.mlp.tp_size:
690-
target_size = self.mlp.tp_size
691-
new_hidden_states = torch.empty([target_size, hidden_dims[i]],
692-
dtype=hidden_states[i].dtype,
693-
device=hidden_states[i].device)
694-
new_hidden_states[:num_tokens[i]] = hidden_states[i]
695-
hidden_states[i] = new_hidden_states
696+
#if num_tokens[i] < self.mlp.tp_size:
697+
# target_size = self.mlp.tp_size
698+
# new_hidden_states = torch.empty([target_size, hidden_dims[i]],
699+
# dtype=hidden_states[i].dtype,
700+
# device=hidden_states[i].device)
701+
# new_hidden_states[:num_tokens[i]] = hidden_states[i]
702+
# hidden_states[i] = new_hidden_states
703+
num_token, _ = hidden_states[i].shape
704+
padded_num_tokens = (self.mlp.tp_size -
705+
num_token % self.mlp.tp_size) % self.mlp.tp_size
706+
if padded_num_tokens > 0:
707+
hidden_states[i] = nn.functional.pad(hidden_states[i],
708+
(0, 0, 0, padded_num_tokens))
709+
num_tokens.append(padded_num_tokens)
696710
chunk_hidden_state = torch.tensor_split(hidden_states[i],
697711
self.mlp.tp_size,
698712
dim=0)
@@ -713,7 +727,7 @@ def _forward_ms_layer(
713727
if VLLM_ENABLE_MC2 and not is_prefill:
714728
...
715729

716-
hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp(hidden_states[i], router_logits[i], is_prefill, real_top_k, enable_force_load_balance)
730+
hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp(local_hidden_states, router_logits[i], is_prefill, real_top_k, enable_force_load_balance)
717731

718732
if VLLM_ENABLE_MC2 and not is_prefill:
719733
...
@@ -847,7 +861,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
847861
["hidden_states", "residual"], config.hidden_size))
848862

849863
# tbo related members
850-
self.multistream_config = vllm_config.model_config.multistream_config
864+
if VLLM_ENABLE_MS:
865+
self.multistream_config = MultiStreamConfig()
866+
else:
867+
self.multistream_config = None
851868
self.use_mla = model_config.use_mla
852869
self.multistream_metadata = make_multistream_metadata_ds(
853870
start_layer=self.start_layer + self.first_k_dense_replace,
@@ -929,13 +946,14 @@ def can_run_ms(self):
929946
return False
930947
num_microbatchs = self.multistream_config.num_micro_batches
931948
# check whether there is a dp rank that not use dual batch
932-
if dp_metadata is not None:
949+
'''if dp_metadata is not None:
933950
for i in range(num_microbatchs):
934951
cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i]
935952
if torch.any(cu_tokens == 0).item():
936953
return False
937954
[token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens,
938-
attn_metadata.attn_state, attn_metadata.num_decode_tokens)
955+
attn_metadata.attn_state, attn_metadata.num_decode_tokens)
956+
'''
939957
if token_index == 0 or seq_index == 0 or seq_index == len(attn_metadata.query_lens):
940958
return False
941959
# check whether the total tokens exceed the threshold

vllm_ascend/multistream/metadata.py

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch
33
from typing import Dict, List, Optional, Union, Tuple
44
from vllm.sequence import IntermediateTensors
5-
from vllm.config import MultiStreamConfig
65
from .base import MSAttentionMetadataSplitConfig, MSEventKey
76
from vllm.attention.backends.abstract import AttentionMetadata
87

@@ -31,57 +30,21 @@ def split_micro_batches_tensors(input_tensors, split_index: int, keys: List[str]
3130
return [micro_batches_pre, micro_batches_post]
3231
else:
3332
raise NotImplementedError
34-
def make_multistream_metadata(
35-
start_layer: int,
36-
end_layer: int,
37-
causal_lm: bool = True,
38-
multistream_config: Optional[MultiStreamConfig] = None,
39-
):
40-
if multistream_config is None:
41-
return None
42-
return MultiStreamMetadata(
43-
calculate_stream=torch.npu.current_stream(),
44-
communicate_stream=torch.npu.Stream(),
45-
start_layer=start_layer,
46-
end_layer=end_layer,
47-
multistream_config=multistream_config,
48-
event_keys=[MSEventKey.ATTN_COM_FINISH, MSEventKey.ATTN_AR_FINISH,
49-
MSEventKey.FFN_COM_FINISH, MSEventKey.FFN_AR_FINISH],
50-
causal_lm=causal_lm,
51-
)
52-
def make_multistream_metadata_ds(
53-
start_layer: int,
54-
end_layer: int,
55-
causal_lm: bool = True,
56-
multistream_config: Optional[MultiStreamConfig] = None,
57-
):
58-
if multistream_config is None:
59-
return None
60-
event_keylist = [
61-
MSEventKey.ATTN_COM_FINISH,
62-
MSEventKey.ATTN_AR_FINISH,
63-
MSEventKey.FFN_COM_FINISH,
64-
MSEventKey.FFN_AR_FINISH,
65-
MSEventKey.MOE_BEFORE_COMM,
66-
MSEventKey.MOE_AFTER_COMM,
67-
MSEventKey.MOE_SE_COMM_FINISH,
68-
MSEventKey.MOE_SE_COMP_FINISH,
69-
MSEventKey.MOE_GATE_FINISH,
70-
]
71-
return MultiStreamMetadata(
72-
calculate_stream=torch.npu.current_stream(),
73-
communicate_stream=torch.npu.Stream(),
74-
start_layer=start_layer,
75-
end_layer=end_layer,
76-
multistream_config=multistream_config,
77-
event_keys=event_keylist,
78-
causal_lm=causal_lm,
79-
)
33+
8034
@dataclass
8135
class MultiStreamStepMetadata:
8236
comm_stream: torch.npu.Stream = None
8337
before_comm_event: torch.npu.Event = None
8438
after_comm_event: torch.npu.Event = None
39+
40+
@dataclass
41+
class MultiStreamConfig:
42+
"""Controls the behavior of multi-stream models."""
43+
min_total_tokens_to_split: int = 256
44+
min_prefill_tokens_to_split: int = 64
45+
num_micro_batches: int = 2
46+
imbalance_ratio: float = 0.1
47+
8548
class MultiStreamMetadata:
8649
# direct stream
8750
calculate_stream = None
@@ -157,4 +120,34 @@ def merge_micro_batches(self,
157120
batch.append(None)
158121
else:
159122
batch.append(torch.cat(tensors, dim=0))
160-
return batch
123+
return batch
124+
125+
126+
def make_multistream_metadata_ds(
127+
start_layer: int,
128+
end_layer: int,
129+
causal_lm: bool = True,
130+
multistream_config: Optional[MultiStreamConfig] = None,
131+
):
132+
if multistream_config is None:
133+
return None
134+
event_keylist = [
135+
MSEventKey.ATTN_COM_FINISH,
136+
MSEventKey.ATTN_AR_FINISH,
137+
MSEventKey.FFN_COM_FINISH,
138+
MSEventKey.FFN_AR_FINISH,
139+
MSEventKey.MOE_BEFORE_COMM,
140+
MSEventKey.MOE_AFTER_COMM,
141+
MSEventKey.MOE_SE_COMM_FINISH,
142+
MSEventKey.MOE_SE_COMP_FINISH,
143+
MSEventKey.MOE_GATE_FINISH,
144+
]
145+
return MultiStreamMetadata(
146+
calculate_stream=torch.npu.current_stream(),
147+
communicate_stream=torch.npu.Stream(),
148+
start_layer=start_layer,
149+
end_layer=end_layer,
150+
multistream_config=multistream_config,
151+
event_keys=event_keylist,
152+
causal_lm=causal_lm,
153+
)

vllm_ascend/multistream/ms_split.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def compute_split_seq_index(
99
num_tokens: int,
1010
imbalance_ratio: float = 0.1,
1111
)->Optional[list[int]]:
12-
if attn_state == AscendAttentionState.PrefillOnly or attn_state == AscendAttentionState.ChunkedPrefill:
12+
if attn_state != AscendAttentionState.DecodeOnly:
1313
assert query_lens is not None
1414
total_tokens = sum(query_lens)
1515
# the first index in last split
@@ -28,11 +28,10 @@ def compute_split_seq_index(
2828
# TODO: split tokens in seq
2929
else :
3030
return [0, 0]
31-
elif attn_state == AscendAttentionState.DecodeOnly:
31+
else:
3232
tokens = num_tokens // 2
3333
return [tokens, tokens]
34-
else:
35-
return [0, 0]
34+
3635
def split_attn_tensor_type(
3736
input_tensor: torch.Tensor,
3837
index: int,
@@ -69,10 +68,10 @@ def model_input_split_v1_mla_attn(
6968
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills>0 else attn_metadata.decode.seq_lens
7069
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens,seq_index)
7170

72-
if attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
71+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
7372
# the attn_mla kernel in torch npu only accept 128*128 attn mask
7473
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
75-
attn_state_pre = attn_state_post = AscendAttentionState.PrefillOnly
74+
attn_state_pre = attn_state_post = attn_metadata.attn_state
7675
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
7776
# should be none in decode only state
7877
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask

vllm_ascend/ops/fused_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from vllm.model_executor.layers.quantization.base_config import (
3838
QuantizationConfig, QuantizeMethodBase)
3939

40-
from vllm.multistream.context import set_multistream_context,get_multistream_comm_context
4140
from vllm_ascend.multistream.base import MSEventKey
4241
from vllm_ascend.multistream.metadata import MultiStreamStepMetadata, MultiStreamMetadata
4342
import vllm_ascend.envs as envs_ascend

vllm_ascend/worker/model_runner_v1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,6 @@ def _process_reqs(
633633
# Run forward pass
634634
with set_forward_context(attn_metadata,
635635
self.vllm_config,
636-
query_lens=self.query_lens,
637636
num_tokens=num_input_tokens):
638637
model_kwargs = {}
639638
if self.enable_torchair_graph_mode:

0 commit comments

Comments
 (0)