Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/developer_guide/evaluation/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ using_evalscope
:caption: Performance
:maxdepth: 1
performance_benchmark
profile_execute_duration
:::
34 changes: 34 additions & 0 deletions docs/source/developer_guide/evaluation/profile_execute_duration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Profile Execute Duration

The execution duration of each stage (including pre/post-processing, model forward, etc.) usually needs to be captured during a complete inference process. Typically, this is done by using `torch.npu.synchronize()` and obtaining CPU timestamps, which increases the performance overhead of host/device synchronization.

**To reduce the performance overhead, we add this feature, using the NPU event timestamp mechanism to observe the device execution time asynchronously.**

## Usage
* Use the environment variable `VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE` to enable this feature.
* Use the non-blocking API `ProfileExecuteDuration().capture_async` to set observation points asynchronously when you need to observe the execution duration.
* Use the blocking API `ProfileExecuteDuration().pop_captured_sync` at an appropriate time to get and print the execution durations of all observed stages.

## Example Output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc is good but we could provide a e2e guid to help devs understand. Such as:


We already add key stage of inference (including pre-processing, model forward, etc.), you can execute inference script:

VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE=1 python3 vllm-ascend/examples/offline_inference_npu.py


```
5691:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.17ms [prepare input and forward]:9.57ms [forward]:4.14ms
5695:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.29ms [prepare input and forward]:10.19ms [forward]:4.14ms
5697:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.81ms [prepare input and forward]:10.29ms [forward]:3.99ms
5701:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.10ms [prepare input and forward]:10.62ms [forward]:4.33ms
5705:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.65ms [prepare input and forward]:9.58ms [forward]:4.20ms
5709:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.43ms [prepare input and forward]:9.88ms [forward]:4.20ms
5711:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.89ms [prepare input and forward]:10.49ms [forward]:4.19ms
5715:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.14ms [prepare input and forward]:11.21ms [forward]:4.18ms
5719:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.71ms [prepare input and forward]:10.15ms [forward]:4.42ms
5723:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.31ms [forward]:4.25ms
5725:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.12ms [prepare input and forward]:10.33ms [forward]:4.24ms
5729:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.58ms [prepare input and forward]:10.85ms [forward]:4.32ms
5733:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.32ms [prepare input and forward]:9.79ms [forward]:4.28ms
5737:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:15.06ms [prepare input and forward]:9.89ms [forward]:4.32ms
5739:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.48ms [forward]:4.27ms
5743:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.60ms [prepare input and forward]:10.71ms [forward]:4.61ms
5747:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.21ms [prepare input and forward]:10.10ms [forward]:4.52ms
5751:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:15.03ms [prepare input and forward]:10.00ms [forward]:4.42ms

```
62 changes: 62 additions & 0 deletions tests/singlecard/test_profile_execute_duration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import time
from unittest.mock import patch

import torch
import vllm # noqa: F401

from vllm_ascend.utils import ProfileExecuteDuration


@patch.dict(os.environ, {"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": "1"})
def test_execue_duration_enabled_discrepancy():
a = torch.randn(10000, 10000).npu()
b = torch.randn(10000, 10000).npu()

# warmup
torch.matmul(a, b)
torch.npu.synchronize()

cpu_start = time.perf_counter()
with ProfileExecuteDuration().capture_async("forward"):
torch.matmul(a, b)
torch.npu.synchronize()
cpu_duration = (time.perf_counter() - cpu_start) * 1000
npu_durations = ProfileExecuteDuration().pop_captured_sync()
assert npu_durations and 'forward' in npu_durations
assert not ProfileExecuteDuration._observations

# Assert discrepancy between CPU and NPU duration is within 50% roughly
diff = abs(cpu_duration - npu_durations['forward']) / max(
cpu_duration, npu_durations['forward'])
assert diff <= 0.5, (
f"CPU={cpu_duration:.2f}ms, NPU={npu_durations['forward']:.2f}ms")


def test_execue_duration_disabled():
a = torch.randn(100, 100).npu()
b = torch.randn(100, 100).npu()

with ProfileExecuteDuration().capture_async("forward"):
torch.matmul(a, b)
torch.npu.synchronize()
npu_durations = ProfileExecuteDuration().pop_captured_sync()
assert not npu_durations
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
lambda: os.getenv("VLLM_VERSION", None),
"VLLM_ASCEND_TRACE_RECOMPILES":
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
),
}

# end-env-vars-definition
Expand Down
54 changes: 53 additions & 1 deletion vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
# Adapted from vllm-project/vllm/vllm/worker/worker.py
#

import atexit
import math
from typing import TYPE_CHECKING
from contextlib import contextmanager
from threading import Lock
from typing import TYPE_CHECKING, List, Tuple

import torch
from packaging.version import InvalidVersion, Version
from torch_npu.npu.streams import Event
from vllm.logger import logger

import vllm_ascend.envs as envs
Expand Down Expand Up @@ -175,3 +179,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:

def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))


class ProfileExecuteDuration:
_instance = None
_observations: List[Tuple[str, Event, Event]] = []
_lock = Lock()

def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
atexit.register(cls._instance.destroy)
return cls._instance

def destroy(self):
with self._lock:
self._observations.clear()

@contextmanager
def capture_async(self, duration_tag: str):
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
yield
return

observe_start = Event(enable_timing=True)
observe_start.record()
try:
yield
finally:
observe_end = Event(enable_timing=True)
observe_end.record()
with self._lock:
self._observations.append(
(duration_tag, observe_start, observe_end))

def pop_captured_sync(self) -> dict:
"""Pop and synchronize all events in the observation list"""
durations: dict[str, float] = {}
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
return durations

while self._observations:
with self._lock:
tag, observe_start, observe_end = self._observations.pop()
observe_end.synchronize()
durations[tag] = observe_start.elapsed_time(observe_end)

return durations
Loading