Skip to content

Commit 52da0d4

Browse files
committed
Add profile execute duration observation
Signed-off-by: depeng1994 <depengzhang@foxmail.com>
1 parent 5a1689f commit 52da0d4

File tree

3 files changed

+189
-122
lines changed

3 files changed

+189
-122
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_MODEL_EXECUTE_TIME_OBSERVE":
40+
lambda: bool(int(os.getenv("VLLM_MODEL_EXECUTE_TIME_OBSERVE", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/worker.py
1818
#
1919

20+
import atexit
2021
import math
21-
from typing import TYPE_CHECKING
22+
from contextlib import contextmanager
23+
from threading import Lock
24+
from typing import TYPE_CHECKING, List, Tuple
2225

2326
import torch
2427
from packaging.version import InvalidVersion, Version
28+
from torch_npu.npu.streams import Event
2529
from vllm.logger import logger
2630

2731
import vllm_ascend.envs as envs
@@ -175,3 +179,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
175179

176180
def dispose_tensor(x: torch.Tensor):
177181
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
182+
183+
184+
class ProfileExecuteDuration:
185+
_instance = None
186+
_observations: List[Tuple[str, Event, Event]] = []
187+
_lock = Lock()
188+
189+
def __new__(cls):
190+
with cls._lock:
191+
if cls._instance is None:
192+
cls._instance = super().__new__(cls)
193+
atexit.register(cls._instance.destroy)
194+
return cls._instance
195+
196+
def destroy(self):
197+
with self._lock:
198+
self._observations.clear()
199+
200+
@contextmanager
201+
def capture_async(self, duration_tag: str):
202+
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
203+
yield
204+
return
205+
206+
observe_start = Event(enable_timing=True)
207+
observe_start.record()
208+
try:
209+
yield
210+
finally:
211+
observe_end = Event(enable_timing=True)
212+
observe_end.record()
213+
with self._lock:
214+
self._observations.append(
215+
(duration_tag, observe_start, observe_end))
216+
217+
def pop_captured_sync(self, captured_name: str):
218+
"""Pop and synchronize all events in the observation list, print all duration"""
219+
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
220+
return
221+
222+
log = f"Profile execute duration [{captured_name}]:"
223+
while self._observations:
224+
with self._lock:
225+
tag, observe_start, observe_end = self._observations.pop()
226+
observe_end.synchronize()
227+
duration = observe_start.elapsed_time(observe_end)
228+
log += f" [{tag}]:{duration:.2f}ms"
229+
print(log)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 134 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
6363
from vllm_ascend.platform import NPUPlatform
6464
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
65+
from vllm_ascend.utils import ProfileExecuteDuration
6566
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
6667

6768
if TYPE_CHECKING:
@@ -663,36 +664,38 @@ def _process_reqs(
663664
with set_forward_context(attn_metadata,
664665
self.vllm_config,
665666
num_tokens=num_input_tokens):
666-
model_kwargs = {}
667-
if self.enable_torchair_graph_mode:
668-
model_kwargs["kv_caches"] = self.kv_caches
669-
model_kwargs["attn_metadata"] = attn_metadata
670-
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
671-
torch._dynamo.mark_static(input_ids)
672-
torch._dynamo.mark_static(positions)
673-
torch._dynamo.mark_static(attn_metadata.decode.block_table)
674-
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
675-
torch._dynamo.mark_static(attn_metadata.slot_mapping)
676-
for kv in self.kv_caches:
677-
if isinstance(kv, tuple):
678-
torch._dynamo.mark_static(kv[0])
679-
torch._dynamo.mark_static(kv[1])
680-
hidden_states = self.compile_model(
681-
input_ids=input_ids,
682-
positions=positions,
683-
intermediate_tensors=intermediate_tensors,
684-
inputs_embeds=None,
685-
**model_kwargs,
686-
)
687-
else:
688-
assert self.model is not None
689-
hidden_states = self.model(
690-
input_ids=input_ids,
691-
positions=positions,
692-
intermediate_tensors=intermediate_tensors,
693-
inputs_embeds=None,
694-
**model_kwargs,
695-
)
667+
with ProfileExecuteDuration().capture_async("forward"):
668+
model_kwargs = {}
669+
if self.enable_torchair_graph_mode:
670+
model_kwargs["kv_caches"] = self.kv_caches
671+
model_kwargs["attn_metadata"] = attn_metadata
672+
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
673+
torch._dynamo.mark_static(input_ids)
674+
torch._dynamo.mark_static(positions)
675+
torch._dynamo.mark_static(attn_metadata.decode.block_table)
676+
torch._dynamo.mark_static(
677+
attn_metadata.decode.input_positions)
678+
torch._dynamo.mark_static(attn_metadata.slot_mapping)
679+
for kv in self.kv_caches:
680+
if isinstance(kv, tuple):
681+
torch._dynamo.mark_static(kv[0])
682+
torch._dynamo.mark_static(kv[1])
683+
hidden_states = self.compile_model(
684+
input_ids=input_ids,
685+
positions=positions,
686+
intermediate_tensors=intermediate_tensors,
687+
inputs_embeds=None,
688+
**model_kwargs,
689+
)
690+
else:
691+
assert self.model is not None
692+
hidden_states = self.model(
693+
input_ids=input_ids,
694+
positions=positions,
695+
intermediate_tensors=intermediate_tensors,
696+
inputs_embeds=None,
697+
**model_kwargs,
698+
)
696699

697700
use_spec_decode = len(
698701
scheduler_output.scheduled_spec_decode_tokens) > 0
@@ -885,103 +888,113 @@ def execute_model(
885888
scheduler_output: "SchedulerOutput",
886889
intermediate_tensors: Optional[IntermediateTensors] = None,
887890
) -> Union[ModelRunnerOutput, torch.Tensor]:
888-
self._update_states(scheduler_output)
889-
if not scheduler_output.total_num_scheduled_tokens:
890-
# Return empty ModelRunnerOuptut if there's no work to do.
891-
return EMPTY_MODEL_RUNNER_OUTPUT
892-
(attn_metadata, hidden_states, spec_decode_metadata, positions,
893-
num_scheduled_tokens,
894-
sample_indices) = (self._process_reqs(scheduler_output,
895-
intermediate_tensors))
896-
logits = self.model.compute_logits(hidden_states[sample_indices], None)
897-
898-
# Apply structured output bitmasks if present
899-
if scheduler_output.grammar_bitmask is not None:
900-
logits = self.apply_grammar_bitmask(scheduler_output, logits)
901-
902-
# Sample the next token and get logprobs if needed.
903-
sampling_metadata = self.input_batch.sampling_metadata
904-
if spec_decode_metadata is None:
905-
sampler_output = self.sampler(
906-
logits=logits,
907-
sampling_metadata=sampling_metadata,
908-
)
909-
else:
910-
# When indexing with a tensor (bonus_logits_indices), PyTorch
911-
# creates a new tensor with separate storage from the original
912-
# logits tensor. This means any in-place operations on bonus_logits
913-
# won't affect the original logits tensor.
914-
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
915-
sampler_output = self.sampler(
916-
logits=bonus_logits,
917-
sampling_metadata=sampling_metadata,
918-
)
919-
bonus_token_ids = sampler_output.sampled_token_ids
891+
with ProfileExecuteDuration().capture_async(
892+
"prepare input and forward"):
893+
self._update_states(scheduler_output)
894+
if not scheduler_output.total_num_scheduled_tokens:
895+
# Return empty ModelRunnerOuptut if there's no work to do.
896+
return EMPTY_MODEL_RUNNER_OUTPUT
897+
(attn_metadata, hidden_states, spec_decode_metadata, positions,
898+
num_scheduled_tokens,
899+
sample_indices) = (self._process_reqs(scheduler_output,
900+
intermediate_tensors))
901+
902+
with ProfileExecuteDuration().capture_async("post process"):
903+
logits = self.model.compute_logits(hidden_states[sample_indices],
904+
None)
905+
906+
# Apply structured output bitmasks if present
907+
if scheduler_output.grammar_bitmask is not None:
908+
logits = self.apply_grammar_bitmask(scheduler_output, logits)
909+
910+
# Sample the next token and get logprobs if needed.
911+
sampling_metadata = self.input_batch.sampling_metadata
912+
if spec_decode_metadata is None:
913+
sampler_output = self.sampler(
914+
logits=logits,
915+
sampling_metadata=sampling_metadata,
916+
)
917+
else:
918+
# When indexing with a tensor (bonus_logits_indices), PyTorch
919+
# creates a new tensor with separate storage from the original
920+
# logits tensor. This means any in-place operations on bonus_logits
921+
# won't affect the original logits tensor.
922+
bonus_logits = logits[
923+
spec_decode_metadata.bonus_logits_indices]
924+
sampler_output = self.sampler(
925+
logits=bonus_logits,
926+
sampling_metadata=sampling_metadata,
927+
)
928+
bonus_token_ids = sampler_output.sampled_token_ids
929+
930+
# Just like `bonus_logits`, `target_logits` is a new tensor with
931+
# separate storage from the original `logits` tensor. Therefore,
932+
# it is safe to update `target_logits` in place.
933+
target_logits = logits[
934+
spec_decode_metadata.target_logits_indices]
935+
output_token_ids = self.rejection_sampler(
936+
spec_decode_metadata,
937+
None, # draft_probs
938+
target_logits,
939+
bonus_token_ids,
940+
sampling_metadata,
941+
)
942+
sampler_output.sampled_token_ids = output_token_ids
920943

921-
# Just like `bonus_logits`, `target_logits` is a new tensor with
922-
# separate storage from the original `logits` tensor. Therefore,
923-
# it is safe to update `target_logits` in place.
924-
target_logits = logits[spec_decode_metadata.target_logits_indices]
925-
output_token_ids = self.rejection_sampler(
926-
spec_decode_metadata,
927-
None, # draft_probs
928-
target_logits,
929-
bonus_token_ids,
944+
# TODO(woosuk): The following loop can be slow since it iterates over
945+
# the requests one by one. Optimize.
946+
for i, req_id in enumerate(self.input_batch.req_ids):
947+
req_state = self.requests[req_id]
948+
seq_len = (req_state.num_computed_tokens +
949+
scheduler_output.num_scheduled_tokens[req_id])
950+
if seq_len < req_state.num_tokens:
951+
# Ignore the sampled token.
952+
# Rewind the generator state as if the token was not sampled.
953+
generator = self.input_batch.generators.get(i)
954+
if generator is not None:
955+
generator.set_offset(generator.get_offset() - 4)
956+
957+
# NOTE: NPU -> CPU Sync happens here.
958+
# Move as many CPU operations as possible before this sync point.
959+
logprobs_tensors = sampler_output.logprobs_tensors
960+
logprobs_lists = logprobs_tensors.tolists() \
961+
if logprobs_tensors is not None else None
962+
963+
# Get the valid generated tokens.
964+
sampled_token_ids = sampler_output.sampled_token_ids
965+
max_gen_len = sampled_token_ids.shape[-1]
966+
if max_gen_len == 1:
967+
# No spec decode tokens.
968+
valid_sampled_token_ids = sampled_token_ids.tolist()
969+
else:
970+
# Includes spec decode tokens.
971+
valid_sampled_token_ids = self.rejection_sampler.parse_output(
972+
sampled_token_ids,
973+
self.input_batch.vocab_size,
974+
)
975+
976+
spec_token_ids = self._get_spec_token_ids(
977+
valid_sampled_token_ids,
930978
sampling_metadata,
979+
scheduler_output,
980+
spec_decode_metadata,
981+
positions,
982+
num_scheduled_tokens,
983+
hidden_states,
984+
attn_metadata,
931985
)
932-
sampler_output.sampled_token_ids = output_token_ids
933986

934-
# TODO(woosuk): The following loop can be slow since it iterates over
935-
# the requests one by one. Optimize.
936-
for i, req_id in enumerate(self.input_batch.req_ids):
937-
req_state = self.requests[req_id]
938-
seq_len = (req_state.num_computed_tokens +
939-
scheduler_output.num_scheduled_tokens[req_id])
940-
if seq_len < req_state.num_tokens:
941-
# Ignore the sampled token.
942-
# Rewind the generator state as if the token was not sampled.
943-
generator = self.input_batch.generators.get(i)
944-
if generator is not None:
945-
generator.set_offset(generator.get_offset() - 4)
946-
947-
# NOTE: NPU -> CPU Sync happens here.
948-
# Move as many CPU operations as possible before this sync point.
949-
logprobs_tensors = sampler_output.logprobs_tensors
950-
logprobs_lists = logprobs_tensors.tolists() \
951-
if logprobs_tensors is not None else None
952-
953-
# Get the valid generated tokens.
954-
sampled_token_ids = sampler_output.sampled_token_ids
955-
max_gen_len = sampled_token_ids.shape[-1]
956-
if max_gen_len == 1:
957-
# No spec decode tokens.
958-
valid_sampled_token_ids = sampled_token_ids.tolist()
959-
else:
960-
# Includes spec decode tokens.
961-
valid_sampled_token_ids = self.rejection_sampler.parse_output(
962-
sampled_token_ids,
963-
self.input_batch.vocab_size,
987+
model_runner_output = ModelRunnerOutput(
988+
req_ids=self.input_batch.req_ids,
989+
req_id_to_index=self.input_batch.req_id_to_index,
990+
sampled_token_ids=valid_sampled_token_ids,
991+
spec_token_ids=spec_token_ids,
992+
logprobs=logprobs_lists,
993+
prompt_logprobs_dict={},
964994
)
965995

966-
spec_token_ids = self._get_spec_token_ids(
967-
valid_sampled_token_ids,
968-
sampling_metadata,
969-
scheduler_output,
970-
spec_decode_metadata,
971-
positions,
972-
num_scheduled_tokens,
973-
hidden_states,
974-
attn_metadata,
975-
)
976-
977-
model_runner_output = ModelRunnerOutput(
978-
req_ids=self.input_batch.req_ids,
979-
req_id_to_index=self.input_batch.req_id_to_index,
980-
sampled_token_ids=valid_sampled_token_ids,
981-
spec_token_ids=spec_token_ids,
982-
logprobs=logprobs_lists,
983-
prompt_logprobs_dict={},
984-
)
996+
capture_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
997+
ProfileExecuteDuration().pop_captured_sync(capture_name)
985998
return model_runner_output
986999

9871000
def _profile_multimodal(self) -> None:

0 commit comments

Comments
 (0)