Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
yeliang2258 committed Sep 2, 2022
1 parent 79394a3 commit 8924f9b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
36 changes: 19 additions & 17 deletions paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,22 @@ void QuantDequantMkldnnPass::CollectWeightScalesInfoFromONNXFormatDequantize(
std::vector<float> scale_v = var_quant_scales->at(x_var_name);
var_quant_scales->insert(std::make_pair(out_var_name, scale_v));
}
continue;
}
*onnx_format_quantize_model = true;
auto scale_name = op_desc->Input("Scale")[0];
auto* var = scope->FindVar(scale_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The Scales variable [%s] of dequantize op is not found.", var));
} else {
*onnx_format_quantize_model = true;
auto scale_name = op_desc->Input("Scale")[0];
auto* var = scope->FindVar(scale_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The Scales variable [%s] of dequantize op is not found.",
var));

auto* scale_tensor = var->GetMutable<LoDTensor>();
auto* scale_data = scale_tensor->data<float>();
std::vector<float> thresholds{};
for (int i = 0; i < scale_tensor->numel(); i++) {
thresholds.push_back(scale_data[i]);
auto* scale_tensor = var->GetMutable<LoDTensor>();
auto* scale_data = scale_tensor->data<float>();
std::vector<float> thresholds(scale_data,
scale_data + scale_tensor->numel());
weight_thresholds->insert(std::make_pair(x_var_name, thresholds));
}
weight_thresholds->insert(std::make_pair(x_var_name, thresholds));
}
}
}
Expand Down Expand Up @@ -357,15 +356,18 @@ void QuantDequantMkldnnPass::CollectQuantizeDequantizeOpsFromONNXFormat(
fake_quant_in,
platform::errors::NotFound(
"The input var [%s] of quantize op is not found.", x_var_name));
PADDLE_ENFORCE_NOT_NULL(
fake_quant_in_scale,
platform::errors::NotFound(
"The scale var [%s] of quantize op is not found.", in_scale_name));
PADDLE_ENFORCE_NOT_NULL(
fake_quant_out,
platform::errors::NotFound(
"The output var [%s] of quantize op is not found.", out_var_name));

std::string input_act_name = fake_quant_in->Var()->Name();
std::string output_act_name = fake_quant_out->Var()->Name();
auto outlinks = fake_quant_out->outputs;
for (auto* next_node : outlinks) {
for (auto* next_node : fake_quant_out->outputs) {
if (!next_node->IsOp()) continue;
next_node->Op()->RenameInput(output_act_name, input_act_name);
IR_NODE_LINK_TO(fake_quant_in, next_node);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
Expand Down

0 comments on commit 8924f9b

Please sign in to comment.