Skip to content

Commit c0d197c

Browse files
slyalinitikhonoilya-lavrenov
committed
PagedAttention Transformation: Rank alignment for replacements (openvinotoolkit#24690)
During the elimination of dependencies from `beam_idx` input and `ReadValue`(s), we are replacing them by the new PA-related inputs and sub-expressions dependent on other remaining inputs. In such replacements we need to guarantee matching shape and element type of old and new nodes. Before this PR it was not guaranteed for shape and sometimes a scalar was replaced by a shape of rank 1 that led to errors like `'start' input is not a scalar`. Now the shape is aligned. --------- Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com> Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
1 parent 65c3b17 commit c0d197c

File tree

3 files changed

+26
-32
lines changed

3 files changed

+26
-32
lines changed

src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,5 @@ class PrevSequenceLengthPattern;
2323
class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass {
2424
public:
2525
OPENVINO_RTTI("PrevSequenceLengthPattern", "0");
26-
explicit PrevSequenceLengthPattern(const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len,
27-
std::shared_ptr<ov::Node>);
26+
explicit PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len, std::shared_ptr<ov::Node> batch_dim);
2827
};

src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp

+17-25
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414

1515
using namespace ov::op;
1616

17-
ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
18-
const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len,
19-
std::shared_ptr<ov::Node> batch_dim) {
17+
ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len,
18+
std::shared_ptr<ov::Node> batch_dim) {
2019
MATCHER_SCOPE(PrevSequenceLengthPattern);
2120
// The transformation addresses two cases that look similar: (1) previous sequence length, (2) batch size in
2221
// kv-cache state In first case it should replace it by prev_max_seq_len. For the second case, connect to batch_dim.
@@ -39,30 +38,23 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
3938
auto axis = gather_index->cast_vector<int64_t>().at(0);
4039
auto kv_init_shape = pattern_map.at(kv_past).get_node()->get_input_partial_shape(0);
4140
auto target_type = gather->get_output_element_type(0);
41+
std::shared_ptr<ov::Node> replacement;
4242
if (kv_init_shape[axis].is_static() && kv_init_shape[axis].get_length() == 0) {
43-
// this is a sequence dimension based on how the initialization expression is build for stateful models
44-
std::shared_ptr<ov::Node> replacement;
45-
if (prev_max_seq_len->get_output_element_type(0) != target_type) {
46-
replacement = std::make_shared<v0::Convert>(prev_max_seq_len, target_type);
47-
} else {
48-
replacement = prev_max_seq_len;
49-
}
50-
replace_node(
51-
gather,
52-
std::make_shared<v1::Reshape>(replacement, v0::Constant::create(element::i64, Shape{1}, {1}), false));
53-
return true;
54-
} else { // assumption that any other axis should point to batch dimension, precise reasoning is too complex
55-
// (TODO)
56-
// this is a batch dimension
57-
std::shared_ptr<ov::Node> replacement;
58-
if (batch_dim->get_output_element_type(0) != target_type) {
59-
replacement = std::make_shared<v0::Convert>(batch_dim, target_type);
60-
} else {
61-
replacement = batch_dim;
62-
}
63-
replace_node(gather, replacement);
64-
return true;
43+
replacement = prev_max_seq_len;
44+
} else {
45+
// assumption that any other axis should point to batch dimension, precise reasoning is too complex
46+
// TODO: provide more reliable check
47+
replacement = batch_dim;
6548
}
49+
if (replacement->get_output_element_type(0) != target_type) {
50+
replacement = std::make_shared<v0::Convert>(replacement, target_type);
51+
}
52+
auto required_shape = gather->get_output_partial_shape(0);
53+
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
54+
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
55+
}
56+
replace_node(gather, replacement);
57+
return true;
6658
};
6759

6860
auto m = std::make_shared<ov::pass::pattern::Matcher>(seq, matcher_name);

src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "openvino/cc/pass/itt.hpp"
88
#include "openvino/op/concat.hpp"
99
#include "openvino/op/gather.hpp"
10+
#include "openvino/op/reshape.hpp"
1011
#include "openvino/op/shape_of.hpp"
1112
#include "openvino/pass/pattern/op/wrap_type.hpp"
1213
#include "transformations/utils/utils.hpp"
@@ -29,11 +30,13 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
2930
// use symbolic infra or look at the constant input
3031
auto gather = m.get_match_root();
3132
auto target_type = gather->get_output_element_type(0);
32-
std::shared_ptr<Node> replacement;
33-
if (max_context_len->get_output_element_type(0) != target_type) {
34-
replacement = std::make_shared<v0::Convert>(max_context_len, target_type);
35-
} else {
36-
replacement = max_context_len;
33+
std::shared_ptr<Node> replacement = max_context_len;
34+
if (replacement->get_output_element_type(0) != target_type) {
35+
replacement = std::make_shared<v0::Convert>(replacement, target_type);
36+
}
37+
auto required_shape = gather->get_output_partial_shape(0);
38+
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
39+
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
3740
}
3841
replace_node(gather, replacement);
3942
return true;

0 commit comments

Comments
 (0)