Skip to content

Commit 6a1d67a

Browse files
committed
Op shell for PagedAttentionExtension
1 parent 8a81500 commit 6a1d67a

File tree

4 files changed

+152
-0
lines changed

4 files changed

+152
-0
lines changed

src/bindings/python/src/openvino/runtime/op/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from openvino._pyopenvino.op import Constant
1212
from openvino._pyopenvino.op import assign
13+
from openvino._pyopenvino.op import _PagedAttentionExtension
1314
from openvino._pyopenvino.op import Parameter
1415
from openvino._pyopenvino.op import if_op
1516
from openvino._pyopenvino.op import loop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/op/op.hpp"
6+
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"
7+
8+
namespace py = pybind11;
9+
10+
namespace {
11+
12+
// This is an experimental operation that is implemented in the plugins.
13+
// Do not use in user applications, backward compatibility is not guaranteed in future releases.
14+
class PagedAttentionExtension : public ov::op::Op {
15+
public:
16+
OPENVINO_OP("PagedAttentionExtension");
17+
18+
PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) {}
19+
20+
void validate_and_infer_types() override {
21+
auto value_cache_shape = get_input_partial_shape(4);
22+
// m_num_kv_heads = value_cache_shape[1];
23+
// m_head_size = value_cache_shape[2];
24+
// m_block_size = value_cache_shape[3];
25+
NODE_VALIDATION_CHECK(this,
26+
value_cache_shape.size() == 4,
27+
"Value cache shape must be 4 dims");
28+
29+
// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
30+
auto key_cache_shape = get_input_partial_shape(3);
31+
NODE_VALIDATION_CHECK(this,
32+
value_cache_shape.size() == 4,
33+
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
34+
// key_cache_shape[1] == m_num_kv_heads &&
35+
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
36+
// m_block_size == key_cache_shape[3], // block_size,
37+
"Key cache shape must be 4 dims");
38+
39+
// query: shape [batch_size, seq_len, num_heads * head_size]
40+
auto query_type = get_input_element_type(0);
41+
auto query_shape = get_input_partial_shape(0);
42+
NODE_VALIDATION_CHECK(this,
43+
// query_type.is_real() &&
44+
query_shape.size() == 3,
45+
// query_shape[2] == m_num_heads * m_head_size,
46+
"Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ",
47+
"Got element type ", query_type, ", shape ", query_shape);
48+
49+
// key: shape [batch_size, seq_len, num_kv_heads * head_size]
50+
auto key_type = get_input_element_type(1);
51+
auto key_shape = get_input_partial_shape(1);
52+
NODE_VALIDATION_CHECK(this,
53+
// query_type == key_type &&
54+
key_shape.size() == 3,
55+
"Key type must be the same as query, shape must be the same as query. "
56+
"Got element type ", key_type, ", shape ", key_shape);
57+
58+
// value: shape [batch_size, seq_len, num_kv_heads * head_size]
59+
// auto value_type = get_input_element_type(2);
60+
auto value_shape = get_input_partial_shape(2);
61+
62+
// is_prompt: boolean scalar
63+
NODE_VALIDATION_CHECK(this,
64+
// get_input_element_type(5) == ov::element::boolean &&
65+
get_input_shape(5) == ov::Shape({}),
66+
"is_prompt validation failed. ",
67+
"Got element type ", get_input_element_type(5), ", shape ", get_input_shape(5));
68+
69+
// slot_mapping: shape [batch_size, max_context_len]
70+
auto slot_mapping_shape = get_input_partial_shape(6);
71+
NODE_VALIDATION_CHECK(this,
72+
// get_input_element_type(6) == ov::element::i64 &&
73+
slot_mapping_shape.size() == 2,
74+
"slot_mapping validation failed. ",
75+
"Got element type ", get_input_element_type(6), ", shape ", slot_mapping_shape);
76+
77+
// max_context_len: integer scalar
78+
NODE_VALIDATION_CHECK(this,
79+
// get_input_element_type(7) == ov::element::i32 &&
80+
get_input_shape(7) == ov::Shape({}),
81+
"max_context_len validation failed. ",
82+
"Got element type ", get_input_element_type(7), ", shape ", get_input_shape(7));
83+
84+
// context_lens: shape [batch_size]
85+
auto context_lens_shape = get_input_partial_shape(8);
86+
NODE_VALIDATION_CHECK(this,
87+
// get_input_element_type(8) == ov::element::i32 &&
88+
context_lens_shape.size() == 1,
89+
"context_lens validation failed. ",
90+
"Got element type ", get_input_element_type(8), ", shape ", context_lens_shape);
91+
92+
// block_tables: shape [batch_size, max_block_per_request]
93+
NODE_VALIDATION_CHECK(this,
94+
// get_input_element_type(9) == ov::element::i32 &&
95+
get_input_partial_shape(9).size() == 2,
96+
"block_tables validation failed. ",
97+
"Got element type ", get_input_element_type(9), ", shape ", get_input_partial_shape(9));
98+
99+
// scale: float scalar
100+
NODE_VALIDATION_CHECK(this,
101+
// get_input_element_type(10) == ov::element::f32 &&
102+
get_input_shape(10) == ov::Shape({}),
103+
"block_tables validation failed. ",
104+
"Got element type ", get_input_element_type(10), ", shape ", get_input_shape(10));
105+
106+
// alibi_slopes: 1D float tensor
107+
NODE_VALIDATION_CHECK(this,
108+
// get_input_element_type(11) == ov::element::f32 &&
109+
get_input_partial_shape(11).rank().get_length() == 1,
110+
"alibi_slopes should be a 1D float tensor. ",
111+
"Got element type ", get_input_element_type(11), ", shape ", get_input_partial_shape(11));
112+
113+
// sliding_window: int scalar
114+
NODE_VALIDATION_CHECK(this,
115+
// get_input_element_type(12) == ov::element::i32 &&
116+
get_input_partial_shape(12).rank().get_length() == 0,
117+
"sliding_window argument should be an i32 scalar. ",
118+
"Got element type ", get_input_element_type(12), ", shape ", get_input_partial_shape(12));
119+
120+
set_output_type(0, query_type, query_shape);
121+
}
122+
123+
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override {
124+
return std::make_shared<PagedAttentionExtension>(new_args);
125+
}
126+
127+
bool has_evaluate() const override {
128+
return true;
129+
}
130+
};
131+
132+
}
133+
134+
void regclass_graph_op_PagedAttentionExtension(py::module m) {
135+
py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls(m, "_PagedAttentionExtension");
136+
cls.doc() = "Experimental extention for PagedAttention operation. Use with care: no backward compatibility is guaranteed in future releases.";
137+
cls.def(py::init<const ov::OutputVector&>());
138+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <pybind11/pybind11.h>
8+
9+
namespace py = pybind11;
10+
11+
void regclass_graph_op_PagedAttentionExtension(py::module m);

src/bindings/python/src/pyopenvino/pyopenvino.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "pyopenvino/graph/ops/constant.hpp"
5353
#include "pyopenvino/graph/ops/if.hpp"
5454
#include "pyopenvino/graph/ops/loop.hpp"
55+
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"
5556
#include "pyopenvino/graph/ops/parameter.hpp"
5657
#include "pyopenvino/graph/ops/result.hpp"
5758
#include "pyopenvino/graph/ops/tensor_iterator.hpp"
@@ -234,6 +235,7 @@ PYBIND11_MODULE(_pyopenvino, m) {
234235
py::module m_op = m.def_submodule("op", "Package ngraph.impl.op that wraps ov::op"); // TODO(!)
235236
regclass_graph_op_Assign(m_op);
236237
regclass_graph_op_Constant(m_op);
238+
regclass_graph_op_PagedAttentionExtension(m_op);
237239
regclass_graph_op_Parameter(m_op);
238240
regclass_graph_op_Result(m_op);
239241
regclass_graph_op_If(m_op);

0 commit comments

Comments
 (0)