Skip to content

Commit 23c61e0

Browse files
committed
PagedAttention transformation placeholder
1 parent 7d56aeb commit 23c61e0

File tree

7 files changed

+231
-151
lines changed

7 files changed

+231
-151
lines changed

src/bindings/python/src/openvino/_offline_transformations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from openvino._pyopenvino._offline_transformations import compress_model_transformation
1818
from openvino._pyopenvino._offline_transformations import compress_quantize_weights_transformation
1919
from openvino._pyopenvino._offline_transformations import convert_sequence_to_tensor_iterator_transformation
20+
from openvino._pyopenvino._offline_transformations import paged_attention_transformation

src/bindings/python/src/pyopenvino/core/offline_transformations.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <transformations/flush_fp32_subnormals_to_zero.hpp>
1919
#include <transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp>
2020
#include <transformations/smart_reshape/smart_reshape.hpp>
21+
#include <openvino/pass/sdpa_to_paged_attention.hpp>
2122

2223
#include "openvino/pass/low_latency.hpp"
2324
#include "openvino/pass/manager.hpp"
@@ -127,4 +128,14 @@ void regmodule_offline_transformations(py::module m) {
127128
manager.run_passes(model);
128129
},
129130
py::arg("model"));
131+
132+
133+
m_offline_transformations.def(
134+
"paged_attention_transformation",
135+
[](std::shared_ptr<ov::Model> model) {
136+
ov::pass::Manager manager;
137+
manager.register_pass<ov::pass::SDPAToPagedAttention>();
138+
manager.run_passes(model);
139+
},
140+
py::arg("model"));
130141
}

src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp

+2-151
Original file line numberDiff line numberDiff line change
@@ -5,162 +5,13 @@
55
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"
66

77
#include "openvino/op/op.hpp"
8+
#include "openvino/op/paged_attention.hpp"
89
#include "pyopenvino/core/common.hpp"
910

1011
namespace py = pybind11;
1112

12-
namespace {
13-
14-
// This is an experimental operation that is implemented in the plugins.
15-
// Do not use in user applications, backward compatibility is not guaranteed in future releases.
16-
class PagedAttentionExtension : public ov::op::Op {
17-
public:
18-
OPENVINO_OP("PagedAttentionExtension");
19-
20-
PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) {
21-
constructor_validate_and_infer_types();
22-
}
23-
24-
void validate_and_infer_types() override {
25-
auto value_cache_shape = get_input_partial_shape(4);
26-
// m_num_kv_heads = value_cache_shape[1];
27-
// m_head_size = value_cache_shape[2];
28-
// m_block_size = value_cache_shape[3];
29-
NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims");
30-
31-
// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
32-
auto key_cache_shape = get_input_partial_shape(3);
33-
NODE_VALIDATION_CHECK(this,
34-
value_cache_shape.size() == 4,
35-
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
36-
// key_cache_shape[1] == m_num_kv_heads &&
37-
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
38-
// m_block_size == key_cache_shape[3], // block_size,
39-
"Key cache shape must be 4 dims");
40-
41-
// query: shape [batch_size, seq_len, num_heads * head_size]
42-
auto query_type = get_input_element_type(0);
43-
auto query_shape = get_input_partial_shape(0);
44-
NODE_VALIDATION_CHECK(
45-
this,
46-
// query_type.is_real() &&
47-
query_shape.size() == 3,
48-
// query_shape[2] == m_num_heads * m_head_size,
49-
"Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ",
50-
"Got element type ",
51-
query_type,
52-
", shape ",
53-
query_shape);
54-
55-
// key: shape [batch_size, seq_len, num_kv_heads * head_size]
56-
auto key_type = get_input_element_type(1);
57-
auto key_shape = get_input_partial_shape(1);
58-
NODE_VALIDATION_CHECK(this,
59-
// query_type == key_type &&
60-
key_shape.size() == 3,
61-
"Key type must be the same as query, shape must be the same as query. "
62-
"Got element type ",
63-
key_type,
64-
", shape ",
65-
key_shape);
66-
67-
// value: shape [batch_size, seq_len, num_kv_heads * head_size]
68-
// auto value_type = get_input_element_type(2);
69-
auto value_shape = get_input_partial_shape(2);
70-
71-
// is_prompt: boolean scalar
72-
NODE_VALIDATION_CHECK(this,
73-
// get_input_element_type(5) == ov::element::boolean &&
74-
get_input_shape(5) == ov::Shape({}),
75-
"is_prompt validation failed. ",
76-
"Got element type ",
77-
get_input_element_type(5),
78-
", shape ",
79-
get_input_shape(5));
80-
81-
// slot_mapping: shape [batch_size, max_context_len]
82-
auto slot_mapping_shape = get_input_partial_shape(6);
83-
NODE_VALIDATION_CHECK(this,
84-
// get_input_element_type(6) == ov::element::i64 &&
85-
slot_mapping_shape.size() == 2,
86-
"slot_mapping validation failed. ",
87-
"Got element type ",
88-
get_input_element_type(6),
89-
", shape ",
90-
slot_mapping_shape);
91-
92-
// max_context_len: integer scalar
93-
NODE_VALIDATION_CHECK(this,
94-
// get_input_element_type(7) == ov::element::i32 &&
95-
get_input_shape(7) == ov::Shape({}),
96-
"max_context_len validation failed. ",
97-
"Got element type ",
98-
get_input_element_type(7),
99-
", shape ",
100-
get_input_shape(7));
101-
102-
// context_lens: shape [batch_size]
103-
auto context_lens_shape = get_input_partial_shape(8);
104-
NODE_VALIDATION_CHECK(this,
105-
// get_input_element_type(8) == ov::element::i32 &&
106-
context_lens_shape.size() == 1,
107-
"context_lens validation failed. ",
108-
"Got element type ",
109-
get_input_element_type(8),
110-
", shape ",
111-
context_lens_shape);
112-
113-
// block_tables: shape [batch_size, max_block_per_request]
114-
NODE_VALIDATION_CHECK(this,
115-
// get_input_element_type(9) == ov::element::i32 &&
116-
get_input_partial_shape(9).size() == 2,
117-
"block_tables validation failed. ",
118-
"Got element type ",
119-
get_input_element_type(9),
120-
", shape ",
121-
get_input_partial_shape(9));
122-
123-
// scale: float scalar
124-
NODE_VALIDATION_CHECK(this,
125-
// get_input_element_type(10) == ov::element::f32 &&
126-
get_input_shape(10) == ov::Shape({}),
127-
"block_tables validation failed. ",
128-
"Got element type ",
129-
get_input_element_type(10),
130-
", shape ",
131-
get_input_shape(10));
132-
133-
// alibi_slopes: 1D float tensor
134-
NODE_VALIDATION_CHECK(this,
135-
// get_input_element_type(11) == ov::element::f32 &&
136-
get_input_partial_shape(11).rank().get_length() == 1,
137-
"alibi_slopes should be a 1D float tensor. ",
138-
"Got element type ",
139-
get_input_element_type(11),
140-
", shape ",
141-
get_input_partial_shape(11));
142-
143-
// sliding_window: int scalar
144-
NODE_VALIDATION_CHECK(this,
145-
// get_input_element_type(12) == ov::element::i32 &&
146-
get_input_partial_shape(12).rank().get_length() == 0,
147-
"sliding_window argument should be an i32 scalar. ",
148-
"Got element type ",
149-
get_input_element_type(12),
150-
", shape ",
151-
get_input_partial_shape(12));
152-
153-
set_output_type(0, query_type, query_shape);
154-
}
155-
156-
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override {
157-
return std::make_shared<PagedAttentionExtension>(new_args);
158-
}
159-
};
160-
161-
} // namespace
162-
16313
void regclass_graph_op_PagedAttentionExtension(py::module m) {
14+
using ov::op::PagedAttentionExtension;
16415
py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls(
16516
m,
16617
"_PagedAttentionExtension");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#pragma once
5+
6+
#include "openvino/op/op.hpp"
7+
8+
namespace ov {
9+
namespace op {
10+
11+
// This is an experimental operation that is implemented in the plugins.
12+
// Do not use in user applications, backward compatibility is not guaranteed in future releases.
13+
class OPENVINO_API PagedAttentionExtension : public ov::op::Op {
14+
public:
15+
OPENVINO_OP("PagedAttentionExtension");
16+
17+
PagedAttentionExtension(const ov::OutputVector& args);
18+
void validate_and_infer_types() override;
19+
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
20+
};
21+
22+
} // namespace op
23+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <memory>
8+
#include <vector>
9+
10+
#include "openvino/pass/pass.hpp"
11+
12+
namespace ov {
13+
namespace pass {
14+
/**
15+
* @brief The transformation replaces KV-cache processing part in LLMs by PagedAttention operation.
16+
* \ingroup ov_pass_cpp_api
17+
*/
18+
class OPENVINO_API SDPAToPagedAttention : public ModelPass {
19+
public:
20+
OPENVINO_RTTI("SDPAToPagedAttention");
21+
22+
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
23+
};
24+
} // namespace pass
25+
} // namespace ov

src/core/src/op/paged_attention.cpp

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/op/op.hpp"
6+
#include "openvino/op/paged_attention.hpp"
7+
8+
namespace ov {
9+
namespace op {
10+
11+
PagedAttentionExtension::PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) {
12+
constructor_validate_and_infer_types();
13+
}
14+
15+
void PagedAttentionExtension::validate_and_infer_types() {
16+
auto value_cache_shape = get_input_partial_shape(4);
17+
// m_num_kv_heads = value_cache_shape[1];
18+
// m_head_size = value_cache_shape[2];
19+
// m_block_size = value_cache_shape[3];
20+
NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims");
21+
22+
// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
23+
auto key_cache_shape = get_input_partial_shape(3);
24+
NODE_VALIDATION_CHECK(this,
25+
value_cache_shape.size() == 4,
26+
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
27+
// key_cache_shape[1] == m_num_kv_heads &&
28+
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
29+
// m_block_size == key_cache_shape[3], // block_size,
30+
"Key cache shape must be 4 dims");
31+
32+
// query: shape [batch_size, seq_len, num_heads * head_size]
33+
auto query_type = get_input_element_type(0);
34+
auto query_shape = get_input_partial_shape(0);
35+
NODE_VALIDATION_CHECK(
36+
this,
37+
// query_type.is_real() &&
38+
query_shape.size() == 3,
39+
// query_shape[2] == m_num_heads * m_head_size,
40+
"Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ",
41+
"Got element type ",
42+
query_type,
43+
", shape ",
44+
query_shape);
45+
46+
// key: shape [batch_size, seq_len, num_kv_heads * head_size]
47+
auto key_type = get_input_element_type(1);
48+
auto key_shape = get_input_partial_shape(1);
49+
NODE_VALIDATION_CHECK(this,
50+
// query_type == key_type &&
51+
key_shape.size() == 3,
52+
"Key type must be the same as query, shape must be the same as query. "
53+
"Got element type ",
54+
key_type,
55+
", shape ",
56+
key_shape);
57+
58+
// value: shape [batch_size, seq_len, num_kv_heads * head_size]
59+
// auto value_type = get_input_element_type(2);
60+
auto value_shape = get_input_partial_shape(2);
61+
62+
// is_prompt: boolean scalar
63+
NODE_VALIDATION_CHECK(this,
64+
// get_input_element_type(5) == ov::element::boolean &&
65+
get_input_shape(5) == ov::Shape({}),
66+
"is_prompt validation failed. ",
67+
"Got element type ",
68+
get_input_element_type(5),
69+
", shape ",
70+
get_input_shape(5));
71+
72+
// slot_mapping: shape [batch_size, max_context_len]
73+
auto slot_mapping_shape = get_input_partial_shape(6);
74+
NODE_VALIDATION_CHECK(this,
75+
// get_input_element_type(6) == ov::element::i64 &&
76+
slot_mapping_shape.size() == 2,
77+
"slot_mapping validation failed. ",
78+
"Got element type ",
79+
get_input_element_type(6),
80+
", shape ",
81+
slot_mapping_shape);
82+
83+
// max_context_len: integer scalar
84+
NODE_VALIDATION_CHECK(this,
85+
// get_input_element_type(7) == ov::element::i32 &&
86+
get_input_shape(7) == ov::Shape({}),
87+
"max_context_len validation failed. ",
88+
"Got element type ",
89+
get_input_element_type(7),
90+
", shape ",
91+
get_input_shape(7));
92+
93+
// context_lens: shape [batch_size]
94+
auto context_lens_shape = get_input_partial_shape(8);
95+
NODE_VALIDATION_CHECK(this,
96+
// get_input_element_type(8) == ov::element::i32 &&
97+
context_lens_shape.size() == 1,
98+
"context_lens validation failed. ",
99+
"Got element type ",
100+
get_input_element_type(8),
101+
", shape ",
102+
context_lens_shape);
103+
104+
// block_tables: shape [batch_size, max_block_per_request]
105+
NODE_VALIDATION_CHECK(this,
106+
// get_input_element_type(9) == ov::element::i32 &&
107+
get_input_partial_shape(9).size() == 2,
108+
"block_tables validation failed. ",
109+
"Got element type ",
110+
get_input_element_type(9),
111+
", shape ",
112+
get_input_partial_shape(9));
113+
114+
// scale: float scalar
115+
NODE_VALIDATION_CHECK(this,
116+
// get_input_element_type(10) == ov::element::f32 &&
117+
get_input_shape(10) == ov::Shape({}),
118+
"block_tables validation failed. ",
119+
"Got element type ",
120+
get_input_element_type(10),
121+
", shape ",
122+
get_input_shape(10));
123+
124+
// alibi_slopes: 1D float tensor
125+
NODE_VALIDATION_CHECK(this,
126+
// get_input_element_type(11) == ov::element::f32 &&
127+
get_input_partial_shape(11).rank().get_length() == 1,
128+
"alibi_slopes should be a 1D float tensor. ",
129+
"Got element type ",
130+
get_input_element_type(11),
131+
", shape ",
132+
get_input_partial_shape(11));
133+
134+
// sliding_window: int scalar
135+
NODE_VALIDATION_CHECK(this,
136+
// get_input_element_type(12) == ov::element::i32 &&
137+
get_input_partial_shape(12).rank().get_length() == 0,
138+
"sliding_window argument should be an i32 scalar. ",
139+
"Got element type ",
140+
get_input_element_type(12),
141+
", shape ",
142+
get_input_partial_shape(12));
143+
144+
set_output_type(0, query_type, query_shape);
145+
}
146+
147+
std::shared_ptr<ov::Node> PagedAttentionExtension::clone_with_new_inputs(const ov::OutputVector& new_args) const {
148+
return std::make_shared<PagedAttentionExtension>(new_args);
149+
}
150+
151+
} // namespace op
152+
} // namespace ov

0 commit comments

Comments
 (0)