diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 09e5a47f8fd..a5c2df84026 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -68,6 +68,11 @@ public String type() { return type; } + @Override + public EagerSession env() { + return session; + } + @Override public int numOutputs() { return outputHandles.length; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index 37f3af7ca26..a865300bc5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -75,6 +75,7 @@ public EagerOperation build() { @Override public EagerOperationBuilder addInput(Output input) { + session.checkInput(input); addInput(opHandle, (TFE_TensorHandle) input.getUnsafeNativeHandle()); return this; } @@ -83,6 +84,7 @@ public EagerOperationBuilder addInput(Output input) { public EagerOperationBuilder addInputList(Output[] inputs) { TFE_TensorHandle[] inputHandles = new TFE_TensorHandle[inputs.length]; for (int i = 0; i < inputs.length; ++i) { + session.checkInput(inputs[i]); inputHandles[i] = (TFE_TensorHandle) inputs[i].getUnsafeNativeHandle(); } addInputList(opHandle, inputHandles); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 96ef5228a4f..75bc12b5a6c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -27,6 +27,7 @@ import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.internal.c_api.TFE_ContextOptions; import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; @@ -297,6 +298,13 @@ public boolean isOpEnabled(String opType) { } } + @Override + public void checkInput(Op input) { + if (!input.env().isEager()) { + throw new IllegalArgumentException("Can't use graph operation " + input + " in eager mode."); + } + } + TFE_Context nativeHandle() { checkSession(); return nativeHandle; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index a7a3363f690..d5389bcd0ad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -15,7 +15,11 @@ package org.tensorflow; -/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */ +import org.tensorflow.op.Op; + +/** + * Defines an environment for creating and executing TensorFlow {@link Operation}s. + */ public interface ExecutionEnvironment { enum Types { @@ -36,13 +40,23 @@ enum Types { /** * Returns true if the given operation is valid in this execution environment. + * * @param opType The op to check. * @return Whether the given operation is valid in this execution environment. */ - default boolean isOpEnabled(String opType){ + default boolean isOpEnabled(String opType) { return true; } + /** + * Checks that {@code input} is valid to use as an input in this execution environment. Throws {@link + * IllegalArgumentException} if not. + * + * @param input The op to check + * @throws IllegalArgumentException if input can't be used as an input in this execution environment. + */ + void checkInput(Op input); + /** * Get the type of this environment (from the `Environments` enumeration. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index f2717f263eb..988683895c4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -158,6 +158,17 @@ public Types environmentType() { return Types.GRAPH; } + @Override + public void checkInput(Op input) { + if (input.env().isEager()) { + throw new IllegalArgumentException( + "Input " + input + " was from an eager session, can't use in a graph. Use tf.constantOf(input.asTensor())"); + } + if (input.env() != this) { + throw new IllegalArgumentException("Input " + input + " was from a different graph, can't use."); + } + } + /** * Import a representation of a TensorFlow graph. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java index e1255748c3b..fbad92160a2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java @@ -73,6 +73,13 @@ public String type() { } } + @Override + public Graph env() { + try (Graph.Reference r = graph.ref()) { + return graph; + } + } + @Override public int numOutputs() { Graph.Reference r = graph.ref(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 9c0f011bab4..72858ece572 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -92,6 +92,11 @@ public GraphOperationBuilder addControlInput(Operation control) { throw new IllegalArgumentException( "Only GraphOperation instances can be used as control inputs"); } + + if (control.env() != graph) { + throw new IllegalArgumentException("Control input " + control + " was from a different graph, can't use."); + } + Graph.Reference r = graph.ref(); try { addControlInput(unsafeNativeHandle, ((GraphOperation) control).getUnsafeNativeHandle()); @@ -103,6 +108,7 @@ public GraphOperationBuilder addControlInput(Operation control) { @Override public GraphOperationBuilder addInput(Output input) { + graph.checkInput(input); Graph.Reference r = graph.ref(); try { addInput(unsafeNativeHandle, (TF_Operation) input.getUnsafeNativeHandle(), input.index()); @@ -114,6 +120,10 @@ public GraphOperationBuilder addInput(Output input) { @Override public GraphOperationBuilder addInputList(Output[] inputs) { + for (Output input : inputs) { + graph.checkInput(input); + } + Graph.Reference r = graph.ref(); try { TF_Operation[] opHandles = new TF_Operation[inputs.length]; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java index 1cc175da161..b47eee6850c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java @@ -25,16 +25,24 @@ */ public interface Operation { - /** Returns the full name of the Operation. */ + /** + * Returns the full name of the Operation. + */ String name(); /** - * Returns the type of the operation, i.e., the name of the computation performed by the - * operation. + * Returns the type of the operation, i.e., the name of the computation performed by the operation. */ String type(); - /** Returns the number of tensors produced by this operation. */ + /** + * Returns the execution environment this operation was created in. + */ + ExecutionEnvironment env(); + + /** + * Returns the number of tensors produced by this operation. + */ int numOutputs(); /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java index 40b54393c60..6051623414f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java @@ -15,6 +15,7 @@ package org.tensorflow.op; +import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operation; /** @@ -48,4 +49,11 @@ public interface Op { * @return an {@link Operation} */ Operation op(); + + /** + * Return the execution environment this op was created in. + */ + default ExecutionEnvironment env() { + return op().env(); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java index f0b739e074f..73fa340a487 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java @@ -16,14 +16,12 @@ package org.tensorflow.op; import java.util.ArrayList; - import org.tensorflow.DeviceSpec; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.OperationBuilder; /** - * Manages groups of related properties when creating Tensorflow Operations, such as a common name - * prefix. + * Manages groups of related properties when creating Tensorflow Operations, such as a common name prefix. * *

A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user * code initializes a {@code Scope} and provides it to Operation building classes. For example: @@ -88,7 +86,9 @@ public Scope(ExecutionEnvironment env) { this(env, new NameScope(), new ArrayList<>(), DeviceSpec.newBuilder().build()); } - /** Returns the execution environment used by this scope. */ + /** + * Returns the execution environment used by this scope. + */ public ExecutionEnvironment env() { return env; } @@ -97,8 +97,7 @@ public ExecutionEnvironment env() { * Returns a new scope where added operations will have the provided name prefix. * *

Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual - * name will be unique in the returned scope. All other properties are inherited from the current - * scope. + * name will be unique in the returned scope. All other properties are inherited from the current scope. * *

The child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} * @@ -129,7 +128,8 @@ public Scope withName(String opName) { /** * Return a new scope that uses the provided device specification for an op. * - *

Operations created within this scope will place the created operations on the device(s) matching the provided spec. + *

Operations created within this scope will place the created operations on the device(s) matching the provided + * spec. * * @param deviceSpec device specification for an operator in the returned scope * @return a new Scope that uses opName for operations. @@ -151,8 +151,8 @@ public Scope withDevice(DeviceSpec deviceSpec) { * } * *

Note: if you provide a composite operator building class (i.e, a class that creates a - * set of related operations by calling other operator building code), the provided name will act - * as a subscope to all underlying operators. + * set of related operations by calling other operator building code), the provided name will act as a subscope to all + * underlying operators. * * @param defaultName name for the underlying operator. * @return unique name for the operator. @@ -180,11 +180,15 @@ private Scope( * @return a new scope with the provided control dependencies */ public Scope withControlDependencies(Iterable controls) { + for (Op control : controls) { + env.checkInput(control); + } return new Scope(env, nameScope, controls, deviceSpec); } /** - * Applies device specification and adds each Operand in controlDependencies as a control input to the provided builder. + * Applies device specification and adds each Operand in controlDependencies as a control input to the provided + * builder. * * @param builder OperationBuilder to add control inputs and device specification to */ @@ -210,7 +214,9 @@ public OperationBuilder applyControlDependencies(OperationBuilder builder) { private final NameScope nameScope; private final DeviceSpec deviceSpec; - /** Returns device string from the scope. */ + /** + * Returns device string from the scope. + */ public String getDeviceString() { return deviceSpec.toString(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index bbb9e23ec90..33ae979ccbd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -23,7 +23,6 @@ import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Constant; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt32; @@ -31,22 +30,6 @@ /** Unit tests for {@link org.tensorflow.GraphOperationBuilder}. */ public class GraphOperationBuilderTest { - @Test - public void failWhenMixingOperationsOnDifferentGraphs() { - try (Graph g1 = new Graph(); - Graph g2 = new Graph()) { - Ops tf = Ops.create(g1); - Constant c1 = tf.constant(3); - tf.math.add(c1, c1); - try { - Ops tf2 = Ops.create(g2); - tf2.math.add(c1, c1); - } catch (Exception e) { - fail(e.toString()); - } - } - } - @Test public void failOnUseAfterBuild() { try (Graph g = new Graph(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java new file mode 100644 index 00000000000..b2fbc1e794a --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java @@ -0,0 +1,108 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.Test; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; + +/** + * Tests for using Operands in different environments + */ +public class WrongEnvTest { + + /** + * Should work fine + */ + @Test + public void testTwoEagers() { + try (EagerSession e1 = EagerSession.create(); + EagerSession e2 = EagerSession.create()) { + Ops tf1 = Ops.create(e1); + Ops tf2 = Ops.create(e2); + + Operand a = tf1.constant(5); + Operand b = tf2.constant(6); + + Operand c = tf2.math.add(a, b); + + try (TInt32 tensor = c.asTensor()) { + assertEquals(11, tensor.getInt()); + } + + } + } + + @Test + public void testEagerInGraph() { + try (EagerSession e1 = EagerSession.create(); + Graph e2 = new Graph()) { + Ops tf1 = Ops.create(e1); + Ops tf2 = Ops.create(e2); + + Operand a = tf1.constant(5); + Operand b = tf2.constant(6); + + Operand c = tf2.math.add(a, b); + + fail(); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("was from an eager session, can't use in a graph")); + } + } + + @Test + public void testGraphInEager() { + try (Graph e1 = new Graph(); + EagerSession e2 = EagerSession.create()) { + Ops tf1 = Ops.create(e1); + Ops tf2 = Ops.create(e2); + + Operand a = tf1.constant(5); + Operand b = tf2.constant(6); + + Operand c = tf2.math.add(a, b); + + fail(); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Can't use graph operation")); + } + } + + @Test + public void testTwoGraphs() { + try (Graph e1 = new Graph(); + Graph e2 = new Graph()) { + Ops tf1 = Ops.create(e1); + Ops tf2 = Ops.create(e2); + + Operand a = tf1.constant(5); + Operand b = tf2.constant(6); + + Operand c = tf2.math.add(a, b); + + fail(); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("was from a different graph")); + } + } + +}