Skip to content

Commit cc94eb7

Browse files
committed
Upgraded PagedAttention op to the new version
1 parent 7160874 commit cc94eb7

File tree

1 file changed

+96
-94
lines changed

1 file changed

+96
-94
lines changed

src/core/src/op/paged_attention.cpp

+96-94
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

5+
#include "itt.hpp"
56
#include "openvino/op/paged_attention.hpp"
6-
77
#include "openvino/op/op.hpp"
88

99
namespace ov {
@@ -14,123 +14,125 @@ PagedAttentionExtension::PagedAttentionExtension(const ov::OutputVector& args) :
1414
}
1515

1616
void PagedAttentionExtension::validate_and_infer_types() {
17-
// const auto& value_cache_shape = get_input_partial_shape(4);
18-
// m_num_kv_heads = value_cache_shape[1];
19-
// m_head_size = value_cache_shape[2];
20-
// m_block_size = value_cache_shape[3];
21-
// Do not check shapes for cache K and cache V inputs, because they are hardware dependent
17+
OV_OP_SCOPE(PagedAttentionExtension_validate_and_infer_types);
2218

23-
// query: shape [batch_size, seq_len, num_heads * head_size]
24-
const auto& query_type = get_input_element_type(0);
25-
const auto& query_shape = get_input_partial_shape(0);
26-
NODE_VALIDATION_CHECK(this,
27-
// query_type.is_real() &&
28-
query_shape.size() == 3,
29-
// query_shape[2] == m_num_heads * m_head_size,
30-
"Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ",
31-
"Got element type ",
32-
query_type,
33-
", shape ",
34-
query_shape);
19+
NODE_VALIDATION_CHECK(this,
20+
get_input_size() == 13,
21+
"PagedAttensionExtension expects 13 inputs, but it has ",
22+
get_input_size());
3523

36-
// key: shape [batch_size, seq_len, num_kv_heads * head_size]
37-
const auto& key_type = get_input_element_type(1);
38-
const auto& key_shape = get_input_partial_shape(1);
39-
NODE_VALIDATION_CHECK(this,
40-
// query_type == key_type &&
41-
key_shape.size() == 3,
42-
"Key type must be the same as query, shape must be the same as query. "
43-
"Got element type ",
44-
key_type,
45-
", shape ",
46-
key_shape);
24+
NODE_VALIDATION_CHECK(this,
25+
get_input_partial_shape(0).rank().is_dynamic() || get_input_partial_shape(0).rank().get_length() == 2,
26+
"Rank of `query` input should be 2, but it is ",
27+
get_input_partial_shape(0).rank().get_length(),
28+
".");
29+
NODE_VALIDATION_CHECK(this,
30+
get_input_partial_shape(1).rank().is_dynamic() || get_input_partial_shape(1).rank().get_length() == 2,
31+
"Rank of `key` input should be 2, but it is ",
32+
get_input_partial_shape(1).rank().get_length(),
33+
".");
34+
NODE_VALIDATION_CHECK(this,
35+
get_input_partial_shape(2).rank().is_dynamic() || get_input_partial_shape(2).rank().get_length() == 2,
36+
"Rank of `value` input should be 2, but it is ",
37+
get_input_partial_shape(2).rank().get_length(),
38+
".");
4739

48-
// value: shape [batch_size, seq_len, num_kv_heads * head_size]
49-
// auto value_type = get_input_element_type(2);
40+
NODE_VALIDATION_CHECK(this,
41+
get_input_partial_shape(3).rank().is_dynamic() || get_input_partial_shape(3).rank().get_length() >= 2,
42+
"Rank of `key_cache` input should be at least 2, but it is ",
43+
get_input_partial_shape(3).rank().get_length(),
44+
".");
45+
NODE_VALIDATION_CHECK(this,
46+
get_input_partial_shape(4).rank().is_dynamic() || get_input_partial_shape(4).rank().get_length() >= 2,
47+
"Rank of `value_cache` input should be at least 2, but it is ",
48+
get_input_partial_shape(4).rank().get_length(),
49+
".");
5050

51-
// is_prompt: boolean scalar
5251
NODE_VALIDATION_CHECK(this,
53-
// get_input_element_type(5) == ov::element::boolean &&
54-
get_input_shape(5) == ov::Shape({}),
55-
"is_prompt validation failed. ",
56-
"Got element type ",
52+
get_input_partial_shape(5).rank().is_dynamic() || get_input_partial_shape(5).rank().get_length() == 1,
53+
"Rank of `context_lens` input should be 1, but it is ",
54+
get_input_partial_shape(5).rank().get_length(),
55+
".");
56+
NODE_VALIDATION_CHECK(this,
57+
get_input_element_type(5).is_dynamic() || get_input_element_type(5) == element::i32,
58+
"Element type of `context_lens` input should be i32, but it is ",
5759
get_input_element_type(5),
58-
", shape ",
59-
get_input_shape(5));
60-
61-
// slot_mapping: shape [batch_size, max_context_len]
62-
const auto& slot_mapping_shape = get_input_partial_shape(6);
60+
".");
61+
NODE_VALIDATION_CHECK(this,
62+
get_input_partial_shape(6).rank().is_dynamic() || get_input_partial_shape(6).rank().get_length() == 1,
63+
"Rank of `subsequence_begins` input should be 1, but it is ",
64+
get_input_partial_shape(6).rank().get_length(),
65+
".");
6366
NODE_VALIDATION_CHECK(this,
64-
// get_input_element_type(6) == ov::element::i64 &&
65-
slot_mapping_shape.size() == 2,
66-
"slot_mapping validation failed. ",
67-
"Got element type ",
67+
get_input_element_type(6).is_dynamic() || get_input_element_type(6) == element::i32,
68+
"Element type of `subsequence_begins` input should be i32, but it is ",
6869
get_input_element_type(6),
69-
", shape ",
70-
slot_mapping_shape);
70+
".");
7171

72-
// max_context_len: integer scalar
7372
NODE_VALIDATION_CHECK(this,
74-
// get_input_element_type(7) == ov::element::i32 &&
75-
get_input_shape(7) == ov::Shape({}),
76-
"max_context_len validation failed. ",
77-
"Got element type ",
73+
get_input_partial_shape(7).rank().is_dynamic() || get_input_partial_shape(7).rank().get_length() == 1,
74+
"Rank of `block_indices` input should be 1, but it is ",
75+
get_input_partial_shape(7).rank().get_length(),
76+
".");
77+
NODE_VALIDATION_CHECK(this,
78+
get_input_element_type(7).is_dynamic() || get_input_element_type(7) == element::i32,
79+
"Element type of `block_indices` input should be i32, but it is ",
7880
get_input_element_type(7),
79-
", shape ",
80-
get_input_shape(7));
81-
82-
// context_lens: shape [batch_size]
83-
const auto& context_lens_shape = get_input_partial_shape(8);
81+
".");
82+
NODE_VALIDATION_CHECK(this,
83+
get_input_partial_shape(8).rank().is_dynamic() || get_input_partial_shape(8).rank().get_length() == 1,
84+
"Rank of `block_indices_begins` input should be 1, but it is ",
85+
get_input_partial_shape(8).rank().get_length(),
86+
".");
8487
NODE_VALIDATION_CHECK(this,
85-
// get_input_element_type(8) == ov::element::i32 &&
86-
context_lens_shape.size() == 1,
87-
"context_lens validation failed. ",
88-
"Got element type ",
88+
get_input_element_type(8).is_dynamic() || get_input_element_type(8) == element::i32,
89+
"Element type of `block_indices_begins` input should be i32, but it is ",
8990
get_input_element_type(8),
90-
", shape ",
91-
context_lens_shape);
91+
".");
9292

93-
// block_tables: shape [batch_size, max_block_per_request]
9493
NODE_VALIDATION_CHECK(this,
95-
// get_input_element_type(9) == ov::element::i32 &&
96-
get_input_partial_shape(9).size() == 2,
97-
"block_tables validation failed. ",
98-
"Got element type ",
94+
get_input_partial_shape(9).rank().is_dynamic() || get_input_partial_shape(9).rank().get_length() == 0,
95+
"Input `scale` should be a scalar but it has rank ",
96+
get_input_partial_shape(9).rank().get_length(),
97+
".");
98+
NODE_VALIDATION_CHECK(this,
99+
get_input_element_type(9).is_dynamic() || get_input_element_type(9).is_real(),
100+
"Element type of `scale` input should be a floating type, but it is ",
99101
get_input_element_type(9),
100-
", shape ",
101-
get_input_partial_shape(9));
102-
103-
// scale: float scalar
102+
".");
104103
NODE_VALIDATION_CHECK(this,
105-
// get_input_element_type(10) == ov::element::f32 &&
106-
get_input_shape(10) == ov::Shape({}),
107-
"block_tables validation failed. ",
108-
"Got element type ",
104+
get_input_partial_shape(10).rank().is_dynamic() || get_input_partial_shape(10).rank().get_length() == 0,
105+
"Input `sliding_window` should be a scalar but it has rank ",
106+
get_input_partial_shape(10).rank().get_length(),
107+
".");
108+
NODE_VALIDATION_CHECK(this,
109+
get_input_element_type(10).is_dynamic() || get_input_element_type(10) == element::i32,
110+
"Element type of `sliding_window` input should be i32, but it is ",
109111
get_input_element_type(10),
110-
", shape ",
111-
get_input_shape(10));
112+
".");
112113

113-
// alibi_slopes: 1D float tensor
114114
NODE_VALIDATION_CHECK(this,
115-
// get_input_element_type(11) == ov::element::f32 &&
116-
get_input_partial_shape(11).rank().get_length() == 1,
117-
"alibi_slopes should be a 1D float tensor. ",
118-
"Got element type ",
115+
get_input_partial_shape(11).rank().is_dynamic() || get_input_partial_shape(11).rank().get_length() == 1,
116+
"Rank of `alibi_slopes` input should be 1, but it is ",
117+
get_input_partial_shape(11).rank().get_length(),
118+
".");
119+
NODE_VALIDATION_CHECK(this,
120+
get_input_element_type(11).is_dynamic() || get_input_element_type(11).is_real(),
121+
"Element type of `alibi_slopes` input should be a floating type, but it is ",
119122
get_input_element_type(11),
120-
", shape ",
121-
get_input_partial_shape(11));
122-
123-
// sliding_window: int scalar
123+
".");
124+
NODE_VALIDATION_CHECK(this,
125+
get_input_partial_shape(12).rank().is_dynamic() || get_input_partial_shape(12).rank().get_length() == 0,
126+
"Input `max_context_len` should be a scalar but it has rank ",
127+
get_input_partial_shape(12).rank().get_length(),
128+
".");
124129
NODE_VALIDATION_CHECK(this,
125-
// get_input_element_type(12) == ov::element::i32 &&
126-
get_input_partial_shape(12).rank().get_length() == 0,
127-
"sliding_window argument should be an i32 scalar. ",
128-
"Got element type ",
130+
get_input_element_type(12).is_dynamic() || get_input_element_type(12) == element::i32,
131+
"Element type of `max_context_len` input should be i32, but it is ",
129132
get_input_element_type(12),
130-
", shape ",
131-
get_input_partial_shape(12));
133+
".");
132134

133-
set_output_type(0, query_type, query_shape);
135+
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
134136
}
135137

136138
std::shared_ptr<ov::Node> PagedAttentionExtension::clone_with_new_inputs(const ov::OutputVector& new_args) const {

0 commit comments

Comments
 (0)