-
Notifications
You must be signed in to change notification settings - Fork 215
Draft: Java API to use tf.function available on SavedModel. #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
karllessard
merged 1 commit into
tensorflow:shared-saved-model
from
Shajan:sd/tf.function
Jul 28, 2020
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
157 changes: 157 additions & 0 deletions
157
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
/* | ||
* Copyright 2020 The TensorFlow Authors. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.tensorflow; | ||
|
||
import com.google.protobuf.InvalidProtocolBufferException; | ||
|
||
import java.util.List; | ||
import java.util.ListIterator; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
/** | ||
* Invoke <a href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a> | ||
* defined in a {@link SavedModelBundle}. | ||
* | ||
* <pre>{@code | ||
* TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName"); | ||
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap); | ||
* }</pre> | ||
* | ||
*/ | ||
public class TfFunction { | ||
|
||
public TfFunction( | ||
String functionSignatureName, | ||
SavedModelBundle.SignatureToNodeName nameToNode, Session session) { | ||
this.nameToNode = nameToNode; | ||
this.session = session; | ||
this.functionSignatureName = functionSignatureName; | ||
} | ||
|
||
/** | ||
* Invokes a tf.function. | ||
* Caller is responsible for closing all Tensors. | ||
* | ||
* @param arguments map of input tensors | ||
* @return map of output tensors | ||
*/ | ||
public Map<String, Tensor<?>> call( | ||
Map<String, Tensor<?>> arguments) throws IllegalArgumentException { | ||
|
||
Session.Runner runner = this.session.runner(); | ||
|
||
Map<String, String> inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); | ||
|
||
if (inputToNode == null) { | ||
throw new IllegalArgumentException( | ||
String.format("Function [%s] is missing input", this.functionSignatureName)); | ||
} | ||
|
||
// Join arguments.key, inputToNodeName.key | ||
for (Map.Entry<String, String> entry: inputToNode.entrySet()) { | ||
String argName = entry.getKey(); | ||
Tensor<?> tensor = arguments.get(argName); | ||
|
||
if (tensor == null) { | ||
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); | ||
} | ||
|
||
// Node name in the tensorflow graph, corresponding to the tf.function argument | ||
runner = runner.feed(entry.getValue(), tensor); | ||
} | ||
|
||
Map<String, String> outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); | ||
if (outputToNode == null) { | ||
throw new IllegalArgumentException( | ||
String.format("Function [%] is missing output", this.functionSignatureName)); | ||
} | ||
|
||
for (String nodeName: outputToNode.values()) { | ||
// Node names corresponding to the return value | ||
runner = runner.fetch(nodeName); | ||
} | ||
|
||
List<Tensor<?>> resultTensors = runner.run(); | ||
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator(); | ||
|
||
Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>(); | ||
|
||
// Use the output names as present in the signature definition | ||
for (String nodeName: outputToNode.keySet()) { | ||
returnMap.put(nodeName, resultTensorIter.next()); | ||
} | ||
|
||
return returnMap; | ||
} | ||
|
||
/** | ||
* Invokes a tf.function. | ||
* Caller is responsible for closing all Tensors. | ||
* | ||
* Throws IllegalArgumentException if there are multiple input or output parameters defined | ||
* in the tf.function | ||
* | ||
* @param tensor input tensor | ||
* @return output tensor | ||
*/ | ||
public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException { | ||
Session.Runner runner = this.session.runner(); | ||
|
||
Map<String, String> inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); | ||
|
||
if (inputToNode == null) { | ||
throw new IllegalArgumentException( | ||
String.format("Function [%s] is missing input", this.functionSignatureName)); | ||
} | ||
|
||
if (inputToNode.size() != 1) { | ||
throw new IllegalArgumentException( | ||
String.format("Function [%s] requires multiple inputs", this.functionSignatureName)); | ||
} | ||
|
||
// Feed the single argument | ||
for (Map.Entry<String, String> entry: inputToNode.entrySet()) { | ||
// Node name in the tensorflow graph, corresponding to the tf.function argument | ||
runner = runner.feed(entry.getValue(), tensor); | ||
} | ||
|
||
Map<String, String> outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); | ||
if (outputToNode == null) { | ||
throw new IllegalArgumentException( | ||
String.format("Function [%] is missing output", this.functionSignatureName)); | ||
} | ||
|
||
if (outputToNode.size() != 1) { | ||
throw new IllegalArgumentException( | ||
String.format("Function [%s] has multiple outputs", this.functionSignatureName)); | ||
} | ||
|
||
// Fetch the single return tensor | ||
for (String nodeName: outputToNode.values()) { | ||
// Node names corresponding to the return value | ||
runner = runner.fetch(nodeName); | ||
} | ||
|
||
List<Tensor<?>> resultTensors = runner.run(); | ||
|
||
return resultTensors.get(0); | ||
} | ||
|
||
private final Session session; | ||
private final SavedModelBundle.SignatureToNodeName nameToNode; | ||
private final String functionSignatureName; | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Read through this change, and I believe the following API design would be both cleaner and more intuitive to use. Wdyt?
withSignature(String... signatures)
.Loader.load()
to also constructTfFunctions
corresponding to the specified signatures.SignatureToNodeName
can become private and not exposed to end user, preferrably living insideTfFunction
as private.TfFunction
should not have a reference to thesession
. Let SavedModelBundle manage the session (which it already does).TfFunction
can become a pure data class.SavedModelBundle.call(String signature, Map<String, Tensor<?>> inputs, Map<String, Tensor<?>> outputs)
toSavedModelBundle
.TfFunction
can become private and not exposed to end user as well.What do you think about the above proposed API design?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exposing SignatureToNodeName (3) sounds good
TtFunction not having reference to session (4) is good as well
Regarding (5), it is desirable to have a way to do repeated call to the same function. Having a TfFuction class allows for this. I am also thinking removing runtime check currently done with each call in the Tensor call(Tensor) method and do that once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your proposal @yzhuang , some thoughts on it:
Looking at it,
getSignatureToNodeName()
can already be private inSavedModelBundle
andSignatureToNodeName
could be restricted at the default-package level. I don't think it should be inTfFunction
though as it maps all signatures while an instance ofTfFunction
is only mapped to one of them.That would prevent though to just invoke
function.call(tensor)
to run the graph, which I personally like. While I understand that data classes have their lot of advantages, I'm not sure there is real gains for havingTfFunction
as one of them, especially thatSession
instances are thread-safe.Again, I personally like the OO approach of letting users manipulating callable entities instead of having
SavedModelBundle
acting as a "service". Is the intention here is just to hide theTfFunction
at the user level? Note that I was planning to reuse the same object to add new signatures to a model when exporting it (but I could also do it just withSavedModelBundle
if we decide to take that direction).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Karl and Shajan,
I have been busy with traveling, and didn't have time to follow up on this thread. Please feel free to take whatever you find helpful from my suggestions, and discard the rest. Thank you for taking time to go through my comments!
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @karllessard and @Shajan ,
Sorry for the late reply on your comment—just landed in Hong Kong :)
My suggestion to let SavedModel manage the session is not about thread safety, but about resource ownership. SavedModel currently “owns” the session and is responsible for closing the resource to avoid memory leaks. If we create TFFunction objects holding references to the session, it complicate the resource ownership mental model. For example, we will need to do things such as the below:
My intention was to not have to think about the resource ownership problem, and keep the resource ownership unchanged. The API of having TFFunctions be callable and exposed to users sounds great to me, and we need to think about the above 1 & 2 though. This is not a contrived scenario: we already do model hot swapping at Twitter in our prediction servers, and SavedModels are hot swapped without restarting the JVM.
Thank you Karl and Shajan!