|
5 | 5 | #include "pyopenvino/graph/ops/paged_attention_extension.hpp"
|
6 | 6 |
|
7 | 7 | #include "openvino/op/op.hpp"
|
| 8 | +#include "openvino/op/paged_attention.hpp" |
8 | 9 | #include "pyopenvino/core/common.hpp"
|
9 | 10 |
|
10 | 11 | namespace py = pybind11;
|
11 | 12 |
|
12 |
| -namespace { |
13 |
| - |
14 |
| -// This is an experimental operation that is implemented in the plugins. |
15 |
| -// Do not use in user applications, backward compatibility is not guaranteed in future releases. |
16 |
| -class PagedAttentionExtension : public ov::op::Op { |
17 |
| -public: |
18 |
| - OPENVINO_OP("PagedAttentionExtension"); |
19 |
| - |
20 |
| - PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) { |
21 |
| - constructor_validate_and_infer_types(); |
22 |
| - } |
23 |
| - |
24 |
| - void validate_and_infer_types() override { |
25 |
| - auto value_cache_shape = get_input_partial_shape(4); |
26 |
| - // m_num_kv_heads = value_cache_shape[1]; |
27 |
| - // m_head_size = value_cache_shape[2]; |
28 |
| - // m_block_size = value_cache_shape[3]; |
29 |
| - NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims"); |
30 |
| - |
31 |
| - // key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x] |
32 |
| - auto key_cache_shape = get_input_partial_shape(3); |
33 |
| - NODE_VALIDATION_CHECK(this, |
34 |
| - value_cache_shape.size() == 4, |
35 |
| - // value_cache_shape[0] == key_cache_shape[0] && // num_blocks |
36 |
| - // key_cache_shape[1] == m_num_kv_heads && |
37 |
| - // key_cache_shape[2] * key_cache_shape[4] == m_head_size && |
38 |
| - // m_block_size == key_cache_shape[3], // block_size, |
39 |
| - "Key cache shape must be 4 dims"); |
40 |
| - |
41 |
| - // query: shape [batch_size, seq_len, num_heads * head_size] |
42 |
| - auto query_type = get_input_element_type(0); |
43 |
| - auto query_shape = get_input_partial_shape(0); |
44 |
| - NODE_VALIDATION_CHECK( |
45 |
| - this, |
46 |
| - // query_type.is_real() && |
47 |
| - query_shape.size() == 3, |
48 |
| - // query_shape[2] == m_num_heads * m_head_size, |
49 |
| - "Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ", |
50 |
| - "Got element type ", |
51 |
| - query_type, |
52 |
| - ", shape ", |
53 |
| - query_shape); |
54 |
| - |
55 |
| - // key: shape [batch_size, seq_len, num_kv_heads * head_size] |
56 |
| - auto key_type = get_input_element_type(1); |
57 |
| - auto key_shape = get_input_partial_shape(1); |
58 |
| - NODE_VALIDATION_CHECK(this, |
59 |
| - // query_type == key_type && |
60 |
| - key_shape.size() == 3, |
61 |
| - "Key type must be the same as query, shape must be the same as query. " |
62 |
| - "Got element type ", |
63 |
| - key_type, |
64 |
| - ", shape ", |
65 |
| - key_shape); |
66 |
| - |
67 |
| - // value: shape [batch_size, seq_len, num_kv_heads * head_size] |
68 |
| - // auto value_type = get_input_element_type(2); |
69 |
| - auto value_shape = get_input_partial_shape(2); |
70 |
| - |
71 |
| - // is_prompt: boolean scalar |
72 |
| - NODE_VALIDATION_CHECK(this, |
73 |
| - // get_input_element_type(5) == ov::element::boolean && |
74 |
| - get_input_shape(5) == ov::Shape({}), |
75 |
| - "is_prompt validation failed. ", |
76 |
| - "Got element type ", |
77 |
| - get_input_element_type(5), |
78 |
| - ", shape ", |
79 |
| - get_input_shape(5)); |
80 |
| - |
81 |
| - // slot_mapping: shape [batch_size, max_context_len] |
82 |
| - auto slot_mapping_shape = get_input_partial_shape(6); |
83 |
| - NODE_VALIDATION_CHECK(this, |
84 |
| - // get_input_element_type(6) == ov::element::i64 && |
85 |
| - slot_mapping_shape.size() == 2, |
86 |
| - "slot_mapping validation failed. ", |
87 |
| - "Got element type ", |
88 |
| - get_input_element_type(6), |
89 |
| - ", shape ", |
90 |
| - slot_mapping_shape); |
91 |
| - |
92 |
| - // max_context_len: integer scalar |
93 |
| - NODE_VALIDATION_CHECK(this, |
94 |
| - // get_input_element_type(7) == ov::element::i32 && |
95 |
| - get_input_shape(7) == ov::Shape({}), |
96 |
| - "max_context_len validation failed. ", |
97 |
| - "Got element type ", |
98 |
| - get_input_element_type(7), |
99 |
| - ", shape ", |
100 |
| - get_input_shape(7)); |
101 |
| - |
102 |
| - // context_lens: shape [batch_size] |
103 |
| - auto context_lens_shape = get_input_partial_shape(8); |
104 |
| - NODE_VALIDATION_CHECK(this, |
105 |
| - // get_input_element_type(8) == ov::element::i32 && |
106 |
| - context_lens_shape.size() == 1, |
107 |
| - "context_lens validation failed. ", |
108 |
| - "Got element type ", |
109 |
| - get_input_element_type(8), |
110 |
| - ", shape ", |
111 |
| - context_lens_shape); |
112 |
| - |
113 |
| - // block_tables: shape [batch_size, max_block_per_request] |
114 |
| - NODE_VALIDATION_CHECK(this, |
115 |
| - // get_input_element_type(9) == ov::element::i32 && |
116 |
| - get_input_partial_shape(9).size() == 2, |
117 |
| - "block_tables validation failed. ", |
118 |
| - "Got element type ", |
119 |
| - get_input_element_type(9), |
120 |
| - ", shape ", |
121 |
| - get_input_partial_shape(9)); |
122 |
| - |
123 |
| - // scale: float scalar |
124 |
| - NODE_VALIDATION_CHECK(this, |
125 |
| - // get_input_element_type(10) == ov::element::f32 && |
126 |
| - get_input_shape(10) == ov::Shape({}), |
127 |
| - "block_tables validation failed. ", |
128 |
| - "Got element type ", |
129 |
| - get_input_element_type(10), |
130 |
| - ", shape ", |
131 |
| - get_input_shape(10)); |
132 |
| - |
133 |
| - // alibi_slopes: 1D float tensor |
134 |
| - NODE_VALIDATION_CHECK(this, |
135 |
| - // get_input_element_type(11) == ov::element::f32 && |
136 |
| - get_input_partial_shape(11).rank().get_length() == 1, |
137 |
| - "alibi_slopes should be a 1D float tensor. ", |
138 |
| - "Got element type ", |
139 |
| - get_input_element_type(11), |
140 |
| - ", shape ", |
141 |
| - get_input_partial_shape(11)); |
142 |
| - |
143 |
| - // sliding_window: int scalar |
144 |
| - NODE_VALIDATION_CHECK(this, |
145 |
| - // get_input_element_type(12) == ov::element::i32 && |
146 |
| - get_input_partial_shape(12).rank().get_length() == 0, |
147 |
| - "sliding_window argument should be an i32 scalar. ", |
148 |
| - "Got element type ", |
149 |
| - get_input_element_type(12), |
150 |
| - ", shape ", |
151 |
| - get_input_partial_shape(12)); |
152 |
| - |
153 |
| - set_output_type(0, query_type, query_shape); |
154 |
| - } |
155 |
| - |
156 |
| - std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override { |
157 |
| - return std::make_shared<PagedAttentionExtension>(new_args); |
158 |
| - } |
159 |
| -}; |
160 |
| - |
161 |
| -} // namespace |
162 |
| - |
163 | 13 | void regclass_graph_op_PagedAttentionExtension(py::module m) {
|
| 14 | + using ov::op::PagedAttentionExtension; |
164 | 15 | py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls(
|
165 | 16 | m,
|
166 | 17 | "_PagedAttentionExtension");
|
|
0 commit comments