Skip to content

Commit 0733d60

Browse files
committed
Perform f16 compression to postponed constant input
1 parent 3ffdd2a commit 0733d60

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

src/common/transformations/src/transformations/common_optimizations/compress_float_constants.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,28 @@ ov::pass::CompressFloatConstantsImpl::CompressFloatConstantsImpl(bool postponed)
196196
return false;
197197
}
198198
auto constant_target_inputs = const_node->get_output_target_inputs(0);
199-
auto convert = std::make_shared<ov::op::v0::Convert>(new_const, const_node->get_element_type());
199+
200+
std::shared_ptr<ov::Node> postponed_constant_node;
201+
decltype(constant_target_inputs) postponed_constant_node_target_inputs;
202+
bool is_postponed_constant_next = [&]() {
203+
if (constant_target_inputs.size() == 1 &&
204+
constant_target_inputs.begin()->get_node()->get_rt_info().count("postponed_constant")) {
205+
postponed_constant_node = constant_target_inputs.begin()->get_node()->shared_from_this();
206+
postponed_constant_node_target_inputs = postponed_constant_node->get_output_target_inputs(0);
207+
return true;
208+
}
209+
return false;
210+
}();
211+
// is_postponed_constant_next flag means that the next node is to be constant_folded later during serialization.
212+
// If f16 conversion is also postponed, we need to insert Convert node after the postponed_constant node
213+
214+
std::shared_ptr<ov::Node> convert;
215+
if (is_postponed_constant_next && postponed) {
216+
convert = std::make_shared<ov::op::v0::Convert>(postponed_constant_node, const_node->get_element_type());
217+
postponed_constant_node->set_friendly_name(const_node->get_friendly_name() + "_compressed");
218+
} else {
219+
convert = std::make_shared<ov::op::v0::Convert>(new_const, const_node->get_element_type());
220+
}
200221

201222
convert->set_friendly_name(const_node->get_friendly_name());
202223
new_const->set_friendly_name(const_node->get_friendly_name() + "_compressed");
@@ -206,7 +227,10 @@ ov::pass::CompressFloatConstantsImpl::CompressFloatConstantsImpl(bool postponed)
206227
postpone_fp16_compression(new_const->get_rt_info());
207228
postpone_fp16_compression(new_const->get_output_tensor(0).get_rt_info());
208229

209-
for (const auto& target_input : constant_target_inputs) {
230+
auto target_inputs_to_replace = is_postponed_constant_next
231+
? postponed_constant_node_target_inputs
232+
: constant_target_inputs;
233+
for (const auto& target_input : target_inputs_to_replace) {
210234
target_input.replace_source_output(convert);
211235
}
212236
} else {

src/common/transformations/src/transformations/common_optimizations/matmul_const_transposes_extraction.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ ov::pass::MatMulConstTransposesExtraction::MatMulConstTransposesExtraction() {
4343
transpose->get_rt_info()["postponed_constant"] = true;
4444
// disable constant folding here to postpone it to serialization step
4545
ov::pass::disable_constant_folding(transpose);
46-
// disable fp16 compression. Otherwise an additional conversion will be added after the constant, which
47-
// breaks postponed_constant serialization
48-
ov::disable_fp16_compression(weights.get_node_shared_ptr());
4946
}
5047
auto new_matmul = std::make_shared<ov::op::v0::MatMul>(pattern_value_map.at(data_pattern),
5148
transpose,

0 commit comments

Comments
 (0)