2
2
// SPDX-License-Identifier: Apache-2.0
3
3
//
4
4
5
+ #include " itt.hpp"
5
6
#include " openvino/op/paged_attention.hpp"
6
-
7
7
#include " openvino/op/op.hpp"
8
8
9
9
namespace ov {
@@ -14,123 +14,125 @@ PagedAttentionExtension::PagedAttentionExtension(const ov::OutputVector& args) :
14
14
}
15
15
16
16
void PagedAttentionExtension::validate_and_infer_types () {
17
- // const auto& value_cache_shape = get_input_partial_shape(4);
18
- // m_num_kv_heads = value_cache_shape[1];
19
- // m_head_size = value_cache_shape[2];
20
- // m_block_size = value_cache_shape[3];
21
- // Do not check shapes for cache K and cache V inputs, because they are hardware dependent
17
+ OV_OP_SCOPE (PagedAttentionExtension_validate_and_infer_types);
22
18
23
- // query: shape [batch_size, seq_len, num_heads * head_size]
24
- const auto & query_type = get_input_element_type (0 );
25
- const auto & query_shape = get_input_partial_shape (0 );
26
- NODE_VALIDATION_CHECK (this ,
27
- // query_type.is_real() &&
28
- query_shape.size () == 3 ,
29
- // query_shape[2] == m_num_heads * m_head_size,
30
- " Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. " ,
31
- " Got element type " ,
32
- query_type,
33
- " , shape " ,
34
- query_shape);
19
+ NODE_VALIDATION_CHECK (this ,
20
+ get_input_size () == 13 ,
21
+ " PagedAttensionExtension expects 13 inputs, but it has " ,
22
+ get_input_size ());
35
23
36
- // key: shape [batch_size, seq_len, num_kv_heads * head_size]
37
- const auto & key_type = get_input_element_type (1 );
38
- const auto & key_shape = get_input_partial_shape (1 );
39
- NODE_VALIDATION_CHECK (this ,
40
- // query_type == key_type &&
41
- key_shape.size () == 3 ,
42
- " Key type must be the same as query, shape must be the same as query. "
43
- " Got element type " ,
44
- key_type,
45
- " , shape " ,
46
- key_shape);
24
+ NODE_VALIDATION_CHECK (this ,
25
+ get_input_partial_shape (0 ).rank ().is_dynamic () || get_input_partial_shape (0 ).rank ().get_length () == 2 ,
26
+ " Rank of `query` input should be 2, but it is " ,
27
+ get_input_partial_shape (0 ).rank ().get_length (),
28
+ " ." );
29
+ NODE_VALIDATION_CHECK (this ,
30
+ get_input_partial_shape (1 ).rank ().is_dynamic () || get_input_partial_shape (1 ).rank ().get_length () == 2 ,
31
+ " Rank of `key` input should be 2, but it is " ,
32
+ get_input_partial_shape (1 ).rank ().get_length (),
33
+ " ." );
34
+ NODE_VALIDATION_CHECK (this ,
35
+ get_input_partial_shape (2 ).rank ().is_dynamic () || get_input_partial_shape (2 ).rank ().get_length () == 2 ,
36
+ " Rank of `value` input should be 2, but it is " ,
37
+ get_input_partial_shape (2 ).rank ().get_length (),
38
+ " ." );
47
39
48
- // value: shape [batch_size, seq_len, num_kv_heads * head_size]
49
- // auto value_type = get_input_element_type(2);
40
+ NODE_VALIDATION_CHECK (this ,
41
+ get_input_partial_shape (3 ).rank ().is_dynamic () || get_input_partial_shape (3 ).rank ().get_length () >= 2 ,
42
+ " Rank of `key_cache` input should be at least 2, but it is " ,
43
+ get_input_partial_shape (3 ).rank ().get_length (),
44
+ " ." );
45
+ NODE_VALIDATION_CHECK (this ,
46
+ get_input_partial_shape (4 ).rank ().is_dynamic () || get_input_partial_shape (4 ).rank ().get_length () >= 2 ,
47
+ " Rank of `value_cache` input should be at least 2, but it is " ,
48
+ get_input_partial_shape (4 ).rank ().get_length (),
49
+ " ." );
50
50
51
- // is_prompt: boolean scalar
52
51
NODE_VALIDATION_CHECK (this ,
53
- // get_input_element_type(5) == ov::element::boolean &&
54
- get_input_shape (5 ) == ov::Shape ({}),
55
- " is_prompt validation failed. " ,
56
- " Got element type " ,
52
+ get_input_partial_shape (5 ).rank ().is_dynamic () || get_input_partial_shape (5 ).rank ().get_length () == 1 ,
53
+ " Rank of `context_lens` input should be 1, but it is " ,
54
+ get_input_partial_shape (5 ).rank ().get_length (),
55
+ " ." );
56
+ NODE_VALIDATION_CHECK (this ,
57
+ get_input_element_type (5 ).is_dynamic () || get_input_element_type (5 ) == element::i32,
58
+ " Element type of `context_lens` input should be i32, but it is " ,
57
59
get_input_element_type (5 ),
58
- " , shape " ,
59
- get_input_shape (5 ));
60
-
61
- // slot_mapping: shape [batch_size, max_context_len]
62
- const auto & slot_mapping_shape = get_input_partial_shape (6 );
60
+ " ." );
61
+ NODE_VALIDATION_CHECK (this ,
62
+ get_input_partial_shape (6 ).rank ().is_dynamic () || get_input_partial_shape (6 ).rank ().get_length () == 1 ,
63
+ " Rank of `subsequence_begins` input should be 1, but it is " ,
64
+ get_input_partial_shape (6 ).rank ().get_length (),
65
+ " ." );
63
66
NODE_VALIDATION_CHECK (this ,
64
- // get_input_element_type(6) == ov::element::i64 &&
65
- slot_mapping_shape.size () == 2 ,
66
- " slot_mapping validation failed. " ,
67
- " Got element type " ,
67
+ get_input_element_type (6 ).is_dynamic () || get_input_element_type (6 ) == element::i32,
68
+ " Element type of `subsequence_begins` input should be i32, but it is " ,
68
69
get_input_element_type (6 ),
69
- " , shape " ,
70
- slot_mapping_shape);
70
+ " ." );
71
71
72
- // max_context_len: integer scalar
73
72
NODE_VALIDATION_CHECK (this ,
74
- // get_input_element_type(7) == ov::element::i32 &&
75
- get_input_shape (7 ) == ov::Shape ({}),
76
- " max_context_len validation failed. " ,
77
- " Got element type " ,
73
+ get_input_partial_shape (7 ).rank ().is_dynamic () || get_input_partial_shape (7 ).rank ().get_length () == 1 ,
74
+ " Rank of `block_indices` input should be 1, but it is " ,
75
+ get_input_partial_shape (7 ).rank ().get_length (),
76
+ " ." );
77
+ NODE_VALIDATION_CHECK (this ,
78
+ get_input_element_type (7 ).is_dynamic () || get_input_element_type (7 ) == element::i32,
79
+ " Element type of `block_indices` input should be i32, but it is " ,
78
80
get_input_element_type (7 ),
79
- " , shape " ,
80
- get_input_shape (7 ));
81
-
82
- // context_lens: shape [batch_size]
83
- const auto & context_lens_shape = get_input_partial_shape (8 );
81
+ " ." );
82
+ NODE_VALIDATION_CHECK (this ,
83
+ get_input_partial_shape (8 ).rank ().is_dynamic () || get_input_partial_shape (8 ).rank ().get_length () == 1 ,
84
+ " Rank of `block_indices_begins` input should be 1, but it is " ,
85
+ get_input_partial_shape (8 ).rank ().get_length (),
86
+ " ." );
84
87
NODE_VALIDATION_CHECK (this ,
85
- // get_input_element_type(8) == ov::element::i32 &&
86
- context_lens_shape.size () == 1 ,
87
- " context_lens validation failed. " ,
88
- " Got element type " ,
88
+ get_input_element_type (8 ).is_dynamic () || get_input_element_type (8 ) == element::i32,
89
+ " Element type of `block_indices_begins` input should be i32, but it is " ,
89
90
get_input_element_type (8 ),
90
- " , shape " ,
91
- context_lens_shape);
91
+ " ." );
92
92
93
- // block_tables: shape [batch_size, max_block_per_request]
94
93
NODE_VALIDATION_CHECK (this ,
95
- // get_input_element_type(9) == ov::element::i32 &&
96
- get_input_partial_shape (9 ).size () == 2 ,
97
- " block_tables validation failed. " ,
98
- " Got element type " ,
94
+ get_input_partial_shape (9 ).rank ().is_dynamic () || get_input_partial_shape (9 ).rank ().get_length () == 0 ,
95
+ " Input `scale` should be a scalar but it has rank " ,
96
+ get_input_partial_shape (9 ).rank ().get_length (),
97
+ " ." );
98
+ NODE_VALIDATION_CHECK (this ,
99
+ get_input_element_type (9 ).is_dynamic () || get_input_element_type (9 ).is_real (),
100
+ " Element type of `scale` input should be a floating type, but it is " ,
99
101
get_input_element_type (9 ),
100
- " , shape " ,
101
- get_input_partial_shape (9 ));
102
-
103
- // scale: float scalar
102
+ " ." );
104
103
NODE_VALIDATION_CHECK (this ,
105
- // get_input_element_type(10) == ov::element::f32 &&
106
- get_input_shape (10 ) == ov::Shape ({}),
107
- " block_tables validation failed. " ,
108
- " Got element type " ,
104
+ get_input_partial_shape (10 ).rank ().is_dynamic () || get_input_partial_shape (10 ).rank ().get_length () == 0 ,
105
+ " Input `sliding_window` should be a scalar but it has rank " ,
106
+ get_input_partial_shape (10 ).rank ().get_length (),
107
+ " ." );
108
+ NODE_VALIDATION_CHECK (this ,
109
+ get_input_element_type (10 ).is_dynamic () || get_input_element_type (10 ) == element::i32,
110
+ " Element type of `sliding_window` input should be i32, but it is " ,
109
111
get_input_element_type (10 ),
110
- " , shape " ,
111
- get_input_shape (10 ));
112
+ " ." );
112
113
113
- // alibi_slopes: 1D float tensor
114
114
NODE_VALIDATION_CHECK (this ,
115
- // get_input_element_type(11) == ov::element::f32 &&
116
- get_input_partial_shape (11 ).rank ().get_length () == 1 ,
117
- " alibi_slopes should be a 1D float tensor. " ,
118
- " Got element type " ,
115
+ get_input_partial_shape (11 ).rank ().is_dynamic () || get_input_partial_shape (11 ).rank ().get_length () == 1 ,
116
+ " Rank of `alibi_slopes` input should be 1, but it is " ,
117
+ get_input_partial_shape (11 ).rank ().get_length (),
118
+ " ." );
119
+ NODE_VALIDATION_CHECK (this ,
120
+ get_input_element_type (11 ).is_dynamic () || get_input_element_type (11 ).is_real (),
121
+ " Element type of `alibi_slopes` input should be a floating type, but it is " ,
119
122
get_input_element_type (11 ),
120
- " , shape " ,
121
- get_input_partial_shape (11 ));
122
-
123
- // sliding_window: int scalar
123
+ " ." );
124
+ NODE_VALIDATION_CHECK (this ,
125
+ get_input_partial_shape (12 ).rank ().is_dynamic () || get_input_partial_shape (12 ).rank ().get_length () == 0 ,
126
+ " Input `max_context_len` should be a scalar but it has rank " ,
127
+ get_input_partial_shape (12 ).rank ().get_length (),
128
+ " ." );
124
129
NODE_VALIDATION_CHECK (this ,
125
- // get_input_element_type(12) == ov::element::i32 &&
126
- get_input_partial_shape (12 ).rank ().get_length () == 0 ,
127
- " sliding_window argument should be an i32 scalar. " ,
128
- " Got element type " ,
130
+ get_input_element_type (12 ).is_dynamic () || get_input_element_type (12 ) == element::i32,
131
+ " Element type of `max_context_len` input should be i32, but it is " ,
129
132
get_input_element_type (12 ),
130
- " , shape " ,
131
- get_input_partial_shape (12 ));
133
+ " ." );
132
134
133
- set_output_type (0 , query_type, query_shape );
135
+ set_output_type (0 , get_input_element_type ( 0 ), get_input_partial_shape ( 0 ) );
134
136
}
135
137
136
138
std::shared_ptr<ov::Node> PagedAttentionExtension::clone_with_new_inputs (const ov::OutputVector& new_args) const {
0 commit comments