Skip to content

Better cross-environment error messages #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public String type() {
return type;
}

@Override
public EagerSession env() {
return session;
}

@Override
public int numOutputs() {
return outputHandles.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public EagerOperation build() {

@Override
public EagerOperationBuilder addInput(Output<?> input) {
session.checkInput(input);
addInput(opHandle, (TFE_TensorHandle) input.getUnsafeNativeHandle());
return this;
}
Expand All @@ -83,6 +84,7 @@ public EagerOperationBuilder addInput(Output<?> input) {
public EagerOperationBuilder addInputList(Output<?>[] inputs) {
TFE_TensorHandle[] inputHandles = new TFE_TensorHandle[inputs.length];
for (int i = 0; i < inputs.length; ++i) {
session.checkInput(inputs[i]);
inputHandles[i] = (TFE_TensorHandle) inputs[i].getUnsafeNativeHandle();
}
addInputList(opHandle, inputHandles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.tensorflow.internal.c_api.TFE_Context;
import org.tensorflow.internal.c_api.TFE_ContextOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.op.Op;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
Expand Down Expand Up @@ -297,6 +298,13 @@ public boolean isOpEnabled(String opType) {
}
}

@Override
public void checkInput(Op input) {
if (!input.env().isEager()) {
throw new IllegalArgumentException("Can't use graph operation " + input + " in eager mode.");
}
}

TFE_Context nativeHandle() {
checkSession();
return nativeHandle;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

package org.tensorflow;

/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */
import org.tensorflow.op.Op;

/**
* Defines an environment for creating and executing TensorFlow {@link Operation}s.
*/
public interface ExecutionEnvironment {

enum Types {
Expand All @@ -36,13 +40,23 @@ enum Types {

/**
* Returns true if the given operation is valid in this execution environment.
*
* @param opType The op to check.
* @return Whether the given operation is valid in this execution environment.
*/
default boolean isOpEnabled(String opType){
default boolean isOpEnabled(String opType) {
return true;
}

/**
* Checks that {@code input} is valid to use as an input in this execution environment. Throws {@link
* IllegalArgumentException} if not.
*
* @param input The op to check
* @throws IllegalArgumentException if input can't be used as an input in this execution environment.
*/
void checkInput(Op input);

/**
* Get the type of this environment (from the `Environments` enumeration.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ public Types environmentType() {
return Types.GRAPH;
}

@Override
public void checkInput(Op input) {
if (input.env().isEager()) {
throw new IllegalArgumentException(
"Input " + input + " was from an eager session, can't use in a graph. Use tf.constantOf(input.asTensor())");
}
if (input.env() != this) {
throw new IllegalArgumentException("Input " + input + " was from a different graph, can't use.");
}
}

/**
* Import a representation of a TensorFlow graph.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ public String type() {
}
}

@Override
public Graph env() {
try (Graph.Reference r = graph.ref()) {
return graph;
}
}

@Override
public int numOutputs() {
Graph.Reference r = graph.ref();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ public GraphOperationBuilder addControlInput(Operation control) {
throw new IllegalArgumentException(
"Only GraphOperation instances can be used as control inputs");
}

if (control.env() != graph) {
throw new IllegalArgumentException("Control input " + control + " was from a different graph, can't use.");
}

Graph.Reference r = graph.ref();
try {
addControlInput(unsafeNativeHandle, ((GraphOperation) control).getUnsafeNativeHandle());
Expand All @@ -103,6 +108,7 @@ public GraphOperationBuilder addControlInput(Operation control) {

@Override
public GraphOperationBuilder addInput(Output<?> input) {
graph.checkInput(input);
Graph.Reference r = graph.ref();
try {
addInput(unsafeNativeHandle, (TF_Operation) input.getUnsafeNativeHandle(), input.index());
Expand All @@ -114,6 +120,10 @@ public GraphOperationBuilder addInput(Output<?> input) {

@Override
public GraphOperationBuilder addInputList(Output<?>[] inputs) {
for (Output<?> input : inputs) {
graph.checkInput(input);
}

Graph.Reference r = graph.ref();
try {
TF_Operation[] opHandles = new TF_Operation[inputs.length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,24 @@
*/
public interface Operation {

/** Returns the full name of the Operation. */
/**
* Returns the full name of the Operation.
*/
String name();

/**
* Returns the type of the operation, i.e., the name of the computation performed by the
* operation.
* Returns the type of the operation, i.e., the name of the computation performed by the operation.
*/
String type();

/** Returns the number of tensors produced by this operation. */
/**
* Returns the execution environment this operation was created in.
*/
ExecutionEnvironment env();

/**
* Returns the number of tensors produced by this operation.
*/
int numOutputs();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package org.tensorflow.op;

import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Operation;

/**
Expand Down Expand Up @@ -48,4 +49,11 @@ public interface Op {
* @return an {@link Operation}
*/
Operation op();

/**
* Return the execution environment this op was created in.
*/
default ExecutionEnvironment env() {
return op().env();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
package org.tensorflow.op;

import java.util.ArrayList;

import org.tensorflow.DeviceSpec;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.OperationBuilder;

/**
* Manages groups of related properties when creating Tensorflow Operations, such as a common name
* prefix.
* Manages groups of related properties when creating Tensorflow Operations, such as a common name prefix.
*
* <p>A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user
* code initializes a {@code Scope} and provides it to Operation building classes. For example:
Expand Down Expand Up @@ -88,7 +86,9 @@ public Scope(ExecutionEnvironment env) {
this(env, new NameScope(), new ArrayList<>(), DeviceSpec.newBuilder().build());
}

/** Returns the execution environment used by this scope. */
/**
* Returns the execution environment used by this scope.
*/
public ExecutionEnvironment env() {
return env;
}
Expand All @@ -97,8 +97,7 @@ public ExecutionEnvironment env() {
* Returns a new scope where added operations will have the provided name prefix.
*
* <p>Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual
* name will be unique in the returned scope. All other properties are inherited from the current
* scope.
* name will be unique in the returned scope. All other properties are inherited from the current scope.
*
* <p>The child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*}
*
Expand Down Expand Up @@ -129,7 +128,8 @@ public Scope withName(String opName) {
/**
* Return a new scope that uses the provided device specification for an op.
*
* <p>Operations created within this scope will place the created operations on the device(s) matching the provided spec.
* <p>Operations created within this scope will place the created operations on the device(s) matching the provided
* spec.
*
* @param deviceSpec device specification for an operator in the returned scope
* @return a new Scope that uses opName for operations.
Expand All @@ -151,8 +151,8 @@ public Scope withDevice(DeviceSpec deviceSpec) {
* }</pre>
*
* <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that creates a
* set of related operations by calling other operator building code), the provided name will act
* as a subscope to all underlying operators.
* set of related operations by calling other operator building code), the provided name will act as a subscope to all
* underlying operators.
*
* @param defaultName name for the underlying operator.
* @return unique name for the operator.
Expand Down Expand Up @@ -180,11 +180,15 @@ private Scope(
* @return a new scope with the provided control dependencies
*/
public Scope withControlDependencies(Iterable<Op> controls) {
for (Op control : controls) {
env.checkInput(control);
}
return new Scope(env, nameScope, controls, deviceSpec);
}

/**
* Applies device specification and adds each Operand in controlDependencies as a control input to the provided builder.
* Applies device specification and adds each Operand in controlDependencies as a control input to the provided
* builder.
*
* @param builder OperationBuilder to add control inputs and device specification to
*/
Expand All @@ -210,7 +214,9 @@ public OperationBuilder applyControlDependencies(OperationBuilder builder) {
private final NameScope nameScope;
private final DeviceSpec deviceSpec;

/** Returns device string from the scope. */
/**
* Returns device string from the scope.
*/
public String getDeviceString() {
return deviceSpec.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,13 @@
import org.tensorflow.exceptions.TFInvalidArgumentException;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TInt32;

/** Unit tests for {@link org.tensorflow.GraphOperationBuilder}. */
public class GraphOperationBuilderTest {

@Test
public void failWhenMixingOperationsOnDifferentGraphs() {
try (Graph g1 = new Graph();
Graph g2 = new Graph()) {
Ops tf = Ops.create(g1);
Constant<TInt32> c1 = tf.constant(3);
tf.math.add(c1, c1);
try {
Ops tf2 = Ops.create(g2);
tf2.math.add(c1, c1);
} catch (Exception e) {
fail(e.toString());
}
}
}

@Test
public void failOnUseAfterBuild() {
try (Graph g = new Graph();
Expand Down
Loading