Skip to content

Commit edce3b8

Browse files
committed
support npugraph to default
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
1 parent 66a0837 commit edce3b8

File tree

13 files changed

+422
-90
lines changed

13 files changed

+422
-90
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ jobs:
115115
- name: Install vllm-project/vllm-ascend
116116
run: |
117117
pip install -r requirements-dev.txt
118-
pip install -e .
118+
pip install -v --no-build-isolation -e .
119119
120120
- name: Run vllm-project/vllm-ascend test on V0 engine
121121
env:

csrc/ops.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <vector>
2323
#include "kernels/types.h"
24+
#include "torch_npu/csrc/aten/common/from_blob.h"
2425

2526
namespace vllm_ascend {
2627
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
@@ -29,4 +30,20 @@ namespace vllm_ascend {
2930
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
3031
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
3132
uint32_t aivNum);
32-
}
33+
34+
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
35+
if (!tensor.is_privateuseone()) {
36+
throw std::runtime_error("Tensor must be on NPU device");
37+
}
38+
// Get the raw data pointer
39+
void* data_ptr = tensor.data_ptr();
40+
// Get tensor sizes and strides
41+
std::vector<int64_t> sizes = tensor.sizes().vec();
42+
std::vector<int64_t> strides = tensor.strides().vec();
43+
// Get tensor options (dtype, device)
44+
auto options = tensor.options();
45+
// Create a new tensor from the raw data pointer
46+
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
47+
return new_tensor;
48+
}
49+
}

csrc/torch_binding.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
103103
TORCH_LIBRARY_EXPAND(_C, ops)
104104
{
105105
// vLLM-Ascend custom ops
106+
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
107+
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
106108

107109
// Rotary embedding
108110
// Apply GPT-NeoX style rotary embedding to query and key.

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
-r requirements-lint.txt
2+
-r requirements.txt
23
modelscope
34
pytest >= 6.0
45
pytest-asyncio

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ cmake>=3.26
33
decorator
44
numpy<2.0.0
55
packaging
6+
pip
67
pybind11
78
pyyaml
89
scipy
@@ -11,3 +12,4 @@ setuptools-scm>=8
1112
torch_npu
1213
torch >= 2.5.1
1314
torchvision<0.21.0
15+
wheel

tests/compile/__init__.py

Whitespace-only changes.

tests/compile/test_simple.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Test the piecewise compilation with a simple model so that we
4+
can exactly calculate the expected output and side effects.
5+
"""
6+
7+
import os
8+
9+
import torch
10+
import vllm_ascend # noqa: F401
11+
from torch import nn
12+
from torch.library import Library
13+
from torch_npu.contrib import transfer_to_npu # noqa: F401
14+
from vllm.compilation.counter import compilation_counter
15+
from vllm.compilation.decorators import support_torch_compile
16+
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
17+
set_current_vllm_config,)
18+
from vllm.utils import direct_register_custom_op
19+
20+
global_counter = 0
21+
22+
# create a library to hold the custom op
23+
silly_lib = Library("silly", "FRAGMENT") # noqa
24+
25+
26+
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
27+
out: torch.Tensor) -> None:
28+
global global_counter
29+
global_counter += 1
30+
print(f"{global_counter=}")
31+
out.copy_(q)
32+
out[0] += 1
33+
34+
35+
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
36+
out: torch.Tensor) -> None:
37+
return
38+
39+
40+
direct_register_custom_op(
41+
op_name="attention",
42+
op_func=silly_attention,
43+
mutates_args=["out"],
44+
fake_impl=silly_attention_fake,
45+
dispatch_key="PrivateUse1",
46+
target_lib=silly_lib,
47+
)
48+
49+
50+
@support_torch_compile
51+
class SillyModel(nn.Module):
52+
53+
def __init__(self,
54+
*,
55+
vllm_config: VllmConfig,
56+
prefix: str = '',
57+
**kwargs) -> None:
58+
super().__init__()
59+
60+
def forward(self, x: torch.Tensor) -> torch.Tensor:
61+
"""
62+
Overall effect:
63+
x += 1
64+
x[0] += 2
65+
global_counter += 2
66+
"""
67+
x = x + 1
68+
x = x + 2
69+
out = torch.empty_like(x)
70+
torch.ops.silly.attention(x, x, x, out)
71+
x = out
72+
x = x - 2
73+
x = x - 1
74+
out = torch.empty_like(x)
75+
torch.ops.silly.attention(x, x, x, out)
76+
x = out
77+
x = x + 1
78+
return x
79+
80+
81+
def test_simple_piecewise_compile():
82+
83+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
84+
level=CompilationLevel.PIECEWISE,
85+
use_inductor=False,
86+
use_cudagraph=True,
87+
splitting_ops=["silly.attention"],
88+
cudagraph_copy_inputs=True,
89+
cudagraph_capture_sizes=[1, 2],
90+
))
91+
vllm_config.compilation_config.pass_config.enable_fusion = False
92+
with set_current_vllm_config(vllm_config):
93+
model = SillyModel(vllm_config=vllm_config, prefix='')
94+
95+
inputs = torch.randn(100).npu()
96+
97+
with compilation_counter.expect(
98+
num_graphs_seen=1, # one graph for the model
99+
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
100+
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
101+
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
102+
num_cudagraph_caputured=
103+
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
104+
):
105+
106+
model(inputs)
107+
108+
model(torch.randn(2).npu())
109+
model(torch.randn(1).npu())
110+
111+
input = torch.zeros(2).npu()
112+
global global_counter
113+
global_counter = 0
114+
output = model(input)
115+
assert global_counter == 2
116+
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
117+
118+
119+
if __name__ == "__main__":
120+
os.environ["VLLM_USE_V1"] = "1"
121+
test_simple_piecewise_compile()

vllm_ascend/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# This file is a part of the vllm-ascend project.
1616
#
1717

18+
from .utils import register_dummy_fusion_op
19+
1820

1921
def register():
2022
"""Register the NPU platform."""
@@ -28,3 +30,6 @@ def register():
2830
def register_model():
2931
from .models import register_model
3032
register_model()
33+
34+
35+
register_dummy_fusion_op()

0 commit comments

Comments
 (0)