@@ -20,6 +20,8 @@ limitations under the License.
2020#include < memory>
2121#include < string>
2222#include < unordered_map>
23+ #include < unordered_set>
24+ #include < utility>
2325#include < vector>
2426
2527#include " flatbuffers/flexbuffers.h"
@@ -36,6 +38,24 @@ namespace optimize {
3638
3739namespace {
3840
41+ // Gets the operator property from the operator_property list and additionally
42+ // modifies the quantizable parameter based on the user's specified
43+ // operator_names.
44+ operator_property::OperatorProperty GetOperatorProperty (
45+ const std::unordered_set<string>& operator_names, const BuiltinOperator& op,
46+ const string& operator_name) {
47+ operator_property::OperatorProperty property =
48+ operator_property::GetOperatorProperty (op);
49+ // The algorithm adds Dequantize and Quantize, so we don't require them to be
50+ // in the operator_names.
51+ if (op != BuiltinOperator_DEQUANTIZE && op != BuiltinOperator_QUANTIZE) {
52+ property.quantizable =
53+ property.quantizable &&
54+ (operator_names.find (operator_name) != operator_names.end ());
55+ }
56+ return property;
57+ }
58+
3959TfLiteStatus QuantizeBias (ModelT* model, const TensorT* input_tensor,
4060 const TensorT* weight_tensor, TensorT* bias_tensor,
4161 bool is_per_channel, int channel_dim_index,
@@ -239,8 +259,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
239259 // TODO(suharshs): Add support for this case if it ever comes up.
240260 if (tensor->type == TensorType_FLOAT32 && output_type != tensor->type ) {
241261 error_reporter->Report (
242- " Unsupported output type %s for output tensor %d of type %s." ,
243- EnumNameTensorType (output_type), subgraph-> outputs [i] ,
262+ " Unsupported output type %s for output tensor '%s' of type %s." ,
263+ EnumNameTensorType (output_type), tensor-> name . c_str () ,
244264 EnumNameTensorType (tensor->type ));
245265 return kTfLiteError ;
246266 }
@@ -260,7 +280,9 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
260280// outpus must have the same scale and zero point. The other ones with
261281// constraints(averagepool, maxpool, gather, softmax, tanh etc) are handled in
262282// QuantizeWeightsAndInput.
263- TfLiteStatus ApplyConstraints (ModelT* model, ErrorReporter* error_reporter) {
283+ TfLiteStatus ApplyConstraints (ModelT* model,
284+ const std::unordered_set<string>& operator_names,
285+ ErrorReporter* error_reporter) {
264286 for (int subgraph_idx = 0 ; subgraph_idx < model->subgraphs .size ();
265287 subgraph_idx++) {
266288 SubGraphT* subgraph = model->subgraphs .at (subgraph_idx).get ();
@@ -269,8 +291,8 @@ TfLiteStatus ApplyConstraints(ModelT* model, ErrorReporter* error_reporter) {
269291 OperatorT* op = subgraph->operators [op_idx].get ();
270292 const BuiltinOperator op_code =
271293 model->operator_codes [op->opcode_index ]->builtin_code ;
272- operator_property::OperatorProperty property =
273- operator_property::GetOperatorProperty ( op_code);
294+ operator_property::OperatorProperty property = GetOperatorProperty (
295+ operator_names, op_code, subgraph-> tensors [op-> outputs [ 0 ]]-> name );
274296 if (!property.quantizable ) {
275297 continue ;
276298 }
@@ -546,17 +568,19 @@ TfLiteStatus QuantizeOpOutput(
546568
547569// Quantize inputs and weights.
548570// Because of ops such as lstm, still need to do per op, instead of weights.
549- TfLiteStatus QuantizeWeightsInputOutput (ModelT* model, bool allow_float,
550- ErrorReporter* error_reporter) {
571+ TfLiteStatus QuantizeWeightsInputOutput (
572+ ModelT* model, bool allow_float,
573+ const std::unordered_set<string>& operator_names,
574+ ErrorReporter* error_reporter) {
551575 for (size_t subgraph_idx = 0 ; subgraph_idx < model->subgraphs .size ();
552576 subgraph_idx++) {
553577 SubGraphT* subgraph = model->subgraphs .at (subgraph_idx).get ();
554578 for (size_t op_idx = 0 ; op_idx < subgraph->operators .size (); op_idx++) {
555579 OperatorT* op = subgraph->operators [op_idx].get ();
556580 const BuiltinOperator op_code =
557581 model->operator_codes [op->opcode_index ]->builtin_code ;
558- operator_property::OperatorProperty property =
559- operator_property::GetOperatorProperty ( op_code);
582+ operator_property::OperatorProperty property = GetOperatorProperty (
583+ operator_names, op_code, subgraph-> tensors [op-> outputs [ 0 ]]-> name );
560584
561585 if (!property.quantizable && !allow_float) {
562586 error_reporter->Report (" Quantization not yet supported for op: %s" ,
@@ -583,16 +607,18 @@ TfLiteStatus QuantizeWeightsInputOutput(ModelT* model, bool allow_float,
583607}
584608
585609// Quantize bias.
586- TfLiteStatus QuantizeBiases (ModelT* model, ErrorReporter* error_reporter) {
610+ TfLiteStatus QuantizeBiases (ModelT* model,
611+ const std::unordered_set<string>& operator_names,
612+ ErrorReporter* error_reporter) {
587613 for (size_t subgraph_idx = 0 ; subgraph_idx < model->subgraphs .size ();
588614 subgraph_idx++) {
589615 SubGraphT* subgraph = model->subgraphs .at (subgraph_idx).get ();
590616 for (size_t op_idx = 0 ; op_idx < subgraph->operators .size (); op_idx++) {
591617 OperatorT* op = subgraph->operators [op_idx].get ();
592618 const BuiltinOperator op_code =
593619 model->operator_codes [op->opcode_index ]->builtin_code ;
594- operator_property::OperatorProperty property =
595- operator_property::GetOperatorProperty ( op_code);
620+ operator_property::OperatorProperty property = GetOperatorProperty (
621+ operator_names, op_code, subgraph-> tensors [op-> outputs [ 0 ]]-> name );
596622 if (!property.quantizable ) {
597623 continue ;
598624 }
@@ -639,17 +665,32 @@ TfLiteStatus QuantizeBiases(ModelT* model, ErrorReporter* error_reporter) {
639665 return kTfLiteOk ;
640666}
641667
668+ std::unordered_set<string> GetAllOperatorOutputs (ModelT* model) {
669+ std::unordered_set<string> operator_names;
670+ for (int32_t subgraph_idx = 0 ; subgraph_idx < model->subgraphs .size ();
671+ subgraph_idx++) {
672+ SubGraphT* subgraph = model->subgraphs .at (subgraph_idx).get ();
673+ for (int32_t tensor_idx = 0 ; tensor_idx < subgraph->tensors .size ();
674+ tensor_idx++) {
675+ operator_names.insert (subgraph->tensors [tensor_idx]->name );
676+ }
677+ }
678+ return operator_names;
679+ }
680+
642681} // namespace
643682
644683// Assumes that the operators in the model have been topologically sorted.
645684TfLiteStatus QuantizeModel (flatbuffers::FlatBufferBuilder* builder,
646685 ModelT* model, const TensorType& input_type,
647686 const TensorType& output_type, bool allow_float,
687+ const std::unordered_set<string>& operator_names,
648688 ErrorReporter* error_reporter) {
689+ TF_LITE_ENSURE_STATUS (QuantizeWeightsInputOutput (
690+ model, allow_float, operator_names, error_reporter));
649691 TF_LITE_ENSURE_STATUS (
650- QuantizeWeightsInputOutput (model, allow_float, error_reporter));
651- TF_LITE_ENSURE_STATUS (ApplyConstraints (model, error_reporter));
652- TF_LITE_ENSURE_STATUS (QuantizeBiases (model, error_reporter));
692+ ApplyConstraints (model, operator_names, error_reporter));
693+ TF_LITE_ENSURE_STATUS (QuantizeBiases (model, operator_names, error_reporter));
653694 utils::SetOperatorCodeVersion (model);
654695 TF_LITE_ENSURE_STATUS (
655696 SetInputAndOutputTypes (model, input_type, output_type, error_reporter));
@@ -661,6 +702,14 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
661702 return kTfLiteOk ;
662703}
663704
705+ TfLiteStatus QuantizeModel (flatbuffers::FlatBufferBuilder* builder,
706+ ModelT* model, const TensorType& input_type,
707+ const TensorType& output_type, bool allow_float,
708+ ErrorReporter* error_reporter) {
709+ return QuantizeModel (builder, model, input_type, output_type, allow_float,
710+ GetAllOperatorOutputs (model), error_reporter);
711+ }
712+
664713TfLiteStatus QuantizeModel (flatbuffers::FlatBufferBuilder* builder,
665714 ModelT* model, const TensorType& input_type,
666715 const TensorType& output_type,
0 commit comments