Skip to content

Commit 704b722

Browse files
committed
Code style fixes
1 parent 2332579 commit 704b722

File tree

1 file changed

+84
-52
lines changed

1 file changed

+84
-52
lines changed

src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp

+84-52
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

5-
#include "openvino/op/op.hpp"
5+
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"
66

7+
#include "openvino/op/op.hpp"
78
#include "pyopenvino/core/common.hpp"
8-
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"
99

1010
namespace py = pybind11;
1111

@@ -26,100 +26,129 @@ class PagedAttentionExtension : public ov::op::Op {
2626
// m_num_kv_heads = value_cache_shape[1];
2727
// m_head_size = value_cache_shape[2];
2828
// m_block_size = value_cache_shape[3];
29-
NODE_VALIDATION_CHECK(this,
30-
value_cache_shape.size() == 4,
31-
"Value cache shape must be 4 dims");
29+
NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims");
3230

3331
// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
3432
auto key_cache_shape = get_input_partial_shape(3);
3533
NODE_VALIDATION_CHECK(this,
36-
value_cache_shape.size() == 4,
37-
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
38-
// key_cache_shape[1] == m_num_kv_heads &&
39-
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
40-
// m_block_size == key_cache_shape[3], // block_size,
41-
"Key cache shape must be 4 dims");
34+
value_cache_shape.size() == 4,
35+
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
36+
// key_cache_shape[1] == m_num_kv_heads &&
37+
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
38+
// m_block_size == key_cache_shape[3], // block_size,
39+
"Key cache shape must be 4 dims");
4240

4341
// query: shape [batch_size, seq_len, num_heads * head_size]
4442
auto query_type = get_input_element_type(0);
4543
auto query_shape = get_input_partial_shape(0);
46-
NODE_VALIDATION_CHECK(this,
44+
NODE_VALIDATION_CHECK(
45+
this,
4746
// query_type.is_real() &&
4847
query_shape.size() == 3,
4948
// query_shape[2] == m_num_heads * m_head_size,
5049
"Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ",
51-
"Got element type ", query_type, ", shape ", query_shape);
50+
"Got element type ",
51+
query_type,
52+
", shape ",
53+
query_shape);
5254

5355
// key: shape [batch_size, seq_len, num_kv_heads * head_size]
5456
auto key_type = get_input_element_type(1);
5557
auto key_shape = get_input_partial_shape(1);
5658
NODE_VALIDATION_CHECK(this,
57-
// query_type == key_type &&
58-
key_shape.size() == 3,
59-
"Key type must be the same as query, shape must be the same as query. "
60-
"Got element type ", key_type, ", shape ", key_shape);
59+
// query_type == key_type &&
60+
key_shape.size() == 3,
61+
"Key type must be the same as query, shape must be the same as query. "
62+
"Got element type ",
63+
key_type,
64+
", shape ",
65+
key_shape);
6166

6267
// value: shape [batch_size, seq_len, num_kv_heads * head_size]
6368
// auto value_type = get_input_element_type(2);
6469
auto value_shape = get_input_partial_shape(2);
6570

6671
// is_prompt: boolean scalar
6772
NODE_VALIDATION_CHECK(this,
68-
// get_input_element_type(5) == ov::element::boolean &&
69-
get_input_shape(5) == ov::Shape({}),
70-
"is_prompt validation failed. ",
71-
"Got element type ", get_input_element_type(5), ", shape ", get_input_shape(5));
73+
// get_input_element_type(5) == ov::element::boolean &&
74+
get_input_shape(5) == ov::Shape({}),
75+
"is_prompt validation failed. ",
76+
"Got element type ",
77+
get_input_element_type(5),
78+
", shape ",
79+
get_input_shape(5));
7280

7381
// slot_mapping: shape [batch_size, max_context_len]
7482
auto slot_mapping_shape = get_input_partial_shape(6);
7583
NODE_VALIDATION_CHECK(this,
76-
// get_input_element_type(6) == ov::element::i64 &&
77-
slot_mapping_shape.size() == 2,
78-
"slot_mapping validation failed. ",
79-
"Got element type ", get_input_element_type(6), ", shape ", slot_mapping_shape);
84+
// get_input_element_type(6) == ov::element::i64 &&
85+
slot_mapping_shape.size() == 2,
86+
"slot_mapping validation failed. ",
87+
"Got element type ",
88+
get_input_element_type(6),
89+
", shape ",
90+
slot_mapping_shape);
8091

8192
// max_context_len: integer scalar
8293
NODE_VALIDATION_CHECK(this,
83-
// get_input_element_type(7) == ov::element::i32 &&
84-
get_input_shape(7) == ov::Shape({}),
85-
"max_context_len validation failed. ",
86-
"Got element type ", get_input_element_type(7), ", shape ", get_input_shape(7));
94+
// get_input_element_type(7) == ov::element::i32 &&
95+
get_input_shape(7) == ov::Shape({}),
96+
"max_context_len validation failed. ",
97+
"Got element type ",
98+
get_input_element_type(7),
99+
", shape ",
100+
get_input_shape(7));
87101

88102
// context_lens: shape [batch_size]
89103
auto context_lens_shape = get_input_partial_shape(8);
90104
NODE_VALIDATION_CHECK(this,
91-
// get_input_element_type(8) == ov::element::i32 &&
92-
context_lens_shape.size() == 1,
93-
"context_lens validation failed. ",
94-
"Got element type ", get_input_element_type(8), ", shape ", context_lens_shape);
105+
// get_input_element_type(8) == ov::element::i32 &&
106+
context_lens_shape.size() == 1,
107+
"context_lens validation failed. ",
108+
"Got element type ",
109+
get_input_element_type(8),
110+
", shape ",
111+
context_lens_shape);
95112

96113
// block_tables: shape [batch_size, max_block_per_request]
97114
NODE_VALIDATION_CHECK(this,
98-
// get_input_element_type(9) == ov::element::i32 &&
99-
get_input_partial_shape(9).size() == 2,
100-
"block_tables validation failed. ",
101-
"Got element type ", get_input_element_type(9), ", shape ", get_input_partial_shape(9));
115+
// get_input_element_type(9) == ov::element::i32 &&
116+
get_input_partial_shape(9).size() == 2,
117+
"block_tables validation failed. ",
118+
"Got element type ",
119+
get_input_element_type(9),
120+
", shape ",
121+
get_input_partial_shape(9));
102122

103123
// scale: float scalar
104124
NODE_VALIDATION_CHECK(this,
105-
// get_input_element_type(10) == ov::element::f32 &&
106-
get_input_shape(10) == ov::Shape({}),
107-
"block_tables validation failed. ",
108-
"Got element type ", get_input_element_type(10), ", shape ", get_input_shape(10));
125+
// get_input_element_type(10) == ov::element::f32 &&
126+
get_input_shape(10) == ov::Shape({}),
127+
"block_tables validation failed. ",
128+
"Got element type ",
129+
get_input_element_type(10),
130+
", shape ",
131+
get_input_shape(10));
109132

110133
// alibi_slopes: 1D float tensor
111134
NODE_VALIDATION_CHECK(this,
112-
// get_input_element_type(11) == ov::element::f32 &&
113-
get_input_partial_shape(11).rank().get_length() == 1,
114-
"alibi_slopes should be a 1D float tensor. ",
115-
"Got element type ", get_input_element_type(11), ", shape ", get_input_partial_shape(11));
135+
// get_input_element_type(11) == ov::element::f32 &&
136+
get_input_partial_shape(11).rank().get_length() == 1,
137+
"alibi_slopes should be a 1D float tensor. ",
138+
"Got element type ",
139+
get_input_element_type(11),
140+
", shape ",
141+
get_input_partial_shape(11));
116142

117143
// sliding_window: int scalar
118144
NODE_VALIDATION_CHECK(this,
119-
// get_input_element_type(12) == ov::element::i32 &&
120-
get_input_partial_shape(12).rank().get_length() == 0,
121-
"sliding_window argument should be an i32 scalar. ",
122-
"Got element type ", get_input_element_type(12), ", shape ", get_input_partial_shape(12));
145+
// get_input_element_type(12) == ov::element::i32 &&
146+
get_input_partial_shape(12).rank().get_length() == 0,
147+
"sliding_window argument should be an i32 scalar. ",
148+
"Got element type ",
149+
get_input_element_type(12),
150+
", shape ",
151+
get_input_partial_shape(12));
123152

124153
set_output_type(0, query_type, query_shape);
125154
}
@@ -129,10 +158,13 @@ class PagedAttentionExtension : public ov::op::Op {
129158
}
130159
};
131160

132-
}
161+
} // namespace
133162

134163
void regclass_graph_op_PagedAttentionExtension(py::module m) {
135-
py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls(m, "_PagedAttentionExtension");
136-
cls.doc() = "Experimental extention for PagedAttention operation. Use with care: no backward compatibility is guaranteed in future releases.";
164+
py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls(
165+
m,
166+
"_PagedAttentionExtension");
167+
cls.doc() = "Experimental extention for PagedAttention operation. Use with care: no backward compatibility is "
168+
"guaranteed in future releases.";
137169
cls.def(py::init<const ov::OutputVector&>());
138170
}

0 commit comments

Comments
 (0)