Skip to content

Commit ba0cd81

Browse files
committed
support npugraph to default
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
1 parent 5442b46 commit ba0cd81

File tree

13 files changed

+439
-102
lines changed

13 files changed

+439
-102
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ jobs:
115115
- name: Install vllm-project/vllm-ascend
116116
run: |
117117
pip install -r requirements-dev.txt
118-
pip install -e .
118+
pip install -v --no-build-isolation -e .
119119
120120
- name: Run vllm-project/vllm-ascend test on V0 engine
121121
env:
@@ -137,9 +137,11 @@ jobs:
137137
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
138138
pytest -sv tests/singlecard/test_offline_inference.py
139139
pytest -sv tests/ops
140+
pytest -sv tests/compile
140141
else
141142
pytest -sv tests/multicard/test_offline_inference_distributed.py
142143
pytest -sv tests/ops
144+
pytest -sv tests/compile
143145
fi
144146
145147
# only run test on spec decode when the related code changed

csrc/ops.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <vector>
2323
#include "kernels/types.h"
24+
#include "torch_npu/csrc/aten/common/from_blob.h"
2425

2526
namespace vllm_ascend {
2627
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
@@ -29,4 +30,20 @@ namespace vllm_ascend {
2930
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
3031
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
3132
uint32_t aivNum);
32-
}
33+
34+
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
35+
if (!tensor.is_privateuseone()) {
36+
throw std::runtime_error("Tensor must be on NPU device");
37+
}
38+
// Get the raw data pointer
39+
void* data_ptr = tensor.data_ptr();
40+
// Get tensor sizes and strides
41+
std::vector<int64_t> sizes = tensor.sizes().vec();
42+
std::vector<int64_t> strides = tensor.strides().vec();
43+
// Get tensor options (dtype, device)
44+
auto options = tensor.options();
45+
// Create a new tensor from the raw data pointer
46+
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
47+
return new_tensor;
48+
}
49+
}

csrc/torch_binding.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
103103
TORCH_LIBRARY_EXPAND(_C, ops)
104104
{
105105
// vLLM-Ascend custom ops
106+
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
107+
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
106108

107109
// Rotary embedding
108110
// Apply GPT-NeoX style rotary embedding to query and key.

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
-r requirements-lint.txt
2+
-r requirements.txt
23
modelscope
34
pytest >= 6.0
45
pytest-asyncio

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ cmake>=3.26
33
decorator
44
numpy<2.0.0
55
packaging
6+
pip
67
pybind11
78
pyyaml
89
scipy
910
setuptools>=64
1011
setuptools-scm>=8
11-
torch_npu
1212
torch >= 2.5.1
1313
torchvision<0.21.0
14+
wheel

tests/compile/__init__.py

Whitespace-only changes.

tests/compile/test_simple.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Test the piecewise compilation with a simple model so that we
4+
can exactly calculate the expected output and side effects.
5+
"""
6+
7+
import torch
8+
from torch import nn
9+
from torch.library import Library
10+
from vllm.compilation.counter import compilation_counter
11+
from vllm.compilation.decorators import support_torch_compile
12+
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
13+
set_current_vllm_config)
14+
from vllm.utils import direct_register_custom_op
15+
16+
17+
global_counter = 0
18+
19+
# create a library to hold the custom op
20+
silly_lib = Library("silly", "FRAGMENT") # noqa
21+
22+
23+
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
24+
out: torch.Tensor) -> None:
25+
global global_counter
26+
global_counter += 1
27+
print(f"{global_counter=}")
28+
out.copy_(q)
29+
out[0] += 1
30+
31+
32+
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
33+
out: torch.Tensor) -> None:
34+
return
35+
36+
37+
direct_register_custom_op(
38+
op_name="attention",
39+
op_func=silly_attention,
40+
mutates_args=["out"],
41+
fake_impl=silly_attention_fake,
42+
dispatch_key="PrivateUse1",
43+
target_lib=silly_lib,
44+
)
45+
46+
47+
@support_torch_compile
48+
class SillyModel(nn.Module):
49+
50+
def __init__(self,
51+
*,
52+
vllm_config: VllmConfig,
53+
prefix: str = "",
54+
**kwargs) -> None:
55+
super().__init__()
56+
57+
def forward(self, x: torch.Tensor) -> torch.Tensor:
58+
"""
59+
Overall effect:
60+
x += 1
61+
x[0] += 2
62+
global_counter += 2
63+
"""
64+
x = x + 1
65+
x = x + 2
66+
out = torch.empty_like(x)
67+
torch.ops.silly.attention(x, x, x, out)
68+
x = out
69+
x = x - 2
70+
x = x - 1
71+
out = torch.empty_like(x)
72+
torch.ops.silly.attention(x, x, x, out)
73+
x = out
74+
x = x + 1
75+
return x
76+
77+
78+
def test_simple_piecewise_compile():
79+
80+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
81+
level=CompilationLevel.PIECEWISE,
82+
use_inductor=False,
83+
use_cudagraph=True,
84+
splitting_ops=["silly.attention"],
85+
cudagraph_copy_inputs=True,
86+
cudagraph_capture_sizes=[1, 2],
87+
))
88+
vllm_config.compilation_config.pass_config.enable_fusion = False
89+
with set_current_vllm_config(vllm_config):
90+
model = SillyModel(vllm_config=vllm_config, prefix="")
91+
92+
inputs = torch.randn(100).npu()
93+
94+
with compilation_counter.expect(
95+
num_graphs_seen=1, # one graph for the model
96+
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
97+
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
98+
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
99+
num_cudagraph_caputured=
100+
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
101+
):
102+
103+
model(inputs)
104+
105+
model(torch.randn(2).npu())
106+
model(torch.randn(1).npu())
107+
108+
input = torch.zeros(2).npu()
109+
global global_counter
110+
global_counter = 0
111+
output = model(input)
112+
assert global_counter == 2
113+
assert torch.allclose(output.cpu(), torch.tensor([3.0, 1.0]))
114+
115+
116+
if __name__ == "__main__":
117+
test_simple_piecewise_compile()

vllm_ascend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# This file is a part of the vllm-ascend project.
1616
#
1717

18+
from torch_npu.contrib import transfer_to_npu # noqa: F401
19+
1820

1921
def register():
2022
"""Register the NPU platform."""

0 commit comments

Comments
 (0)