@@ -120,12 +120,16 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
120
120
auto q = pattern::any_input ();
121
121
auto scale_input = pattern::any_input ();
122
122
123
- auto k_to_sdpa = std::make_shared<pattern::op::Or>(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped});
124
- auto v_to_sdpa = std::make_shared<pattern::op::Or>(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped});
123
+ auto k_to_sdpa =
124
+ std::make_shared<pattern::op::Or>(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped});
125
+ auto v_to_sdpa =
126
+ std::make_shared<pattern::op::Or>(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped});
125
127
auto mask_to_sdpa = std::make_shared<pattern::op::Or>(OutputVector{sdpa_mask, pattern::any_input ()});
126
128
127
- auto sdpa_with_4_inputs = pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
128
- auto sdpa_with_5_inputs = pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa, scale_input});
129
+ auto sdpa_with_4_inputs =
130
+ pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
131
+ auto sdpa_with_5_inputs =
132
+ pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa, scale_input});
129
133
130
134
auto sdpa_variants = std::make_shared<pattern::op::Or>(OutputVector{sdpa_with_4_inputs, sdpa_with_5_inputs});
131
135
@@ -157,13 +161,15 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
157
161
auto real_k = take_4d (k_current, k_current_reshaped, k_current2);
158
162
auto real_v = take_4d (v_current, v_current_reshaped, v_current2);
159
163
160
- auto sdpa_node = pattern_map.at (pattern_map.count (sdpa_with_4_inputs) ? sdpa_with_4_inputs : sdpa_with_5_inputs).get_node ();
164
+ auto sdpa_node =
165
+ pattern_map.at (pattern_map.count (sdpa_with_4_inputs) ? sdpa_with_4_inputs : sdpa_with_5_inputs).get_node ();
161
166
// E and Ev are from the SDPA specification at
162
167
// https://docs.openvino.ai/2024/documentation/openvino-ir-format/operation-sets/operation-specs/sequence/scaled-dot-product-attention.html
163
168
auto E = sdpa_node->get_input_tensor (1 ).get_partial_shape ()[-1 ];
164
169
auto Ev = sdpa_node->get_input_tensor (2 ).get_partial_shape ()[-1 ]; // in common case may not match E
165
170
166
- auto extract_num_kv_heads = [=, &pattern_map](std::shared_ptr<Node> unsqueeze, const Dimension& default_heads_num) {
171
+ auto extract_num_kv_heads = [=, &pattern_map](std::shared_ptr<Node> unsqueeze,
172
+ const Dimension& default_heads_num) {
167
173
// Deduce number of k/v heads from Unsqueeze-Broadcast-Reshape (UBR pattern, if present)
168
174
// pattern that appears in case of MQA/GQA
169
175
// In case if UBR pattern doesn't appear, the default number of heads is used passed as default_heads_num
@@ -197,8 +203,10 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
197
203
}
198
204
};
199
205
200
- auto num_k_heads = extract_num_kv_heads (k_heads_unsqueeze, sdpa_node->get_input_tensor (1 ).get_partial_shape ()[-3 ]);
201
- auto num_v_heads = extract_num_kv_heads (v_heads_unsqueeze, sdpa_node->get_input_tensor (1 ).get_partial_shape ()[-3 ]);
206
+ auto num_k_heads =
207
+ extract_num_kv_heads (k_heads_unsqueeze, sdpa_node->get_input_tensor (1 ).get_partial_shape ()[-3 ]);
208
+ auto num_v_heads =
209
+ extract_num_kv_heads (v_heads_unsqueeze, sdpa_node->get_input_tensor (1 ).get_partial_shape ()[-3 ]);
202
210
const ov::element::Type kv_cache_type = real_q.get_element_type ();
203
211
std::string layer_index_str = std::to_string (layer_index);
204
212
auto k_parameter = setName (std::make_shared<v0::Parameter>(kv_cache_type, PartialShape{-1 , num_k_heads, E}),
@@ -243,12 +251,12 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
243
251
v0::Constant::create (element::i64, Shape{}, {-1 }),
244
252
v0::Constant::create (element::i64, Shape{}, {0 }));
245
253
std::shared_ptr<ov::Node> scale;
246
- if (pattern_map.count (scale_input)) {
254
+ if (pattern_map.count (scale_input)) {
247
255
scale = pattern_map.at (scale_input).get_node_shared_ptr ();
248
256
} else {
249
257
// most likely `scale` below will always be a constant in real inference, but dynamic dimension
250
- // propagation may not always derive it as a constant. That's why a sub-graph computing `scale` is built instead
251
- // of just a constant node representing one of the dimensions.
258
+ // propagation may not always derive it as a constant. That's why a sub-graph computing `scale` is built
259
+ // instead of just a constant node representing one of the dimensions.
252
260
scale = std::make_shared<v1::Divide>(
253
261
v0::Constant::create (element::f32, Shape{}, {1 }),
254
262
std::make_shared<v0::Sqrt>(std::make_shared<v0::Convert>(hidden_dim, element::f32)));
0 commit comments