diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 09e5a47f8fd..a5c2df84026 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -68,6 +68,11 @@ public String type() { return type; } + @Override + public EagerSession env() { + return session; + } + @Override public int numOutputs() { return outputHandles.length; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index 37f3af7ca26..a865300bc5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -75,6 +75,7 @@ public EagerOperation build() { @Override public EagerOperationBuilder addInput(Output> input) { + session.checkInput(input); addInput(opHandle, (TFE_TensorHandle) input.getUnsafeNativeHandle()); return this; } @@ -83,6 +84,7 @@ public EagerOperationBuilder addInput(Output> input) { public EagerOperationBuilder addInputList(Output>[] inputs) { TFE_TensorHandle[] inputHandles = new TFE_TensorHandle[inputs.length]; for (int i = 0; i < inputs.length; ++i) { + session.checkInput(inputs[i]); inputHandles[i] = (TFE_TensorHandle) inputs[i].getUnsafeNativeHandle(); } addInputList(opHandle, inputHandles); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 96ef5228a4f..75bc12b5a6c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -27,6 +27,7 @@ import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.internal.c_api.TFE_ContextOptions; import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; @@ -297,6 +298,13 @@ public boolean isOpEnabled(String opType) { } } + @Override + public void checkInput(Op input) { + if (!input.env().isEager()) { + throw new IllegalArgumentException("Can't use graph operation " + input + " in eager mode."); + } + } + TFE_Context nativeHandle() { checkSession(); return nativeHandle; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index a7a3363f690..d5389bcd0ad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -15,7 +15,11 @@ package org.tensorflow; -/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */ +import org.tensorflow.op.Op; + +/** + * Defines an environment for creating and executing TensorFlow {@link Operation}s. + */ public interface ExecutionEnvironment { enum Types { @@ -36,13 +40,23 @@ enum Types { /** * Returns true if the given operation is valid in this execution environment. + * * @param opType The op to check. * @return Whether the given operation is valid in this execution environment. */ - default boolean isOpEnabled(String opType){ + default boolean isOpEnabled(String opType) { return true; } + /** + * Checks that {@code input} is valid to use as an input in this execution environment. Throws {@link + * IllegalArgumentException} if not. + * + * @param input The op to check + * @throws IllegalArgumentException if input can't be used as an input in this execution environment. + */ + void checkInput(Op input); + /** * Get the type of this environment (from the `Environments` enumeration. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index f2717f263eb..988683895c4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -158,6 +158,17 @@ public Types environmentType() { return Types.GRAPH; } + @Override + public void checkInput(Op input) { + if (input.env().isEager()) { + throw new IllegalArgumentException( + "Input " + input + " was from an eager session, can't use in a graph. Use tf.constantOf(input.asTensor())"); + } + if (input.env() != this) { + throw new IllegalArgumentException("Input " + input + " was from a different graph, can't use."); + } + } + /** * Import a representation of a TensorFlow graph. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java index e1255748c3b..fbad92160a2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java @@ -73,6 +73,13 @@ public String type() { } } + @Override + public Graph env() { + try (Graph.Reference r = graph.ref()) { + return graph; + } + } + @Override public int numOutputs() { Graph.Reference r = graph.ref(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 9c0f011bab4..72858ece572 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -92,6 +92,11 @@ public GraphOperationBuilder addControlInput(Operation control) { throw new IllegalArgumentException( "Only GraphOperation instances can be used as control inputs"); } + + if (control.env() != graph) { + throw new IllegalArgumentException("Control input " + control + " was from a different graph, can't use."); + } + Graph.Reference r = graph.ref(); try { addControlInput(unsafeNativeHandle, ((GraphOperation) control).getUnsafeNativeHandle()); @@ -103,6 +108,7 @@ public GraphOperationBuilder addControlInput(Operation control) { @Override public GraphOperationBuilder addInput(Output> input) { + graph.checkInput(input); Graph.Reference r = graph.ref(); try { addInput(unsafeNativeHandle, (TF_Operation) input.getUnsafeNativeHandle(), input.index()); @@ -114,6 +120,10 @@ public GraphOperationBuilder addInput(Output> input) { @Override public GraphOperationBuilder addInputList(Output>[] inputs) { + for (Output> input : inputs) { + graph.checkInput(input); + } + Graph.Reference r = graph.ref(); try { TF_Operation[] opHandles = new TF_Operation[inputs.length]; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java index 1cc175da161..b47eee6850c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java @@ -25,16 +25,24 @@ */ public interface Operation { - /** Returns the full name of the Operation. */ + /** + * Returns the full name of the Operation. + */ String name(); /** - * Returns the type of the operation, i.e., the name of the computation performed by the - * operation. + * Returns the type of the operation, i.e., the name of the computation performed by the operation. */ String type(); - /** Returns the number of tensors produced by this operation. */ + /** + * Returns the execution environment this operation was created in. + */ + ExecutionEnvironment env(); + + /** + * Returns the number of tensors produced by this operation. + */ int numOutputs(); /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java index 40b54393c60..6051623414f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java @@ -15,6 +15,7 @@ package org.tensorflow.op; +import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operation; /** @@ -48,4 +49,11 @@ public interface Op { * @return an {@link Operation} */ Operation op(); + + /** + * Return the execution environment this op was created in. + */ + default ExecutionEnvironment env() { + return op().env(); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java index f0b739e074f..73fa340a487 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java @@ -16,14 +16,12 @@ package org.tensorflow.op; import java.util.ArrayList; - import org.tensorflow.DeviceSpec; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.OperationBuilder; /** - * Manages groups of related properties when creating Tensorflow Operations, such as a common name - * prefix. + * Manages groups of related properties when creating Tensorflow Operations, such as a common name prefix. * *
A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user * code initializes a {@code Scope} and provides it to Operation building classes. For example: @@ -88,7 +86,9 @@ public Scope(ExecutionEnvironment env) { this(env, new NameScope(), new ArrayList<>(), DeviceSpec.newBuilder().build()); } - /** Returns the execution environment used by this scope. */ + /** + * Returns the execution environment used by this scope. + */ public ExecutionEnvironment env() { return env; } @@ -97,8 +97,7 @@ public ExecutionEnvironment env() { * Returns a new scope where added operations will have the provided name prefix. * *
Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual - * name will be unique in the returned scope. All other properties are inherited from the current - * scope. + * name will be unique in the returned scope. All other properties are inherited from the current scope. * *
The child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} * @@ -129,7 +128,8 @@ public Scope withName(String opName) { /** * Return a new scope that uses the provided device specification for an op. * - *
Operations created within this scope will place the created operations on the device(s) matching the provided spec. + *
Operations created within this scope will place the created operations on the device(s) matching the provided + * spec. * * @param deviceSpec device specification for an operator in the returned scope * @return a new Scope that uses opName for operations. @@ -151,8 +151,8 @@ public Scope withDevice(DeviceSpec deviceSpec) { * } * *
Note: if you provide a composite operator building class (i.e, a class that creates a
- * set of related operations by calling other operator building code), the provided name will act
- * as a subscope to all underlying operators.
+ * set of related operations by calling other operator building code), the provided name will act as a subscope to all
+ * underlying operators.
*
* @param defaultName name for the underlying operator.
* @return unique name for the operator.
@@ -180,11 +180,15 @@ private Scope(
* @return a new scope with the provided control dependencies
*/
public Scope withControlDependencies(Iterable