Skip to content

Commit

Permalink
Native functions v2 (#233)
Browse files Browse the repository at this point in the history
* Initial native function use

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Allow body constants

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Fix body forbids

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Use default eager session for tensor calls

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Use default eager for single tensor call too

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Get functions from graph

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Start of saver support

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Update loading, detect statefulness, use PartitionedCall

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Start of dependencies

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Support dependencies

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Remove unwrapping

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Proper attribute setters

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Add ignored gradient test

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Rebase fix

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Op generation for functions

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Rebase fix

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* SavedFunction for running functions from SavedModelBundles

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Review fixes

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Generation and better javadoc

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Rework pointer scopes

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* SessionFunction instead of SavedModelBundle specific class

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Add CallableFunction javadoc

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Remove obsolete test

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Rebase fix

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Formatting fixes and nits

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Add session function test, Signature.builder with name

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Remove extra synchronization

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Formatting

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* New names

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Note on SavedModel functions

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Fix tests

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Rename name method

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Re-add tests w/ SessionFunction

Signed-off-by: Ryan Nett <JNett96@gmail.com>

* Helper methods for saving

Signed-off-by: Ryan Nett <JNett96@gmail.com>
  • Loading branch information
rnett authored May 31, 2021
1 parent 3b4533c commit daeb257
Show file tree
Hide file tree
Showing 25 changed files with 2,588 additions and 864 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import org.tensorflow.ConcreteFunction;
import org.tensorflow.DeviceSpec;
import org.tensorflow.EagerSession;
import org.tensorflow.ExecutionEnvironment;
Expand Down Expand Up @@ -87,6 +89,7 @@
import org.tensorflow.op.core.ExtractVolumePatches;
import org.tensorflow.op.core.Fill;
import org.tensorflow.op.core.Fingerprint;
import org.tensorflow.op.core.Function;
import org.tensorflow.op.core.Gather;
import org.tensorflow.op.core.GatherNd;
import org.tensorflow.op.core.GetSessionHandle;
Expand Down Expand Up @@ -1116,6 +1119,31 @@ public Bucketize bucketize(Operand<? extends TNumber> input, List<Float> boundar
return Bucketize.create(scope, input, boundaries);
}

/**
* Calls the function in an execution environment, adding its graph as a function if it isn't
* already present. Only works for functions with a single input and output.
*
* @param argument the argument to the call
* @return the output of the function
* @see ConcreteFunction#call(Ops, Operand)
*/
public Operand<?> call(ConcreteFunction function, Operand<?> argument) {
return Function.call(scope, function, argument);
}

/**
* Calls the function in an execution environment, adding its graph as a function if it isn't
* already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
*
* @param arguments the arguments to the call
* @return the outputs of the function
* @see ConcreteFunction#call(Ops, Map)
*/
public Map<String, Operand<?>> call(ConcreteFunction function,
Map<String, Operand<?>> arguments) {
return Function.call(scope, function, arguments);
}

/**
* Clips tensor values to a specified min and max.
* Given a tensor {@code t}, this operation returns a tensor of the same type and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// Once created and added to graphs, functions can be invoked by creating an
// operation whose operation type matches the function name.
@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class TF_Function extends Pointer {
public class TF_Function extends org.tensorflow.internal.c_api.AbstractTF_Function {
/** Empty constructor. Calls {@code super((Pointer)null)}. */
public TF_Function() { super((Pointer)null); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
Expand Down
Loading

0 comments on commit daeb257

Please sign in to comment.