Skip to content

Commit 09fbdf8

Browse files
committed
Initial version of StatefulToStateless transformation (not tested, wrong names of new Parameter/Result pairs).
1 parent 96a534c commit 09fbdf8

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/pass/pass.hpp"
8+
9+
namespace ov {
10+
namespace pass {
11+
/**
12+
* @brief The transformation convert KV cache state back to stateless form.
13+
* \ingroup ov_pass_cpp_api
14+
*/
15+
class OPENVINO_API StatefulToStateless : public ModelPass {
16+
public:
17+
OPENVINO_RTTI("StatefulToStateless");
18+
19+
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
20+
};
21+
} // namespace pass
22+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/pass/stateful_to_stateless.hpp"
6+
7+
#include "openvino/cc/pass/itt.hpp"
8+
#include "openvino/op/gather.hpp"
9+
#include "openvino/op/read_value.hpp"
10+
#include "openvino/op/assign.hpp"
11+
#include "openvino/pass/manager.hpp"
12+
#include "transformations/utils/utils.hpp"
13+
14+
using namespace ov::op;
15+
16+
static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> node, const std::string& name) {
17+
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
18+
// given single name)
19+
node->set_friendly_name(name);
20+
OPENVINO_ASSERT(node->get_output_size() == 1);
21+
node->get_output_tensor(0).set_names({name});
22+
return node;
23+
}
24+
25+
26+
std::shared_ptr<v0::Parameter> get_parameter_by_tensor_name(const std::shared_ptr<ov::Model>& model, const std::string& name) {
27+
for (const auto& param : model->get_parameters()) {
28+
if (param->get_output_tensor(0).get_names().count(name))
29+
return param;
30+
}
31+
return nullptr; // nullptr and return type are only difference from ov::Model::input(name)
32+
}
33+
34+
35+
bool ov::pass::StatefulToStateless::run_on_model(const std::shared_ptr<ov::Model>& model) {
36+
RUN_ON_MODEL_SCOPE(StatefulToStateless);
37+
38+
auto beam_idx = get_parameter_by_tensor_name(model, "beam_idx");
39+
ov::NodeVector future_params; // to collect nodes, each with a single output that will be replaced by new parameters
40+
std::vector<std::shared_ptr<op::util::Variable>> variables; // to collect variables corresponding to future_params
41+
if(beam_idx) {
42+
for(const ov::Input<ov::Node>& input: beam_idx->get_output_target_inputs(0)) {
43+
if(auto gather = std::dynamic_pointer_cast<op::util::GatherBase>(input.get_node()->shared_from_this())) {
44+
auto read_value = std::dynamic_pointer_cast<op::util::ReadValueBase>(gather->get_input_node_shared_ptr(0));
45+
OPENVINO_ASSERT(
46+
read_value,
47+
"Unexpected model topology in StatefulToStateless: no ReadValue is found at the first input of Gather by `beam_idx` parameter");
48+
future_params.push_back(gather);
49+
variables.push_back(read_value->get_variable());
50+
}
51+
}
52+
} else {
53+
OPENVINO_THROW("Stateful models without `beam_idx` input are not supported in StatefulToStateless transformation");
54+
}
55+
56+
// TODO: Use naming convention for variables to sort them in the orignal order and assing name for new Parameter's/Result's
57+
58+
std::unordered_map<std::string, std::shared_ptr<op::util::AssignBase>> assign_by_var_name;
59+
for(auto sink: model->get_sinks()) {
60+
if(auto assign = std::dynamic_pointer_cast<op::util::AssignBase>(sink)) {
61+
assign_by_var_name[assign->get_variable_id()] = assign;
62+
}
63+
}
64+
65+
model->remove_parameter(beam_idx);
66+
for(size_t i = 0; i < future_params.size(); ++i) {
67+
auto future_param = future_params[i];
68+
std::string variable_id = variables[i]->get_info().variable_id;
69+
auto parameter = setName(std::make_shared<v0::Parameter>(
70+
future_param->get_output_element_type(0),
71+
future_param->get_output_partial_shape(0)),
72+
variable_id + ".restored_input");
73+
model->add_parameters({parameter});
74+
replace_node(future_param, parameter);
75+
auto assign = assign_by_var_name[variable_id];
76+
auto result = std::make_shared<v0::Result>(assign->input_value(0));
77+
replace_node(assign, result);
78+
model->remove_sink(assign);
79+
}
80+
81+
return true;
82+
}

0 commit comments

Comments
 (0)