@@ -20,17 +20,24 @@ limitations under the License.
2020#include < vector>
2121
2222#include " absl/container/btree_map.h"
23+ #include " absl/strings/str_cat.h"
24+ #include " absl/strings/str_split.h"
25+ #include " absl/strings/strip.h"
2326#include " tensorflow/core/framework/attr_value.pb.h"
2427#include " tensorflow/core/framework/function.pb.h"
2528#include " tensorflow/core/framework/op_def.pb.h"
2629#include " tensorflow/core/framework/types.pb.h"
2730#include " tensorflow/core/framework/versions.pb.h"
2831#include " tensorflow/core/grappler/op_types.h"
32+ #include " tensorflow/core/lib/strings/numbers.h"
2933#include " tensorflow/core/lib/strings/proto_serialization.h"
34+ #include " tensorflow/core/platform/errors.h"
3035#include " tensorflow/core/platform/fingerprint.h"
36+ #include " tensorflow/core/platform/statusor.h"
3137#include " tensorflow/core/protobuf/fingerprint.pb.h"
3238#include " tensorflow/core/protobuf/meta_graph.pb.h"
3339#include " tensorflow/core/protobuf/saved_model.pb.h"
40+ #include " tensorflow/core/protobuf/saved_object_graph.pb.h"
3441
3542namespace tensorflow ::fingerprinting {
3643
@@ -61,6 +68,17 @@ void CanonicalizeNodes(GraphDef* orig_graph_def) {
6168 }
6269}
6370
71+ // Returns the suffix UID of `function_name`.
72+ StatusOr<int > GetSuffixUID (absl::string_view function_name) {
73+ std::vector<std::string> v = absl::StrSplit (function_name, ' _' );
74+ int uid;
75+ if (!strings::safe_strto32 (v.back (), &uid)) {
76+ return errors::InvalidArgument (absl::StrCat (
77+ " Function name: `" , function_name, " ` does not end in an integer." ));
78+ }
79+ return uid;
80+ }
81+
6482} // namespace
6583
6684uint64 ComputeHash (const GraphDef& graph_def) {
@@ -84,6 +102,11 @@ FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph) {
84102 // Set fingerprint field #3.
85103 fingerprint_def.set_signature_def_hash (
86104 RegularizeAndHashSignatureDefs (metagraph_copy.signature_def ()));
105+ // Set fingerprint field #4.
106+ StatusOr<uint64> object_graph_hash =
107+ RegularizeAndHashSavedObjectGraph (metagraph_copy.object_graph_def ());
108+ fingerprint_def.set_saved_object_graph_hash (
109+ RegularizeAndHashSavedObjectGraph (metagraph_copy.object_graph_def ()));
87110 return fingerprint_def;
88111}
89112
@@ -114,4 +137,40 @@ uint64 RegularizeAndHashSignatureDefs(
114137 return result_hash;
115138}
116139
140+ // The SavedObjectGraph contains two parts: the list of nodes and the map of
141+ // concrete functions. Regularization treats these two parts separately.
142+ uint64 RegularizeAndHashSavedObjectGraph (
143+ const SavedObjectGraph& object_graph_def) {
144+ // Sort `concrete_functions`, which is an unordered map from function names to
145+ // SavedConcreteFunction, using the suffix UID of the function name. Assumes
146+ // that the trackable children are listed in a deterministic order during
147+ // serialization.
148+ absl::btree_map<int , std::string> uid_to_function_names;
149+ for (const auto & [name, concrete_function] :
150+ object_graph_def.concrete_functions ()) {
151+ StatusOr<int > uid = GetSuffixUID (name);
152+ // All valid function names should end in an UID.
153+ if (uid.ok ()) {
154+ uid_to_function_names.insert ({*uid, name});
155+ } else {
156+ LOG (ERROR) << uid.status ().error_message ();
157+ }
158+ }
159+ uint64 result_hash = 0 ;
160+ for (const auto & [uid, function_name] : uid_to_function_names) {
161+ // Hash the function name (with the UID stripped).
162+ result_hash = FingerprintCat64 (result_hash,
163+ tensorflow::Fingerprint64 (absl::StripSuffix (
164+ function_name, std::to_string (uid))));
165+ // Hash the serialized concrete function.
166+ std::string concrete_function_string;
167+ SerializeToStringDeterministic (
168+ object_graph_def.concrete_functions ().at (function_name),
169+ &concrete_function_string);
170+ result_hash = FingerprintCat64 (
171+ result_hash, tensorflow::Fingerprint64 (concrete_function_string));
172+ }
173+ // TODO(b/241294832): Complete canonicalization of `object_graph_def.nodes`.
174+ return result_hash;
175+ }
117176} // namespace tensorflow::fingerprinting
0 commit comments