Skip to content

Commit 817c738

Browse files
author
闫鹏全
committed
support aclgraph
1 parent 830fd30 commit 817c738

File tree

5 files changed

+299
-119
lines changed

5 files changed

+299
-119
lines changed

csrc/torch_binding.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <torch_npu/csrc/core/npu/NPUStream.h>
2121
#include <torch_npu/csrc/framework/OpCommand.h>
2222
#include <torch_npu/csrc/npu/Module.h>
23+
#include <torch_npu/csrc/aten/common/from_blob.h>
2324
#include <pybind11/pybind11.h>
2425
#include "acl/acl.h"
2526
#include "tiling/platform/platform_ascendc.h"
@@ -73,7 +74,7 @@ void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
7374
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
7475
at_npu::native::OpCommand cmd;
7576
cmd.Name("rotary_embedding");
76-
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr,
77+
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr,
7778
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
7879
num_heads, num_kv_heads, head_size]() -> int {
7980
auto dtype_num = get_dtype_from_torch(scalar_type);
@@ -90,6 +91,30 @@ void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
9091
cmd.Run();
9192
return ;
9293
}
94+
95+
torch::Tensor weak_ref_tensor(torch::Tensor& tensor)
96+
{
97+
// Ensure tensor is on NPU
98+
if (tensor.is_privateuseone()) {
99+
throw std::runtime_error("Tensor must be on NPU device");
100+
}
101+
102+
// Get the raw data pointer
103+
void* data_ptr = tensor.data_ptr();
104+
105+
// Get tensor sizes and strides
106+
std::vector<int64_t> sizes = tensor.sizes().vec();
107+
std::vector<int64_t> strides = tensor.strides().vec();
108+
109+
// Get tensor options (dtype, device)
110+
auto options = tensor.options();
111+
112+
// Create a new tensor from the raw data pointer
113+
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
114+
115+
return new_tensor;
116+
}
117+
93118
} // namespace vllm_ascend
94119

95120
TORCH_LIBRARY_EXPAND(_C, ops)
@@ -103,6 +128,9 @@ TORCH_LIBRARY_EXPAND(_C, ops)
103128
" Tensor! key, int head_size,"
104129
" Tensor cos_sin_cache, bool is_neox) -> ()");
105130
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
131+
132+
ops.def("weak_ref_tensor", &weak_ref_tensor);
133+
ops.impl("weak_ref_tensor", &weak_ref_tensor);
106134
}
107135

108136
REGISTER_EXTENSION(_C)

vllm_ascend/attention/attention_v1.py

Lines changed: 127 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2424
AttentionLayer, AttentionType)
2525
from vllm.attention.backends.utils import CommonAttentionState
26-
26+
from vllm.utils import direct_register_custom_op
2727

2828
class AscendAttentionBackend(AttentionBackend):
29+
accept_output_buffer: bool = True
2930

3031
@staticmethod
3132
def get_name() -> str:
@@ -167,59 +168,134 @@ def forward(
167168
shape = [batch_size * seq_len, num_heads, head_size]
168169
"""
169170
num_tokens = query.shape[0]
170-
output = torch.empty(num_tokens,
171+
if output is None:
172+
output = torch.empty(num_tokens,
171173
self.num_heads,
172174
self.head_size,
173175
dtype=query.dtype,
174176
device=query.device)
177+
torch.ops.vllm.unified_ascend_attention_with_output(
178+
layer=layer,
179+
query=query,
180+
key=key,
181+
value=value,
182+
kv_cache=kv_cache,
183+
attn_metadata=attn_metadata,
184+
output=output,
185+
self_num_heads=self.num_heads,
186+
self_head_size=self.head_size,
187+
self_scale=self.scale,
188+
self_num_kv_heads=self.num_kv_heads,
189+
self_hidden_size=self.hidden_size,
190+
self_kv_cache_dtype=self.kv_cache_dtype,
191+
self_sliding_window=self.sliding_window,
192+
self_alibi_slopes=self.alibi_slopes,
193+
self_attn_type=self.attn_type,
194+
self_num_queries_per_kv=self.num_queries_per_kv,
195+
self_seq_len_cpu_tensor=self.seq_len_cpu_tensor,
196+
)
175197

176-
if attn_metadata is None:
177-
# Profiling run.
178-
return output.view(num_tokens, self.hidden_size)
179-
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
180-
attn_type = self.attn_type
181-
if attn_type != AttentionType.DECODER:
182-
raise NotImplementedError("Encoder self-attention and "
183-
"encoder/decoder cross-attention "
184-
"are not implemented for "
185-
"PallasAttentionBackendImpl")
186-
# View q k v to BSH.
187-
query = query.view(-1, self.num_heads, self.head_size)
188-
key = key.view(-1, self.num_kv_heads, self.head_size)
189-
value = value.view(-1, self.num_kv_heads, self.head_size)
190-
# TODO: Remove this contiguous in the future.
191-
value = value.contiguous()
192-
193-
if hasattr(layer, 'quant_method'):
194-
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
195-
pass
196-
else:
197-
if kv_cache.numel() > 0:
198-
key_cache, value_cache = kv_cache[0], kv_cache[1]
199-
num_blocks, block_size, _ = key_cache.shape
200-
key_cache = key_cache.view(num_blocks, block_size,
201-
self.num_kv_heads, self.head_size)
202-
value_cache = value_cache.view(num_blocks, block_size,
203-
self.num_kv_heads,
204-
self.head_size)
205-
slots = attn_metadata.slot_mapping
206-
torch_npu._npu_reshape_and_cache(key=key,
207-
value=value,
208-
key_cache=key_cache,
209-
value_cache=value_cache,
210-
slot_indices=slots)
211-
212-
# use paged attention
213-
torch_npu._npu_paged_attention_splitfuse(
214-
query=query,
215-
key_cache=key_cache,
216-
value_cache=value_cache,
217-
mask=attn_metadata.attn_mask,
218-
block_table=attn_metadata.block_tables,
219-
seq_len=attn_metadata.seq_lens,
220-
context_lens=attn_metadata.context_lens,
221-
num_kv_heads=self.num_kv_heads,
222-
num_heads=self.num_heads,
223-
scale_value=self.scale,
224-
out=output)
225198
return output.view(num_tokens, self.hidden_size)
199+
200+
201+
def unified_ascend_attention_with_output(
202+
layer: AttentionLayer,
203+
query: torch.Tensor,
204+
key: torch.Tensor,
205+
value: torch.Tensor,
206+
kv_cache: torch.Tensor,
207+
attn_metadata: AscendMetadata,
208+
output: torch.Tensor,
209+
self_num_heads: int,
210+
self_head_size: int,
211+
self_scale: float,
212+
self_num_kv_heads: int,
213+
self_hidden_size: int,
214+
self_kv_cache_dtype: str,
215+
self_sliding_window: Optional[int],
216+
self_alibi_slopes: torch.Tensor,
217+
self_attn_type: str,
218+
self_num_queries_per_kv: int,
219+
self_seq_len_cpu_tensor: int,
220+
) -> None:
221+
num_tokens = query.shape[0]
222+
if attn_metadata is None:
223+
return output.view(num_tokens, self_hidden_size)
224+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
225+
attn_type = self_attn_type
226+
if attn_type != AttentionType.DECODER:
227+
raise NotImplementedError("Encoder self-attention and "
228+
"encoder/decoder cross-attention "
229+
"are not implemented for "
230+
"PallasAttentionBackendImpl")
231+
# View q k v to BSH.
232+
query = query.view(-1, self_num_heads, self_head_size)
233+
key = key.view(-1, self_num_kv_heads, self_head_size)
234+
value = value.view(-1, self_num_kv_heads, self_head_size)
235+
# TODO: Remove this contiguous in the future.
236+
value = value.contiguous()
237+
238+
if hasattr(layer, 'quant_method'):
239+
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
240+
pass
241+
else:
242+
if kv_cache.numel() > 0:
243+
key_cache, value_cache = kv_cache[0], kv_cache[1]
244+
num_blocks, block_size, _ = key_cache.shape
245+
key_cache = key_cache.view(num_blocks, block_size,
246+
self_num_kv_heads, self_head_size)
247+
value_cache = value_cache.view(num_blocks, block_size,
248+
self_num_kv_heads,
249+
self_head_size)
250+
slots = attn_metadata.slot_mapping
251+
torch_npu._npu_reshape_and_cache(key=key,
252+
value=value,
253+
key_cache=key_cache,
254+
value_cache=value_cache,
255+
slot_indices=slots)
256+
257+
# use paged attention
258+
torch_npu._npu_paged_attention_splitfuse(
259+
query=query,
260+
key_cache=key_cache,
261+
value_cache=value_cache,
262+
mask=attn_metadata.attn_mask,
263+
block_table=attn_metadata.block_tables,
264+
seq_len=attn_metadata.seq_lens,
265+
context_lens=attn_metadata.context_lens,
266+
num_kv_heads=self_num_kv_heads,
267+
num_heads=self_num_heads,
268+
scale_value=self_scale,
269+
out=output)
270+
271+
272+
def unified_attention_with_output_fake(
273+
layer: AttentionLayer,
274+
query: torch.Tensor,
275+
key: torch.Tensor,
276+
value: torch.Tensor,
277+
kv_cache: torch.Tensor,
278+
attn_metadata: AscendMetadata,
279+
output: torch.Tensor,
280+
self_num_heads: int,
281+
self_head_size: int,
282+
self_scale: float,
283+
self_num_kv_heads: int,
284+
self_hidden_size: int,
285+
self_kv_cache_dtype: str,
286+
self_sliding_window: Optional[int],
287+
self_alibi_slopes: torch.Tensor,
288+
self_attn_type: str,
289+
self_num_queries_per_kv: int,
290+
self_seq_len_cpu_tensor: int,
291+
) -> None:
292+
return
293+
294+
295+
direct_register_custom_op(
296+
op_name="unified_ascend_attention_with_output",
297+
op_func=unified_ascend_attention_with_output,
298+
mutates_args=["output"],
299+
fake_impl=unified_attention_with_output_fake,
300+
dispatch_key="PrivateUse1",
301+
)

vllm_ascend/platform.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,19 @@
1515
# limitations under the License.
1616
#
1717

18-
import logging
1918
import os
2019
from typing import TYPE_CHECKING, Optional, Tuple
2120

2221
import torch
2322
import torch_npu # noqa: F401
2423
import vllm.envs as envs
25-
from vllm.config import CompilationLevel
24+
from vllm.config import CompilationLevel, VllmConfig, ModelConfig
2625
from vllm.logger import init_logger
27-
28-
try:
29-
# register custom ops into torch_library here
30-
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
31-
32-
except ImportError as e:
33-
if not str(
34-
e
35-
) == "dynamic module does not define module export function (PyInit_vllm_ascend_C)":
36-
logging.warning(
37-
"Warning: Failed to register custom ops, all custom ops will be disabled"
38-
)
39-
4026
from vllm.platforms import Platform, PlatformEnum
4127

4228
if TYPE_CHECKING:
43-
from vllm.config import ModelConfig, VllmConfig
4429
from vllm.utils import FlexibleArgumentParser
4530
else:
46-
ModelConfig = None
47-
VllmConfig = None
4831
FlexibleArgumentParser = None
4932

5033
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
@@ -106,7 +89,13 @@ def mem_get_info(cls) -> Tuple[int, int]:
10689
@classmethod
10790
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
10891
compilation_config = vllm_config.compilation_config
109-
if compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
92+
if compilation_config and compilation_config.level == CompilationLevel.PIECEWISE:
93+
logger.warning(
94+
"Compilation level %s is supported on NPU now, But use_inductor is no support",
95+
compilation_config.level)
96+
compilation_config.use_inductor = False
97+
compilation_config.splitting_ops = ["vllm.unified_ascend_attention_with_output"]
98+
elif compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
11099
logger.warning(
111100
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
112101
compilation_config.level)
@@ -125,14 +114,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
125114
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
126115

127116
cache_config = vllm_config.cache_config
128-
if cache_config:
129-
if cache_config.block_size is None:
130-
cache_config.block_size = 128
131-
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching:
132-
logger.warning(
133-
"Prefix caching is not supported for V1 now, disable prefix caching"
134-
)
135-
cache_config.enable_prefix_caching = False
117+
if cache_config and cache_config.block_size is None:
118+
cache_config.block_size = 128
119+
120+
if envs.VLLM_USE_V1 and cache_config and cache_config.enable_prefix_caching:
121+
logger.warning(
122+
"Prefix caching is not supported for V1 now, disable prefix caching"
123+
)
124+
cache_config.enable_prefix_caching = False
136125

137126
@classmethod
138127
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
@@ -160,7 +149,4 @@ def is_pin_memory_available(cls):
160149

161150
@classmethod
162151
def supports_v1(cls, model_config: ModelConfig) -> bool:
163-
"""Returns whether the current platform can support v1 for the supplied
164-
model configuration.
165-
"""
166152
return True

0 commit comments

Comments
 (0)