Skip to content

Commit 9e36fea

Browse files
Suharsh Sivakumartensorflower-gardener
authored andcommitted
Add support to only quantize specified operators in the quantization tool.
The operators are keys by their first output tensor name. PiperOrigin-RevId: 265096821
1 parent a148b74 commit 9e36fea

File tree

4 files changed

+102
-15
lines changed

4 files changed

+102
-15
lines changed

tensorflow/lite/tools/optimize/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ cc_library(
167167
":operator_property",
168168
":quantization_utils",
169169
"//tensorflow/lite:framework",
170+
"//tensorflow/lite:util",
170171
"//tensorflow/lite/core/api",
171172
"//tensorflow/lite/schema:schema_fbs",
172173
"@flatbuffers",

tensorflow/lite/tools/optimize/quantize_model.cc

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ limitations under the License.
2020
#include <memory>
2121
#include <string>
2222
#include <unordered_map>
23+
#include <unordered_set>
24+
#include <utility>
2325
#include <vector>
2426

2527
#include "flatbuffers/flexbuffers.h"
@@ -36,6 +38,24 @@ namespace optimize {
3638

3739
namespace {
3840

41+
// Gets the operator property from the operator_property list and additionally
42+
// modifies the quantizable parameter based on the user's specified
43+
// operator_names.
44+
operator_property::OperatorProperty GetOperatorProperty(
45+
const std::unordered_set<string>& operator_names, const BuiltinOperator& op,
46+
const string& operator_name) {
47+
operator_property::OperatorProperty property =
48+
operator_property::GetOperatorProperty(op);
49+
// The algorithm adds Dequantize and Quantize, so we don't require them to be
50+
// in the operator_names.
51+
if (op != BuiltinOperator_DEQUANTIZE && op != BuiltinOperator_QUANTIZE) {
52+
property.quantizable =
53+
property.quantizable &&
54+
(operator_names.find(operator_name) != operator_names.end());
55+
}
56+
return property;
57+
}
58+
3959
TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
4060
const TensorT* weight_tensor, TensorT* bias_tensor,
4161
bool is_per_channel, int channel_dim_index,
@@ -239,8 +259,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
239259
// TODO(suharshs): Add support for this case if it ever comes up.
240260
if (tensor->type == TensorType_FLOAT32 && output_type != tensor->type) {
241261
error_reporter->Report(
242-
"Unsupported output type %s for output tensor %d of type %s.",
243-
EnumNameTensorType(output_type), subgraph->outputs[i],
262+
"Unsupported output type %s for output tensor '%s' of type %s.",
263+
EnumNameTensorType(output_type), tensor->name.c_str(),
244264
EnumNameTensorType(tensor->type));
245265
return kTfLiteError;
246266
}
@@ -260,7 +280,9 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
260280
// outpus must have the same scale and zero point. The other ones with
261281
// constraints(averagepool, maxpool, gather, softmax, tanh etc) are handled in
262282
// QuantizeWeightsAndInput.
263-
TfLiteStatus ApplyConstraints(ModelT* model, ErrorReporter* error_reporter) {
283+
TfLiteStatus ApplyConstraints(ModelT* model,
284+
const std::unordered_set<string>& operator_names,
285+
ErrorReporter* error_reporter) {
264286
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
265287
subgraph_idx++) {
266288
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
@@ -269,8 +291,8 @@ TfLiteStatus ApplyConstraints(ModelT* model, ErrorReporter* error_reporter) {
269291
OperatorT* op = subgraph->operators[op_idx].get();
270292
const BuiltinOperator op_code =
271293
model->operator_codes[op->opcode_index]->builtin_code;
272-
operator_property::OperatorProperty property =
273-
operator_property::GetOperatorProperty(op_code);
294+
operator_property::OperatorProperty property = GetOperatorProperty(
295+
operator_names, op_code, subgraph->tensors[op->outputs[0]]->name);
274296
if (!property.quantizable) {
275297
continue;
276298
}
@@ -546,17 +568,19 @@ TfLiteStatus QuantizeOpOutput(
546568

547569
// Quantize inputs and weights.
548570
// Because of ops such as lstm, still need to do per op, instead of weights.
549-
TfLiteStatus QuantizeWeightsInputOutput(ModelT* model, bool allow_float,
550-
ErrorReporter* error_reporter) {
571+
TfLiteStatus QuantizeWeightsInputOutput(
572+
ModelT* model, bool allow_float,
573+
const std::unordered_set<string>& operator_names,
574+
ErrorReporter* error_reporter) {
551575
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
552576
subgraph_idx++) {
553577
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
554578
for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) {
555579
OperatorT* op = subgraph->operators[op_idx].get();
556580
const BuiltinOperator op_code =
557581
model->operator_codes[op->opcode_index]->builtin_code;
558-
operator_property::OperatorProperty property =
559-
operator_property::GetOperatorProperty(op_code);
582+
operator_property::OperatorProperty property = GetOperatorProperty(
583+
operator_names, op_code, subgraph->tensors[op->outputs[0]]->name);
560584

561585
if (!property.quantizable && !allow_float) {
562586
error_reporter->Report("Quantization not yet supported for op: %s",
@@ -583,16 +607,18 @@ TfLiteStatus QuantizeWeightsInputOutput(ModelT* model, bool allow_float,
583607
}
584608

585609
// Quantize bias.
586-
TfLiteStatus QuantizeBiases(ModelT* model, ErrorReporter* error_reporter) {
610+
TfLiteStatus QuantizeBiases(ModelT* model,
611+
const std::unordered_set<string>& operator_names,
612+
ErrorReporter* error_reporter) {
587613
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
588614
subgraph_idx++) {
589615
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
590616
for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) {
591617
OperatorT* op = subgraph->operators[op_idx].get();
592618
const BuiltinOperator op_code =
593619
model->operator_codes[op->opcode_index]->builtin_code;
594-
operator_property::OperatorProperty property =
595-
operator_property::GetOperatorProperty(op_code);
620+
operator_property::OperatorProperty property = GetOperatorProperty(
621+
operator_names, op_code, subgraph->tensors[op->outputs[0]]->name);
596622
if (!property.quantizable) {
597623
continue;
598624
}
@@ -639,17 +665,32 @@ TfLiteStatus QuantizeBiases(ModelT* model, ErrorReporter* error_reporter) {
639665
return kTfLiteOk;
640666
}
641667

668+
std::unordered_set<string> GetAllOperatorOutputs(ModelT* model) {
669+
std::unordered_set<string> operator_names;
670+
for (int32_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
671+
subgraph_idx++) {
672+
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
673+
for (int32_t tensor_idx = 0; tensor_idx < subgraph->tensors.size();
674+
tensor_idx++) {
675+
operator_names.insert(subgraph->tensors[tensor_idx]->name);
676+
}
677+
}
678+
return operator_names;
679+
}
680+
642681
} // namespace
643682

644683
// Assumes that the operators in the model have been topologically sorted.
645684
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
646685
ModelT* model, const TensorType& input_type,
647686
const TensorType& output_type, bool allow_float,
687+
const std::unordered_set<string>& operator_names,
648688
ErrorReporter* error_reporter) {
689+
TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput(
690+
model, allow_float, operator_names, error_reporter));
649691
TF_LITE_ENSURE_STATUS(
650-
QuantizeWeightsInputOutput(model, allow_float, error_reporter));
651-
TF_LITE_ENSURE_STATUS(ApplyConstraints(model, error_reporter));
652-
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, error_reporter));
692+
ApplyConstraints(model, operator_names, error_reporter));
693+
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, error_reporter));
653694
utils::SetOperatorCodeVersion(model);
654695
TF_LITE_ENSURE_STATUS(
655696
SetInputAndOutputTypes(model, input_type, output_type, error_reporter));
@@ -661,6 +702,14 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
661702
return kTfLiteOk;
662703
}
663704

705+
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
706+
ModelT* model, const TensorType& input_type,
707+
const TensorType& output_type, bool allow_float,
708+
ErrorReporter* error_reporter) {
709+
return QuantizeModel(builder, model, input_type, output_type, allow_float,
710+
GetAllOperatorOutputs(model), error_reporter);
711+
}
712+
664713
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
665714
ModelT* model, const TensorType& input_type,
666715
const TensorType& output_type,

tensorflow/lite/tools/optimize/quantize_model.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ limitations under the License.
1616
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZE_MODEL_H_
1717

1818
#include <memory>
19+
#include <unordered_set>
1920

2021
#include "tensorflow/lite/context.h"
2122
#include "tensorflow/lite/core/api/error_reporter.h"
2223
#include "tensorflow/lite/model.h"
2324
#include "tensorflow/lite/schema/schema_generated.h"
25+
#include "tensorflow/lite/util.h"
2426

2527
namespace tflite {
2628
namespace optimize {
@@ -53,6 +55,16 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
5355
const TensorType& output_type, bool allow_float,
5456
ErrorReporter* error_reporter);
5557

58+
// Same as above, but enables only quantizing a whitelist of operations,
59+
// specified by their operator output name.
60+
//
61+
// Note: This is a private API, subject to change.
62+
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
63+
ModelT* input_model, const TensorType& input_type,
64+
const TensorType& output_type, bool allow_float,
65+
const std::unordered_set<string>& operator_names,
66+
ErrorReporter* error_reporter);
67+
5668
} // namespace optimize
5769
} // namespace tflite
5870

tensorflow/lite/tools/optimize/quantize_model_test.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,31 @@ TEST_F(QuantizeConvModelTest, QuantizationSucceeds) {
9898
ASSERT_TRUE(output_model);
9999
}
100100

101+
TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
102+
auto status =
103+
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
104+
/*allow_float=*/true, {}, &error_reporter_);
105+
EXPECT_EQ(status, kTfLiteOk);
106+
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
107+
// The resulting model should be the same.
108+
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
109+
for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
110+
subgraph_idx++) {
111+
const auto quantized_graph = model_.subgraphs[subgraph_idx].get();
112+
const auto float_graph = readonly_model_->subgraphs()->Get(subgraph_idx);
113+
ASSERT_EQ(quantized_graph->tensors.size(), float_graph->tensors()->size());
114+
for (size_t i = 0; i < quantized_graph->tensors.size(); i++) {
115+
const auto quant_tensor = quantized_graph->tensors[i].get();
116+
const auto float_tensor = float_graph->tensors()->Get(i);
117+
EXPECT_EQ(quant_tensor->buffer, float_tensor->buffer());
118+
EXPECT_EQ(quant_tensor->is_variable, float_tensor->is_variable());
119+
EXPECT_EQ(quant_tensor->shape, GetAsVector(float_tensor->shape()));
120+
EXPECT_EQ(quant_tensor->name, float_tensor->name()->str());
121+
EXPECT_EQ(quant_tensor->type, float_tensor->type());
122+
}
123+
}
124+
}
125+
101126
TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
102127
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
103128
TensorType_INT8, &error_reporter_);

0 commit comments

Comments
 (0)