Skip to content

Draft: Java API to use tf.function available on SavedModel. #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;

import com.google.protobuf.InvalidProtocolBufferException;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
Expand All @@ -32,6 +35,7 @@
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.proto.framework.SignatureDef;

/**
* SavedModelBundle represents a model loaded from storage.
Expand Down Expand Up @@ -94,6 +98,101 @@ private Loader(String exportDir) {
private RunOptions runOptions = null;
}

/**
* SignatureToNodeName finds the node names in the {@link Graph} corresponding to the
* input / output parameters of a <a
* href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a>
*/
public static final class SignatureToNodeName {
Copy link

@yzhuang yzhuang Jul 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read through this change, and I believe the following API design would be both cleaner and more intuitive to use. Wdyt?

  1. Update Loader and add withSignature(String... signatures).
  2. Update Loader.load() to also construct TfFunctions corresponding to the specified signatures.
  3. SignatureToNodeName can become private and not exposed to end user, preferrably living inside TfFunction as private.
  4. TfFunction should not have a reference to the session. Let SavedModelBundle manage the session (which it already does). TfFunction can become a pure data class.
  5. Add SavedModelBundle.call(String signature, Map<String, Tensor<?>> inputs, Map<String, Tensor<?>> outputs) to SavedModelBundle. TfFunction can become private and not exposed to end user as well.

What do you think about the above proposed API design?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exposing SignatureToNodeName (3) sounds good
TtFunction not having reference to session (4) is good as well

Regarding (5), it is desirable to have a way to do repeated call to the same function. Having a TfFuction class allows for this. I am also thinking removing runtime check currently done with each call in the Tensor call(Tensor) method and do that once.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your proposal @yzhuang , some thoughts on it:

SignatureToNodeName can become private and not exposed to end user, preferrably living inside TfFunction as private.

Looking at it, getSignatureToNodeName() can already be private in SavedModelBundle and SignatureToNodeName could be restricted at the default-package level. I don't think it should be in TfFunction though as it maps all signatures while an instance of TfFunction is only mapped to one of them.

TfFunction should not have a reference to the session. Let SavedModelBundle manage the session (which it already does). TfFunction can become a pure data class.

That would prevent though to just invoke function.call(tensor) to run the graph, which I personally like. While I understand that data classes have their lot of advantages, I'm not sure there is real gains for having TfFunction as one of them, especially that Session instances are thread-safe.

Add SavedModelBundle.call(String signature, Map<String, Tensor> inputs, Map> outputs) to SavedModelBundle. TfFunction can become private and not exposed to end user as well.

Again, I personally like the OO approach of letting users manipulating callable entities instead of having SavedModelBundle acting as a "service". Is the intention here is just to hide the TfFunction at the user level? Note that I was planning to reuse the same object to add new signatures to a model when exporting it (but I could also do it just with SavedModelBundle if we decide to take that direction).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Karl and Shajan,

I have been busy with traveling, and didn't have time to follow up on this thread. Please feel free to take whatever you find helpful from my suggestions, and discard the rest. Thank you for taking time to go through my comments!

Copy link

@yzhuang yzhuang Aug 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @karllessard and @Shajan ,

Sorry for the late reply on your comment—just landed in Hong Kong :)

My suggestion to let SavedModel manage the session is not about thread safety, but about resource ownership. SavedModel currently “owns” the session and is responsible for closing the resource to avoid memory leaks. If we create TFFunction objects holding references to the session, it complicate the resource ownership mental model. For example, we will need to do things such as the below:

  1. Perform reference counting on the session, and have the last “user” of the session close it. This is user friendly, but introduces complexity and room for bugs. OR
  2. Continue to let SavedModel own the session, but TFFunction objects need to anticipate that it’s underlying resource (the session) can become defunct at anytime. This can be counterintuitive to users as users likely need to surround their call with a try/catch.

Is the intention here is just to hide the TfFunction at the user level?

My intention was to not have to think about the resource ownership problem, and keep the resource ownership unchanged. The API of having TFFunctions be callable and exposed to users sounds great to me, and we need to think about the above 1 & 2 though. This is not a contrived scenario: we already do model hot swapping at Twitter in our prediction servers, and SavedModels are hot swapped without restarting the JVM.

Thank you Karl and Shajan!


public SignatureToNodeName(SavedModelBundle savedModelBundle) {
loadSignatures(savedModelBundle);
}

/**
* Given a tf.function signature name, find the node names corresponding
* to the input arguments
*
* @param functionSignatureName tf.function signature name
* @return a map from input arguments to node names in the {@link Graph}
*/
public Map<String, String> inputNameToNode(String functionSignatureName) {
NameContainer nc = this.functionMap.get(functionSignatureName);
return (nc == null) ? null : nc.inputNameToNode();
}

/**
* Given a tf.function signature name, find the node names corresponding
* to the output arguments
*
* @param functionSignatureName tf.function signature name
* @return a map from output arguments to node names in the {@link Graph}
*/
public Map<String, String> outputNameToNode(String functionSignatureName) {
NameContainer nc = this.functionMap.get(functionSignatureName);
return (nc == null) ? null : nc.outputNameToNode();
}

/**
* Given a tf.function signature name, find the method name
*/
public String methodName(String functionSignatureName) {
NameContainer nc = this.functionMap.get(functionSignatureName);
return (nc == null) ? null : nc.methodName();
}

private void loadSignatures(SavedModelBundle savedModelBundle) {
MetaGraphDef metaGraph = savedModelBundle.metaGraphDef();
Map<String, SignatureDef> signatureMap = metaGraph.getSignatureDefMap();

// A saved model can contain multiple SignatureDef
for (Map.Entry<String, SignatureDef> entry : signatureMap.entrySet()) {
NameContainer nc = new NameContainer(entry.getValue());
this.functionMap.put(entry.getKey(), nc);
}
}

private Map<String, NameContainer> functionMap = new HashMap<>();

private static final class NameContainer {
NameContainer(SignatureDef sd) {
this.inputNameToNodeName = sd.getInputsMap()
.entrySet()
.stream()
.collect(Collectors.toMap(
e -> e.getKey(),
e -> e.getValue().getName()
));

this.outputNameToNodeName = sd.getOutputsMap()
.entrySet()
.stream()
.collect(Collectors.toMap(
e -> e.getKey(),
e -> e.getValue().getName()
));

this.method = sd.getMethodName();
}

public Map<String, String> inputNameToNode() {
return this.inputNameToNodeName;
}

public Map<String, String> outputNameToNode() {
return this.outputNameToNodeName;
}

public String methodName() {
return this.method;
}

private Map<String, String> inputNameToNodeName;
private Map<String, String> outputNameToNodeName;
private String method;
}
}

/**
* Load a saved model from an export directory. The model that is being loaded should be created
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
Expand Down Expand Up @@ -148,6 +247,34 @@ public Session session() {
return session;
}

/**
* Returns the {@link SignatureToNodeName} translator for the model.
*
* @return SignatureToNodeName translator
*/
public SignatureToNodeName getSignatureToNodeName() {
if (this.sigToNodeName == null) {
// no need to lock, ok to create multiple instances
this.sigToNodeName = new SignatureToNodeName(this);
}
return this.sigToNodeName;
}

/**
* Return a {@link TfFunction} corresponding to the function signature.
*
* <pre>{@code
* TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
* }</pre>
*
* @param functionSignatureName name of the {@code SignatureDef} in the saved model.
* @return TfFunction object that can be used to make calls to the tf.function
*/
public TfFunction function(String functionSignatureName) {
return new TfFunction(functionSignatureName, this.getSignatureToNodeName(), this.session);
}

/**
* Releases resources (the {@link Graph} and {@link Session}) associated with the saved model
* bundle.
Expand All @@ -161,6 +288,7 @@ public void close() {
private final Graph graph;
private final Session session;
private final MetaGraphDef metaGraphDef;
private SignatureToNodeName sigToNodeName;

private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef) {
this.graph = graph;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright 2020 The TensorFlow Authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow;

import com.google.protobuf.InvalidProtocolBufferException;

import java.util.List;
import java.util.ListIterator;
import java.util.HashMap;
import java.util.Map;

/**
* Invoke <a href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a>
* defined in a {@link SavedModelBundle}.
*
* <pre>{@code
* TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
* }</pre>
*
*/
public class TfFunction {

public TfFunction(
String functionSignatureName,
SavedModelBundle.SignatureToNodeName nameToNode, Session session) {
this.nameToNode = nameToNode;
this.session = session;
this.functionSignatureName = functionSignatureName;
}

/**
* Invokes a tf.function.
* Caller is responsible for closing all Tensors.
*
* @param arguments map of input tensors
* @return map of output tensors
*/
public Map<String, Tensor<?>> call(
Map<String, Tensor<?>> arguments) throws IllegalArgumentException {

Session.Runner runner = this.session.runner();

Map<String, String> inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName);

if (inputToNode == null) {
throw new IllegalArgumentException(
String.format("Function [%s] is missing input", this.functionSignatureName));
}

// Join arguments.key, inputToNodeName.key
for (Map.Entry<String, String> entry: inputToNode.entrySet()) {
String argName = entry.getKey();
Tensor<?> tensor = arguments.get(argName);

if (tensor == null) {
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName));
}

// Node name in the tensorflow graph, corresponding to the tf.function argument
runner = runner.feed(entry.getValue(), tensor);
}

Map<String, String> outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName);
if (outputToNode == null) {
throw new IllegalArgumentException(
String.format("Function [%] is missing output", this.functionSignatureName));
}

for (String nodeName: outputToNode.values()) {
// Node names corresponding to the return value
runner = runner.fetch(nodeName);
}

List<Tensor<?>> resultTensors = runner.run();
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator();

Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>();

// Use the output names as present in the signature definition
for (String nodeName: outputToNode.keySet()) {
returnMap.put(nodeName, resultTensorIter.next());
}

return returnMap;
}

/**
* Invokes a tf.function.
* Caller is responsible for closing all Tensors.
*
* Throws IllegalArgumentException if there are multiple input or output parameters defined
* in the tf.function
*
* @param tensor input tensor
* @return output tensor
*/
public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException {
Session.Runner runner = this.session.runner();

Map<String, String> inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName);

if (inputToNode == null) {
throw new IllegalArgumentException(
String.format("Function [%s] is missing input", this.functionSignatureName));
}

if (inputToNode.size() != 1) {
throw new IllegalArgumentException(
String.format("Function [%s] requires multiple inputs", this.functionSignatureName));
}

// Feed the single argument
for (Map.Entry<String, String> entry: inputToNode.entrySet()) {
// Node name in the tensorflow graph, corresponding to the tf.function argument
runner = runner.feed(entry.getValue(), tensor);
}

Map<String, String> outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName);
if (outputToNode == null) {
throw new IllegalArgumentException(
String.format("Function [%] is missing output", this.functionSignatureName));
}

if (outputToNode.size() != 1) {
throw new IllegalArgumentException(
String.format("Function [%s] has multiple outputs", this.functionSignatureName));
}

// Fetch the single return tensor
for (String nodeName: outputToNode.values()) {
// Node names corresponding to the return value
runner = runner.fetch(nodeName);
}

List<Tensor<?>> resultTensors = runner.run();

return resultTensors.get(0);
}

private final Session session;
private final SavedModelBundle.SignatureToNodeName nameToNode;
private final String functionSignatureName;
}