Skip to content

Commit 2442243

Browse files
committed
Fix code style
1 parent cfc75d4 commit 2442243

File tree

4 files changed

+38
-25
lines changed

4 files changed

+38
-25
lines changed

src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ class PrevSequenceLengthPattern;
2323
class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass {
2424
public:
2525
OPENVINO_RTTI("PrevSequenceLengthPattern", "0");
26-
explicit PrevSequenceLengthPattern(const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len, std::shared_ptr<ov::Node>);
26+
explicit PrevSequenceLengthPattern(const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len,
27+
std::shared_ptr<ov::Node>);
2728
};

src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp

+14-11
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@
55
#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp"
66

77
#include "openvino/cc/pass/itt.hpp"
8+
#include "openvino/core/validation_util.hpp"
89
#include "openvino/op/gather.hpp"
9-
#include "openvino/op/shape_of.hpp"
1010
#include "openvino/op/reshape.hpp"
11+
#include "openvino/op/shape_of.hpp"
1112
#include "openvino/pass/pattern/op/wrap_type.hpp"
12-
#include "openvino/core/validation_util.hpp"
1313
#include "transformations/utils/utils.hpp"
1414

1515
using namespace ov::op;
1616

17-
1817
ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
19-
const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len, std::shared_ptr<ov::Node> batch_dim) {
18+
const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len,
19+
std::shared_ptr<ov::Node> batch_dim) {
2020
MATCHER_SCOPE(PrevSequenceLengthPattern);
21-
// The transformation addresses two cases that look similar: (1) previous sequence length, (2) batch size in kv-cache state
22-
// In first case it should replace it by prev_max_seq_len. For the second case, connect to batch_dim.
21+
// The transformation addresses two cases that look similar: (1) previous sequence length, (2) batch size in
22+
// kv-cache state In first case it should replace it by prev_max_seq_len. For the second case, connect to batch_dim.
2323

2424
auto kv_past = pattern::wrap_type<v6::ReadValue>({pattern::any_input()});
2525
auto kv_gather = pattern::wrap_type<v8::Gather>({kv_past, pattern::any_input(), pattern::any_input()});
@@ -33,23 +33,26 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
3333
const auto& pattern_map = m.get_pattern_value_map();
3434
auto gather = m.get_match_root();
3535
auto gather_index = ov::util::get_constant_from_source(gather->input_value(1));
36-
if(!gather_index) {
37-
return false; // cannot detect axis
36+
if (!gather_index) {
37+
return false; // cannot detect axis
3838
}
3939
auto axis = gather_index->cast_vector<int64_t>().at(0);
4040
auto kv_init_shape = pattern_map.at(kv_past).get_node()->get_input_partial_shape(0);
4141
auto target_type = gather->get_output_element_type(0);
42-
if(kv_init_shape[axis].is_static() && kv_init_shape[axis].get_length() == 0) {
42+
if (kv_init_shape[axis].is_static() && kv_init_shape[axis].get_length() == 0) {
4343
// this is a sequence dimension based on how the initialization expression is build for stateful models
4444
std::shared_ptr<ov::Node> replacement;
4545
if (prev_max_seq_len->get_output_element_type(0) != target_type) {
4646
replacement = std::make_shared<v0::Convert>(prev_max_seq_len, target_type);
4747
} else {
4848
replacement = prev_max_seq_len;
4949
}
50-
replace_node(gather, std::make_shared<v1::Reshape>(replacement, v0::Constant::create(element::i64, Shape{1}, {1}), false));
50+
replace_node(
51+
gather,
52+
std::make_shared<v1::Reshape>(replacement, v0::Constant::create(element::i64, Shape{1}, {1}), false));
5153
return true;
52-
} else { // assumption that any other axis should point to batch dimension, precise reasoning is too complex (TODO)
54+
} else { // assumption that any other axis should point to batch dimension, precise reasoning is too complex
55+
// (TODO)
5356
// this is a batch dimension
5457
std::shared_ptr<ov::Node> replacement;
5558
if (batch_dim->get_output_element_type(0) != target_type) {

src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp

+19-11
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,16 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
120120
auto q = pattern::any_input();
121121
auto scale_input = pattern::any_input();
122122

123-
auto k_to_sdpa = std::make_shared<pattern::op::Or>(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped});
124-
auto v_to_sdpa = std::make_shared<pattern::op::Or>(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped});
123+
auto k_to_sdpa =
124+
std::make_shared<pattern::op::Or>(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped});
125+
auto v_to_sdpa =
126+
std::make_shared<pattern::op::Or>(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped});
125127
auto mask_to_sdpa = std::make_shared<pattern::op::Or>(OutputVector{sdpa_mask, pattern::any_input()});
126128

127-
auto sdpa_with_4_inputs = pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
128-
auto sdpa_with_5_inputs = pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa, scale_input});
129+
auto sdpa_with_4_inputs =
130+
pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
131+
auto sdpa_with_5_inputs =
132+
pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa, scale_input});
129133

130134
auto sdpa_variants = std::make_shared<pattern::op::Or>(OutputVector{sdpa_with_4_inputs, sdpa_with_5_inputs});
131135

@@ -157,13 +161,15 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
157161
auto real_k = take_4d(k_current, k_current_reshaped, k_current2);
158162
auto real_v = take_4d(v_current, v_current_reshaped, v_current2);
159163

160-
auto sdpa_node = pattern_map.at(pattern_map.count(sdpa_with_4_inputs) ? sdpa_with_4_inputs : sdpa_with_5_inputs).get_node();
164+
auto sdpa_node =
165+
pattern_map.at(pattern_map.count(sdpa_with_4_inputs) ? sdpa_with_4_inputs : sdpa_with_5_inputs).get_node();
161166
// E and Ev are from the SDPA specification at
162167
// https://docs.openvino.ai/2024/documentation/openvino-ir-format/operation-sets/operation-specs/sequence/scaled-dot-product-attention.html
163168
auto E = sdpa_node->get_input_tensor(1).get_partial_shape()[-1];
164169
auto Ev = sdpa_node->get_input_tensor(2).get_partial_shape()[-1]; // in common case may not match E
165170

166-
auto extract_num_kv_heads = [=, &pattern_map](std::shared_ptr<Node> unsqueeze, const Dimension& default_heads_num) {
171+
auto extract_num_kv_heads = [=, &pattern_map](std::shared_ptr<Node> unsqueeze,
172+
const Dimension& default_heads_num) {
167173
// Deduce number of k/v heads from Unsqueeze-Broadcast-Reshape (UBR pattern, if present)
168174
// pattern that appears in case of MQA/GQA
169175
// In case if UBR pattern doesn't appear, the default number of heads is used passed as default_heads_num
@@ -197,8 +203,10 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
197203
}
198204
};
199205

200-
auto num_k_heads = extract_num_kv_heads(k_heads_unsqueeze, sdpa_node->get_input_tensor(1).get_partial_shape()[-3]);
201-
auto num_v_heads = extract_num_kv_heads(v_heads_unsqueeze, sdpa_node->get_input_tensor(1).get_partial_shape()[-3]);
206+
auto num_k_heads =
207+
extract_num_kv_heads(k_heads_unsqueeze, sdpa_node->get_input_tensor(1).get_partial_shape()[-3]);
208+
auto num_v_heads =
209+
extract_num_kv_heads(v_heads_unsqueeze, sdpa_node->get_input_tensor(1).get_partial_shape()[-3]);
202210
const ov::element::Type kv_cache_type = real_q.get_element_type();
203211
std::string layer_index_str = std::to_string(layer_index);
204212
auto k_parameter = setName(std::make_shared<v0::Parameter>(kv_cache_type, PartialShape{-1, num_k_heads, E}),
@@ -243,12 +251,12 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
243251
v0::Constant::create(element::i64, Shape{}, {-1}),
244252
v0::Constant::create(element::i64, Shape{}, {0}));
245253
std::shared_ptr<ov::Node> scale;
246-
if(pattern_map.count(scale_input)) {
254+
if (pattern_map.count(scale_input)) {
247255
scale = pattern_map.at(scale_input).get_node_shared_ptr();
248256
} else {
249257
// most likely `scale` below will always be a constant in real inference, but dynamic dimension
250-
// propagation may not always derive it as a constant. That's why a sub-graph computing `scale` is built instead
251-
// of just a constant node representing one of the dimensions.
258+
// propagation may not always derive it as a constant. That's why a sub-graph computing `scale` is built
259+
// instead of just a constant node representing one of the dimensions.
252260
scale = std::make_shared<v1::Divide>(
253261
v0::Constant::create(element::f32, Shape{}, {1}),
254262
std::make_shared<v0::Sqrt>(std::make_shared<v0::Convert>(hidden_dim, element::f32)));

src/core/src/pass/sdpa_to_paged_attention.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
#include "openvino/cc/pass/itt.hpp"
88
#include "openvino/op/constant.hpp"
99
#include "openvino/op/gather.hpp"
10+
#include "openvino/op/shape_of.hpp"
1011
#include "openvino/op/unsqueeze.hpp"
1112
#include "openvino/pass/manager.hpp"
12-
#include "openvino/op/shape_of.hpp"
1313
#include "transformations/sdpa_to_paged_attention/position_ids_replacer.hpp"
1414
#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp"
1515
#include "transformations/sdpa_to_paged_attention/state_management_pattern.hpp"
@@ -82,7 +82,8 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
8282

8383
int layer_index = 0;
8484

85-
auto batch_dim = std::make_shared<v3::ShapeOf>(position_ids); // it is not always required, so will be disposed if not needed
85+
auto batch_dim =
86+
std::make_shared<v3::ShapeOf>(position_ids); // it is not always required, so will be disposed if not needed
8687

8788
ov::pass::Manager manager;
8889
manager.set_per_pass_validation(false);

0 commit comments

Comments
 (0)