2
2
// SPDX-License-Identifier: Apache-2.0
3
3
//
4
4
5
- #include " openvino/op/op .hpp"
5
+ #include " pyopenvino/graph/ops/paged_attention_extension .hpp"
6
6
7
+ #include " openvino/op/op.hpp"
7
8
#include " pyopenvino/core/common.hpp"
8
- #include " pyopenvino/graph/ops/paged_attention_extension.hpp"
9
9
10
10
namespace py = pybind11;
11
11
@@ -26,100 +26,129 @@ class PagedAttentionExtension : public ov::op::Op {
26
26
// m_num_kv_heads = value_cache_shape[1];
27
27
// m_head_size = value_cache_shape[2];
28
28
// m_block_size = value_cache_shape[3];
29
- NODE_VALIDATION_CHECK (this ,
30
- value_cache_shape.size () == 4 ,
31
- " Value cache shape must be 4 dims" );
29
+ NODE_VALIDATION_CHECK (this , value_cache_shape.size () == 4 , " Value cache shape must be 4 dims" );
32
30
33
31
// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
34
32
auto key_cache_shape = get_input_partial_shape (3 );
35
33
NODE_VALIDATION_CHECK (this ,
36
- value_cache_shape.size () == 4 ,
37
- // value_cache_shape[0] == key_cache_shape[0] && // num_blocks
38
- // key_cache_shape[1] == m_num_kv_heads &&
39
- // key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
40
- // m_block_size == key_cache_shape[3], // block_size,
41
- " Key cache shape must be 4 dims" );
34
+ value_cache_shape.size () == 4 ,
35
+ // value_cache_shape[0] == key_cache_shape[0] && // num_blocks
36
+ // key_cache_shape[1] == m_num_kv_heads &&
37
+ // key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
38
+ // m_block_size == key_cache_shape[3], // block_size,
39
+ " Key cache shape must be 4 dims" );
42
40
43
41
// query: shape [batch_size, seq_len, num_heads * head_size]
44
42
auto query_type = get_input_element_type (0 );
45
43
auto query_shape = get_input_partial_shape (0 );
46
- NODE_VALIDATION_CHECK (this ,
44
+ NODE_VALIDATION_CHECK (
45
+ this ,
47
46
// query_type.is_real() &&
48
47
query_shape.size () == 3 ,
49
48
// query_shape[2] == m_num_heads * m_head_size,
50
49
" Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. " ,
51
- " Got element type " , query_type, " , shape " , query_shape);
50
+ " Got element type " ,
51
+ query_type,
52
+ " , shape " ,
53
+ query_shape);
52
54
53
55
// key: shape [batch_size, seq_len, num_kv_heads * head_size]
54
56
auto key_type = get_input_element_type (1 );
55
57
auto key_shape = get_input_partial_shape (1 );
56
58
NODE_VALIDATION_CHECK (this ,
57
- // query_type == key_type &&
58
- key_shape.size () == 3 ,
59
- " Key type must be the same as query, shape must be the same as query. "
60
- " Got element type " , key_type, " , shape " , key_shape);
59
+ // query_type == key_type &&
60
+ key_shape.size () == 3 ,
61
+ " Key type must be the same as query, shape must be the same as query. "
62
+ " Got element type " ,
63
+ key_type,
64
+ " , shape " ,
65
+ key_shape);
61
66
62
67
// value: shape [batch_size, seq_len, num_kv_heads * head_size]
63
68
// auto value_type = get_input_element_type(2);
64
69
auto value_shape = get_input_partial_shape (2 );
65
70
66
71
// is_prompt: boolean scalar
67
72
NODE_VALIDATION_CHECK (this ,
68
- // get_input_element_type(5) == ov::element::boolean &&
69
- get_input_shape (5 ) == ov::Shape ({}),
70
- " is_prompt validation failed. " ,
71
- " Got element type " , get_input_element_type (5 ), " , shape " , get_input_shape (5 ));
73
+ // get_input_element_type(5) == ov::element::boolean &&
74
+ get_input_shape (5 ) == ov::Shape ({}),
75
+ " is_prompt validation failed. " ,
76
+ " Got element type " ,
77
+ get_input_element_type (5 ),
78
+ " , shape " ,
79
+ get_input_shape (5 ));
72
80
73
81
// slot_mapping: shape [batch_size, max_context_len]
74
82
auto slot_mapping_shape = get_input_partial_shape (6 );
75
83
NODE_VALIDATION_CHECK (this ,
76
- // get_input_element_type(6) == ov::element::i64 &&
77
- slot_mapping_shape.size () == 2 ,
78
- " slot_mapping validation failed. " ,
79
- " Got element type " , get_input_element_type (6 ), " , shape " , slot_mapping_shape);
84
+ // get_input_element_type(6) == ov::element::i64 &&
85
+ slot_mapping_shape.size () == 2 ,
86
+ " slot_mapping validation failed. " ,
87
+ " Got element type " ,
88
+ get_input_element_type (6 ),
89
+ " , shape " ,
90
+ slot_mapping_shape);
80
91
81
92
// max_context_len: integer scalar
82
93
NODE_VALIDATION_CHECK (this ,
83
- // get_input_element_type(7) == ov::element::i32 &&
84
- get_input_shape (7 ) == ov::Shape ({}),
85
- " max_context_len validation failed. " ,
86
- " Got element type " , get_input_element_type (7 ), " , shape " , get_input_shape (7 ));
94
+ // get_input_element_type(7) == ov::element::i32 &&
95
+ get_input_shape (7 ) == ov::Shape ({}),
96
+ " max_context_len validation failed. " ,
97
+ " Got element type " ,
98
+ get_input_element_type (7 ),
99
+ " , shape " ,
100
+ get_input_shape (7 ));
87
101
88
102
// context_lens: shape [batch_size]
89
103
auto context_lens_shape = get_input_partial_shape (8 );
90
104
NODE_VALIDATION_CHECK (this ,
91
- // get_input_element_type(8) == ov::element::i32 &&
92
- context_lens_shape.size () == 1 ,
93
- " context_lens validation failed. " ,
94
- " Got element type " , get_input_element_type (8 ), " , shape " , context_lens_shape);
105
+ // get_input_element_type(8) == ov::element::i32 &&
106
+ context_lens_shape.size () == 1 ,
107
+ " context_lens validation failed. " ,
108
+ " Got element type " ,
109
+ get_input_element_type (8 ),
110
+ " , shape " ,
111
+ context_lens_shape);
95
112
96
113
// block_tables: shape [batch_size, max_block_per_request]
97
114
NODE_VALIDATION_CHECK (this ,
98
- // get_input_element_type(9) == ov::element::i32 &&
99
- get_input_partial_shape (9 ).size () == 2 ,
100
- " block_tables validation failed. " ,
101
- " Got element type " , get_input_element_type (9 ), " , shape " , get_input_partial_shape (9 ));
115
+ // get_input_element_type(9) == ov::element::i32 &&
116
+ get_input_partial_shape (9 ).size () == 2 ,
117
+ " block_tables validation failed. " ,
118
+ " Got element type " ,
119
+ get_input_element_type (9 ),
120
+ " , shape " ,
121
+ get_input_partial_shape (9 ));
102
122
103
123
// scale: float scalar
104
124
NODE_VALIDATION_CHECK (this ,
105
- // get_input_element_type(10) == ov::element::f32 &&
106
- get_input_shape (10 ) == ov::Shape ({}),
107
- " block_tables validation failed. " ,
108
- " Got element type " , get_input_element_type (10 ), " , shape " , get_input_shape (10 ));
125
+ // get_input_element_type(10) == ov::element::f32 &&
126
+ get_input_shape (10 ) == ov::Shape ({}),
127
+ " block_tables validation failed. " ,
128
+ " Got element type " ,
129
+ get_input_element_type (10 ),
130
+ " , shape " ,
131
+ get_input_shape (10 ));
109
132
110
133
// alibi_slopes: 1D float tensor
111
134
NODE_VALIDATION_CHECK (this ,
112
- // get_input_element_type(11) == ov::element::f32 &&
113
- get_input_partial_shape (11 ).rank ().get_length () == 1 ,
114
- " alibi_slopes should be a 1D float tensor. " ,
115
- " Got element type " , get_input_element_type (11 ), " , shape " , get_input_partial_shape (11 ));
135
+ // get_input_element_type(11) == ov::element::f32 &&
136
+ get_input_partial_shape (11 ).rank ().get_length () == 1 ,
137
+ " alibi_slopes should be a 1D float tensor. " ,
138
+ " Got element type " ,
139
+ get_input_element_type (11 ),
140
+ " , shape " ,
141
+ get_input_partial_shape (11 ));
116
142
117
143
// sliding_window: int scalar
118
144
NODE_VALIDATION_CHECK (this ,
119
- // get_input_element_type(12) == ov::element::i32 &&
120
- get_input_partial_shape (12 ).rank ().get_length () == 0 ,
121
- " sliding_window argument should be an i32 scalar. " ,
122
- " Got element type " , get_input_element_type (12 ), " , shape " , get_input_partial_shape (12 ));
145
+ // get_input_element_type(12) == ov::element::i32 &&
146
+ get_input_partial_shape (12 ).rank ().get_length () == 0 ,
147
+ " sliding_window argument should be an i32 scalar. " ,
148
+ " Got element type " ,
149
+ get_input_element_type (12 ),
150
+ " , shape " ,
151
+ get_input_partial_shape (12 ));
123
152
124
153
set_output_type (0 , query_type, query_shape);
125
154
}
@@ -129,10 +158,13 @@ class PagedAttentionExtension : public ov::op::Op {
129
158
}
130
159
};
131
160
132
- }
161
+ } // namespace
133
162
134
163
void regclass_graph_op_PagedAttentionExtension (py::module m) {
135
- py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls (m, " _PagedAttentionExtension" );
136
- cls.doc () = " Experimental extention for PagedAttention operation. Use with care: no backward compatibility is guaranteed in future releases." ;
164
+ py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls (
165
+ m,
166
+ " _PagedAttentionExtension" );
167
+ cls.doc () = " Experimental extention for PagedAttention operation. Use with care: no backward compatibility is "
168
+ " guaranteed in future releases." ;
137
169
cls.def (py::init<const ov::OutputVector&>());
138
170
}
0 commit comments