diff --git a/tflite/toco/import_tensorflow.cc b/tflite/toco/import_tensorflow.cc index 4700aa25..86d93599 100644 --- a/tflite/toco/import_tensorflow.cc +++ b/tflite/toco/import_tensorflow.cc @@ -142,9 +142,9 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node, return attr.list(); } -tensorflow::Status CheckOptionalAttr(const NodeDef& node, - const std::string& attr_name, - const std::string& expected_value) { +absl::Status CheckOptionalAttr(const NodeDef& node, + const std::string& attr_name, + const std::string& expected_value) { if (HasAttr(node, attr_name)) { const std::string& value = GetStringAttr(node, attr_name); if (value != expected_value) { @@ -156,9 +156,9 @@ tensorflow::Status CheckOptionalAttr(const NodeDef& node, return absl::OkStatus(); } -tensorflow::Status CheckOptionalAttr( - const NodeDef& node, const std::string& attr_name, - const tensorflow::DataType& expected_value) { +absl::Status CheckOptionalAttr(const NodeDef& node, + const std::string& attr_name, + const tensorflow::DataType& expected_value) { if (HasAttr(node, attr_name)) { const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name); if (value != expected_value) { @@ -171,8 +171,8 @@ tensorflow::Status CheckOptionalAttr( } template -tensorflow::Status ExpectValue(const T1& v1, const T2& v2, - const std::string& description) { +absl::Status ExpectValue(const T1& v1, const T2& v2, + const std::string& description) { if (v1 == v2) return absl::OkStatus(); return tensorflow::errors::InvalidArgument(absl::StrCat( "Unexpected ", description, ": got ", v1, ", expected ", v2)); @@ -204,10 +204,9 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kNone; } -tensorflow::Status ImportShape( - const TFLITE_PROTO_NS::RepeatedPtrField& - input_dims, - int* input_flat_size, Shape* shape) { +absl::Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< + tensorflow::TensorShapeProto_Dim>& input_dims, + int* input_flat_size, Shape* shape) { std::vector input_dims_only_sizes; bool zero_sized_shape = false; for (auto& d : input_dims) { @@ -344,9 +343,9 @@ struct TensorTraits { }; template -tensorflow::Status ImportTensorData(const TensorProto& input_tensor, - int input_flat_size, - std::vector* output_data) { +absl::Status ImportTensorData(const TensorProto& input_tensor, + int input_flat_size, + std::vector* output_data) { CHECK_GE(output_data->size(), input_flat_size); int num_elements_in_tensor = TensorTraits::size(input_tensor); if (num_elements_in_tensor == input_flat_size) { @@ -384,8 +383,8 @@ tensorflow::Status ImportTensorData(const TensorProto& input_tensor, return absl::OkStatus(); } -tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportFloatArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -402,8 +401,8 @@ tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, &output_float_data); } -tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportComplex64Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_COMPLEX64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -420,8 +419,8 @@ tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor, &output_complex_data); } -tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportQuint8Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -437,8 +436,8 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, &output_int_data); } -tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportInt32Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -454,8 +453,8 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, &output_int_data); } -tensorflow::Status ImportUint32Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportUint32Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_UINT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -471,8 +470,8 @@ tensorflow::Status ImportUint32Array(const TensorProto& input_tensor, &output_int_data); } -tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportInt64Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -488,8 +487,8 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, &output_int_data); } -tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportBoolArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -515,8 +514,8 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, return status; } -tensorflow::Status ImportStringArray(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportStringArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -556,9 +555,9 @@ int GetInputsCount(const NodeDef& node, return node.input_size(); } -tensorflow::Status CheckInputsCount( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - int expected_input_count) { +absl::Status CheckInputsCount(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + int expected_input_count) { if (GetInputsCount(node, tf_import_flags) != expected_input_count) { return tensorflow::errors::FailedPrecondition( node.op(), " node expects ", expected_input_count, @@ -689,7 +688,7 @@ void GetOutputTypesFromNodeDef(const NodeDef& node, } } -tensorflow::Status ConvertUnsupportedOperator( +absl::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { // Names of special attributes in TF graph that are used by Toco. @@ -777,14 +776,14 @@ tensorflow::Status ConvertUnsupportedOperator( return absl::OkStatus(); } -tensorflow::Status ConvertConstOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertConstOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Const"); const auto& tensor = GetTensorAttr(node, "value"); const auto dtype = GetDataTypeAttr(node, "dtype"); - tensorflow::Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); auto& array = model->GetOrCreateArray(node.name()); switch (dtype) { @@ -833,9 +832,9 @@ tensorflow::Status ConvertConstOperator( return absl::OkStatus(); } -tensorflow::Status ConvertConvOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertConvOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Conv2D"); TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2)); @@ -914,7 +913,7 @@ tensorflow::Status ConvertConvOperator( return absl::OkStatus(); } -tensorflow::Status ConvertDepthwiseConvOperator( +absl::Status ConvertDepthwiseConvOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "DepthwiseConv2dNative"); @@ -992,7 +991,7 @@ tensorflow::Status ConvertDepthwiseConvOperator( return absl::OkStatus(); } -tensorflow::Status ConvertDepthToSpaceOperator( +absl::Status ConvertDepthToSpaceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "DepthToSpace"); @@ -1015,7 +1014,7 @@ tensorflow::Status ConvertDepthToSpaceOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSpaceToDepthOperator( +absl::Status ConvertSpaceToDepthOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SpaceToDepth"); @@ -1038,7 +1037,7 @@ tensorflow::Status ConvertSpaceToDepthOperator( return absl::OkStatus(); } -tensorflow::Status ConvertBiasAddOperator( +absl::Status ConvertBiasAddOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "BiasAdd"); @@ -1055,9 +1054,9 @@ tensorflow::Status ConvertBiasAddOperator( return absl::OkStatus(); } -tensorflow::Status ConvertRandomUniform( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertRandomUniform(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "RandomUniform"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); @@ -1073,7 +1072,7 @@ tensorflow::Status ConvertRandomUniform( return absl::OkStatus(); } -tensorflow::Status ConvertIdentityOperator( +absl::Status ConvertIdentityOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || @@ -1096,7 +1095,7 @@ tensorflow::Status ConvertIdentityOperator( return absl::OkStatus(); } -tensorflow::Status ConvertIdentityNOperator( +absl::Status ConvertIdentityNOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "IdentityN"); @@ -1114,7 +1113,7 @@ tensorflow::Status ConvertIdentityNOperator( return absl::OkStatus(); } -tensorflow::Status ConvertFakeQuantWithMinMaxArgs( +absl::Status ConvertFakeQuantWithMinMaxArgs( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); @@ -1135,7 +1134,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs( return absl::OkStatus(); } -tensorflow::Status ConvertFakeQuantWithMinMaxVars( +absl::Status ConvertFakeQuantWithMinMaxVars( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); @@ -1157,7 +1156,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars( return absl::OkStatus(); } -tensorflow::Status ConvertSqueezeOperator( +absl::Status ConvertSqueezeOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Squeeze"); @@ -1178,9 +1177,9 @@ tensorflow::Status ConvertSqueezeOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSplitOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSplitOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Split"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSplitOperator; @@ -1196,9 +1195,10 @@ tensorflow::Status ConvertSplitOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSplitVOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSplitVOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "SplitV"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new TensorFlowSplitVOperator; @@ -1215,9 +1215,10 @@ tensorflow::Status ConvertSplitVOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSwitchOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSwitchOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "Switch"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSwitchOperator; @@ -1230,7 +1231,7 @@ tensorflow::Status ConvertSwitchOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSoftmaxOperator( +absl::Status ConvertSoftmaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Softmax"); @@ -1250,9 +1251,9 @@ tensorflow::Status ConvertSoftmaxOperator( return absl::OkStatus(); } -tensorflow::Status ConvertLRNOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertLRNOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "LRN"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); @@ -1267,7 +1268,7 @@ tensorflow::Status ConvertLRNOperator( return absl::OkStatus(); } -tensorflow::Status ConvertMaxPoolOperator( +absl::Status ConvertMaxPoolOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "MaxPool"); @@ -1310,7 +1311,7 @@ tensorflow::Status ConvertMaxPoolOperator( return absl::OkStatus(); } -tensorflow::Status ConvertAvgPoolOperator( +absl::Status ConvertAvgPoolOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "AvgPool"); @@ -1349,7 +1350,7 @@ tensorflow::Status ConvertAvgPoolOperator( return absl::OkStatus(); } -tensorflow::Status ConvertBatchMatMulOperator( +absl::Status ConvertBatchMatMulOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -1372,9 +1373,10 @@ tensorflow::Status ConvertBatchMatMulOperator( return absl::OkStatus(); } -tensorflow::Status ConvertMatMulOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertMatMulOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); CHECK(!HasAttr(node, "adjoint_a") || @@ -1396,9 +1398,10 @@ tensorflow::Status ConvertMatMulOperator( return absl::OkStatus(); } -tensorflow::Status ConvertConcatOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertConcatOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { Operator* op = nullptr; if (node.op() == "Concat") { op = new TensorFlowConcatOperator; @@ -1421,7 +1424,7 @@ tensorflow::Status ConvertConcatOperator( return absl::OkStatus(); } -tensorflow::Status ConvertMirrorPadOperator( +absl::Status ConvertMirrorPadOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { if (node.op() != "MirrorPad") { @@ -1456,7 +1459,7 @@ enum FlexSupport { kFlexOk, kFlexNotOk }; // kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator // will be eligible for being exported as a flex op. template -tensorflow::Status ConvertSimpleOperatorGeneric( +absl::Status ConvertSimpleOperatorGeneric( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { if (NumInputs != kAnyNumInputs) { @@ -1484,16 +1487,17 @@ tensorflow::Status ConvertSimpleOperatorGeneric( // Convert a simple operator which is not valid as a flex op. template -tensorflow::Status ConvertSimpleOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSimpleOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { return ConvertSimpleOperatorGeneric( node, tf_import_flags, model_flags, model); } // Convert a simple operator which is valid as a flex op. template -tensorflow::Status ConvertSimpleOperatorFlexOk( +absl::Status ConvertSimpleOperatorFlexOk( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { return ConvertSimpleOperatorGeneric( @@ -1503,7 +1507,7 @@ tensorflow::Status ConvertSimpleOperatorFlexOk( // Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if // the types are not supported. Converting Const operators here avoids // expensive copies of the protocol buffers downstream in the flex delegate. -tensorflow::Status ConditionallyConvertConstOperator( +absl::Status ConditionallyConvertConstOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { // We avoid incomplete and zero shapes because the resulting arrays @@ -1531,7 +1535,7 @@ tensorflow::Status ConditionallyConvertConstOperator( } } -tensorflow::Status ConvertStridedSliceOperator( +absl::Status ConvertStridedSliceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "StridedSlice"); @@ -1560,7 +1564,7 @@ tensorflow::Status ConvertStridedSliceOperator( return absl::OkStatus(); } -tensorflow::Status ConvertPlaceholderOperator( +absl::Status ConvertPlaceholderOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); @@ -1600,15 +1604,15 @@ tensorflow::Status ConvertPlaceholderOperator( return absl::OkStatus(); } -tensorflow::Status ConvertNoOpOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertNoOpOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { return absl::OkStatus(); } -tensorflow::Status ConvertCastOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertCastOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Cast"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); @@ -1622,9 +1626,9 @@ tensorflow::Status ConvertCastOperator( return absl::OkStatus(); } -tensorflow::Status ConvertFloorOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertFloorOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Floor"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); @@ -1636,9 +1640,9 @@ tensorflow::Status ConvertFloorOperator( return absl::OkStatus(); } -tensorflow::Status ConvertCeilOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertCeilOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Ceil"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); @@ -1650,9 +1654,9 @@ tensorflow::Status ConvertCeilOperator( return absl::OkStatus(); } -tensorflow::Status ConvertRoundOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertRoundOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Round"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); @@ -1664,9 +1668,10 @@ tensorflow::Status ConvertRoundOperator( return absl::OkStatus(); } -tensorflow::Status ConvertGatherOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertGatherOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK(node.op() == "Gather" || node.op() == "GatherV2"); if (node.op() == "Gather") TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -1693,7 +1698,7 @@ tensorflow::Status ConvertGatherOperator( return absl::OkStatus(); } -tensorflow::Status ConvertGatherNdOperator( +absl::Status ConvertGatherNdOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "GatherNd"); @@ -1709,7 +1714,7 @@ tensorflow::Status ConvertGatherNdOperator( } template -tensorflow::Status ConvertArgMinMaxOperator( +absl::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -1729,23 +1734,25 @@ tensorflow::Status ConvertArgMinMaxOperator( return absl::OkStatus(); } -tensorflow::Status ConvertArgMaxOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertArgMaxOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "ArgMax"); return ConvertArgMinMaxOperator(node, tf_import_flags, model_flags, model); } -tensorflow::Status ConvertArgMinOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertArgMinOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "ArgMin"); return ConvertArgMinMaxOperator(node, tf_import_flags, model_flags, model); } -tensorflow::Status ConvertResizeBilinearOperator( +absl::Status ConvertResizeBilinearOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ResizeBilinear"); @@ -1768,7 +1775,7 @@ tensorflow::Status ConvertResizeBilinearOperator( return absl::OkStatus(); } -tensorflow::Status ConvertResizeNearestNeighborOperator( +absl::Status ConvertResizeNearestNeighborOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ResizeNearestNeighbor"); @@ -1791,7 +1798,7 @@ tensorflow::Status ConvertResizeNearestNeighborOperator( return absl::OkStatus(); } -tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( +absl::Status ConvertBatchNormWithGlobalNormalizationOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); @@ -1841,7 +1848,7 @@ tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( return absl::OkStatus(); } -tensorflow::Status ConvertFusedBatchNormOperator( +absl::Status ConvertFusedBatchNormOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK((node.op() == "FusedBatchNorm") || (node.op() == "FusedBatchNormV3")); @@ -1896,7 +1903,7 @@ tensorflow::Status ConvertFusedBatchNormOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSpaceToBatchNDOperator( +absl::Status ConvertSpaceToBatchNDOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SpaceToBatchND"); @@ -1912,7 +1919,7 @@ tensorflow::Status ConvertSpaceToBatchNDOperator( return absl::OkStatus(); } -tensorflow::Status ConvertBatchToSpaceNDOperator( +absl::Status ConvertBatchToSpaceNDOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "BatchToSpaceND"); @@ -1929,9 +1936,10 @@ tensorflow::Status ConvertBatchToSpaceNDOperator( } template -tensorflow::Status ConvertReduceOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertReduceOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new T; op->inputs.push_back(node.input(0)); @@ -1947,9 +1955,9 @@ tensorflow::Status ConvertReduceOperator( } // TODO(b/139320642): Add test when fused op is supported. -tensorflow::Status ConvertSvdfOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSvdfOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Svdf"); const int input_size = GetInputsCount(node, tf_import_flags); QCHECK(input_size == 4 || input_size == 5) @@ -1977,7 +1985,7 @@ tensorflow::Status ConvertSvdfOperator( } // This is just bare bones support to get the shapes to propagate. -tensorflow::Status ConvertTransposeConvOperator( +absl::Status ConvertTransposeConvOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Conv2DBackpropInput"); @@ -2048,9 +2056,9 @@ tensorflow::Status ConvertTransposeConvOperator( return absl::OkStatus(); } -tensorflow::Status ConvertRangeOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertRangeOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Range"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new RangeOperator; @@ -2073,9 +2081,9 @@ tensorflow::Status ConvertRangeOperator( // they aren't the same thing. tf.stack results in a "Pack" operator. "Stack" // operators also exist, but involve manipulating the TF runtime stack, and are // not directly related to tf.stack() usage. -tensorflow::Status ConvertPackOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertPackOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Pack"); auto op = std::make_unique(); const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -2095,9 +2103,10 @@ tensorflow::Status ConvertPackOperator( return absl::OkStatus(); } -tensorflow::Status ConvertUnpackOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertUnpackOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "Unpack"); auto op = std::make_unique(); const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -2125,7 +2134,7 @@ tensorflow::Status ConvertUnpackOperator( // such ops as RNN back-edges, which is technically incorrect (does not // allow representing the op's semantics) but good enough to get a // graph visualization. -tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( +absl::Status ConvertOperatorSpecialCasedAsRNNBackEdge( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { // At the moment, the only type of operator special-cased in this way is @@ -2144,9 +2153,9 @@ tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( return absl::OkStatus(); } -tensorflow::Status ConvertShapeOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertShapeOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Shape"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto out_type = @@ -2160,7 +2169,7 @@ tensorflow::Status ConvertShapeOperator( return absl::OkStatus(); } -tensorflow::Status ConvertReverseSequenceOperator( +absl::Status ConvertReverseSequenceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ReverseSequence"); @@ -2327,9 +2336,10 @@ bool InlineAllFunctions(GraphDef* graphdef) { return graph_modified; } -tensorflow::Status ConvertTopKV2Operator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertTopKV2Operator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); auto op = std::make_unique(); op->inputs.push_back(node.input(0)); @@ -2349,7 +2359,7 @@ tensorflow::Status ConvertTopKV2Operator( return absl::OkStatus(); } -tensorflow::Status ConvertDynamicPartitionOperator( +absl::Status ConvertDynamicPartitionOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { auto op = std::make_unique(); @@ -2367,7 +2377,7 @@ tensorflow::Status ConvertDynamicPartitionOperator( return absl::OkStatus(); } -tensorflow::Status ConvertDynamicStitchOperator( +absl::Status ConvertDynamicStitchOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { // The parallel and non-parallel variants are the same besides whether they @@ -2386,7 +2396,7 @@ tensorflow::Status ConvertDynamicStitchOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSparseToDenseOperator( +absl::Status ConvertSparseToDenseOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SparseToDense"); @@ -2405,9 +2415,10 @@ tensorflow::Status ConvertSparseToDenseOperator( return absl::OkStatus(); } -tensorflow::Status ConvertOneHotOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertOneHotOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "OneHot"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); @@ -2426,7 +2437,7 @@ tensorflow::Status ConvertOneHotOperator( return absl::OkStatus(); } -tensorflow::Status ConvertCTCBeamSearchDecoderOperator( +absl::Status ConvertCTCBeamSearchDecoderOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "CTCBeamSearchDecoder"); @@ -2456,7 +2467,7 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator( // This isn't a TensorFlow builtin op. Currently this node can only be generated // with TfLite OpHint API. -tensorflow::Status ConvertUnidirectionalSequenceLstm( +absl::Status ConvertUnidirectionalSequenceLstm( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm"); @@ -2512,7 +2523,7 @@ tensorflow::Status ConvertUnidirectionalSequenceLstm( return absl::OkStatus(); } -tensorflow::Status ConvertLeakyReluOperator( +absl::Status ConvertLeakyReluOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "LeakyRelu"); @@ -2527,7 +2538,7 @@ tensorflow::Status ConvertLeakyReluOperator( return absl::OkStatus(); } -tensorflow::Status ConvertUnidirectionalSequenceRnn( +absl::Status ConvertUnidirectionalSequenceRnn( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn"); @@ -2552,7 +2563,7 @@ tensorflow::Status ConvertUnidirectionalSequenceRnn( namespace internal { -using ConverterType = tensorflow::Status (*)( +using ConverterType = absl::Status (*)( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model); using ConverterMapType = std::unordered_map; @@ -2721,10 +2732,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() { }); } -tensorflow::Status ImportTensorFlowNode( - const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, - Model* model, const ConverterMapType& converter_map) { +absl::Status ImportTensorFlowNode(const tensorflow::NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model, + const ConverterMapType& converter_map) { auto converter = converter_map.find(node.op()); if (converter == converter_map.end()) { return ConvertUnsupportedOperator(node, tf_import_flags, model_flags, diff --git a/tflite/toco/import_tensorflow_test.cc b/tflite/toco/import_tensorflow_test.cc index 90917871..023e98b3 100644 --- a/tflite/toco/import_tensorflow_test.cc +++ b/tflite/toco/import_tensorflow_test.cc @@ -47,7 +47,7 @@ using tensorflow::Status; using ::testing::ElementsAre; namespace internal { -using ConverterType = tensorflow::Status (*)( +using ConverterType = absl::Status (*)( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model); using ConverterMapType = std::unordered_map;