From 9ffdba6a311365f647e5f1ef1eb561d323f4c404 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 17 May 2023 11:08:50 -0700 Subject: [PATCH] Strips the op name suffix added from `PopulateFunctionalToRegionPatterns` inside `PopulateRegionToFunctionalPatterns` to reduce model size and improve debugging experience. PiperOrigin-RevId: 532845336 --- .../transforms/region_to_functional/BUILD | 1 + .../transforms/region_to_functional/impl.cc | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/tensorflow/core/transforms/region_to_functional/BUILD b/tensorflow/core/transforms/region_to_functional/BUILD index 6da4f3fe8c19eb..8c1eace0470658 100644 --- a/tensorflow/core/transforms/region_to_functional/BUILD +++ b/tensorflow/core/transforms/region_to_functional/BUILD @@ -18,6 +18,7 @@ cc_library( "//tensorflow/core/ir:Dialect", "//tensorflow/core/ir/types:Dialect", "//tensorflow/core/transforms:utils", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/tensorflow/core/transforms/region_to_functional/impl.cc b/tensorflow/core/transforms/region_to_functional/impl.cc index 7fac57bb7fda91..318ab38c0c970d 100644 --- a/tensorflow/core/transforms/region_to_functional/impl.cc +++ b/tensorflow/core/transforms/region_to_functional/impl.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_split.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -151,6 +153,9 @@ class BasePattern { ArrayAttr GetControlRetAttrs(ValueRange ctls, ValueRange args, NameUniquer *name_uniquer) const; + // Strip out added names. + void StripAddedSuffix(Region ®ion) const; + // Create a function with the given name and attributes. Use the types of the // block arguments and the given results types. Take the body of the region. GraphFuncOp CreateFunc(Location loc, const Twine &sym_name, Region ®ion, @@ -598,6 +603,20 @@ ArrayAttr BasePattern::GetControlRetAttrs(ValueRange ctls, ValueRange args, return ArrayAttr::get(ctx_, ctl_ret_attrs); } +void BasePattern::StripAddedSuffix(Region ®ion) const { + StringAttr name_id = dialect_.getNameAttrIdentifier(); + for (Operation &op : region.getOps()) { + if (auto name = op.getAttrOfType(name_id)) { + if (absl::StrContains(name.getValue().str(), "_tfg_inlined_")) { + std::vector name_components = + absl::StrSplit(name.getValue().str(), "_tfg_inlined_"); + auto new_name = StringAttr::get(op.getContext(), name_components[0]); + op.setAttr(name_id, new_name); + } + } + } +} + GraphFuncOp BasePattern::CreateFunc(Location loc, const Twine &sym_name, Region ®ion, TypeRange res_types, NamedAttrList attrs) const { @@ -644,6 +663,8 @@ FuncAttr BasePattern::Outline(Operation *op, PatternRewriter &rewriter, ValueRange args, Region ®ion, RegionAttr preserved, DictionaryAttr attrs, const Twine &func_name) const { + StripAddedSuffix(region); + // Create a name scope for the function. NameUniquer name_uniquer(ctx_);