Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
import torch_npu
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import logger

from vllm_ascend.platform import NPUPlatform
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
write_kv_cache_bytes_to_file)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
maybe_converting_weight_acl_format)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
Expand All @@ -37,6 +41,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
def _get_forward_metadata_across_dp_and_pad(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
"""Override from NPUModelRunner to pad num_tokens"""
if self.dp_size == 1:
if not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
Expand Down Expand Up @@ -118,3 +123,49 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
def _convert_torch_format(self, kv_cache):
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
return kv_cache

def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
# Trigger torchair graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens, is_torchair_compile=True)
self._dummy_run(num_tokens, is_torchair_compile=True)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes))

def _capture_model(self):
"""Override from NPUModelRunner to use torchair graph capture."""
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
# torchair graph capture can cause some issues, so now we just
# temporarily split the codepath for the two different graph patterns.
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
graph_num = len(torchair_graph_batch_sizes)

if self.use_cached_npu_graph and not check_torchair_cache_exist():
# If caching is enabled but does not exist, we will compile the model twice. The first
# time is used to generate the cache, and the second time is used to load the cache to
# skip the overhead caused by Dynamo guard mechanism.
logger.info(
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
NPUPlatform.synchronize()
torch._dynamo.reset()
self.torchair_compiled_models.clear()
if self.use_cached_npu_graph:
logger.info(
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
0.3 * graph_num, 0.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
else:
logger.info(
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)

if self.new_kv_cache_bytes > 0:
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
self.new_kv_cache_bytes)
72 changes: 15 additions & 57 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
write_kv_cache_bytes_to_file)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, is_310p,
maybe_converting_weight_acl_format,
Expand Down Expand Up @@ -2323,67 +2321,27 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:

return kv_cache_spec

def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
# Trigger torchair graph capture for specific shapes.
def _capture_model(self):
if not self.use_aclgraph:
logger.info("Skipping NPU graph capture for eager mode.")
return
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens, is_torchair_compile=True)
self._dummy_run(num_tokens, is_torchair_compile=True)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
with graph_capture(device=self.device):
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)

def capture_model(self) -> None:
start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0]
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
# torchair graph capture can cause some issues, so now we just
# temporarily split the codepath for the two different graph patterns.
if self.torchair_graph_enabled:
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
graph_num = len(torchair_graph_batch_sizes)

if self.use_cached_npu_graph and not check_torchair_cache_exist():
# If caching is enabled but does not exist, we will compile the model twice. The first
# time is used to generate the cache, and the second time is used to load the cache to
# skip the overhead caused by Dynamo guard mechanism.
logger.info(
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
NPUPlatform.synchronize()
torch._dynamo.reset()
self.torchair_compiled_models.clear()
if self.use_cached_npu_graph:
logger.info(
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
0.3 * graph_num, 0.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
else:
logger.info(
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)

if self.new_kv_cache_bytes > 0:
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
self.new_kv_cache_bytes)
elif self.use_aclgraph:
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
with graph_capture(device=self.device):
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
else:
logger.info("Skipping NPU graph capture for eager mode.")
return

self._capture_model()

end_time = time.perf_counter()
end_free_npu_memory = torch.npu.mem_get_info()[0]
elapsed_time = end_time - start_time
Expand Down
Loading