Skip to content

Nicer error messages for mode-forbidden ops #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ public EagerOperation build() {

@Override
public EagerOperationBuilder addInput(Output<?> input) {
addInput(opHandle, (TFE_TensorHandle)input.getUnsafeNativeHandle());
addInput(opHandle, (TFE_TensorHandle) input.getUnsafeNativeHandle());
return this;
}

@Override
public EagerOperationBuilder addInputList(Output<?>[] inputs) {
TFE_TensorHandle[] inputHandles = new TFE_TensorHandle[inputs.length];
for (int i = 0; i < inputs.length; ++i) {
inputHandles[i] = (TFE_TensorHandle)inputs[i].getUnsafeNativeHandle();
inputHandles[i] = (TFE_TensorHandle) inputs[i].getUnsafeNativeHandle();
}
addInputList(opHandle, inputHandles);
return this;
Expand Down Expand Up @@ -226,7 +226,9 @@ public EagerOperationBuilder setAttr(String name, Shape[] values) {
private final String type;
private final String name;

/** This value should be >= to the maximum number of outputs in any op */
/**
* This value should be >= to the maximum number of outputs in any op
*/
private static final int MAX_OUTPUTS_PER_OP = 1000;

private static void requireOp(TFE_Op handle) {
Expand Down Expand Up @@ -358,7 +360,7 @@ private static void setAttrFloatList(TFE_Op opHandle, String name, float[] value

private static void setAttrBool(TFE_Op opHandle, String name, boolean value) {
requireOp(opHandle);
TFE_OpSetAttrBool(opHandle, name, (byte)(value ? 1 : 0));
TFE_OpSetAttrBool(opHandle, name, (byte) (value ? 1 : 0));
}

private static void setAttrBoolList(TFE_Op opHandle, String name, boolean[] values) {
Expand Down Expand Up @@ -410,7 +412,7 @@ private static void setAttrShapeList(TFE_Op opHandle, String name, long[] shapes
}
TF_Status status = TF_Status.newStatus();
TFE_OpSetAttrShapeList(opHandle, new BytePointer(name), shapesPointers, new IntPointer(numDims),
numDims.length, status);
numDims.length, status);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import org.tensorflow.internal.c_api.TFE_Context;
import org.tensorflow.internal.c_api.TFE_ContextOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.proto.framework.ConfigProto;

/**
Expand Down Expand Up @@ -271,6 +274,9 @@ static void closeDefaultForTest() {
@Override
public OperationBuilder opBuilder(String type, String name) {
checkSession();
if (!isOpEnabled(type)) {
throw new IllegalArgumentException("Op " + type + " is not valid in eager mode.");
}
return new EagerOperationBuilder(this, type, name);
}

Expand All @@ -279,6 +285,18 @@ public Types environmentType() {
return Types.EAGER;
}

@Override
public boolean isOpEnabled(String opType) {
switch (opType) {
case Variable.OP_NAME:
case Placeholder.OP_NAME:
case Assign.OP_NAME:
return false;
default:
return true;
}
}

TFE_Context nativeHandle() {
checkSession();
return nativeHandle;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ enum Types {
*/
OperationBuilder opBuilder(String type, String name);

/**
* Returns true if the given operation is valid in this execution environment.
* @param opType The op to check.
* @return Whether the given operation is valid in this execution environment.
*/
default boolean isOpEnabled(String opType){
return true;
}

/**
* Get the type of this environment (from the `Environments` enumeration.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ public Iterator<Operation> operations() {
*/
@Override
public GraphOperationBuilder opBuilder(String type, String name) {
if (!isOpEnabled(type)) {
throw new IllegalArgumentException("Op " + type + " is not valid in graph mode.");
}
return new GraphOperationBuilder(this, type, name);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@
import org.tensorflow.ndarray.Shape;
import org.tensorflow.proto.framework.DataType;

/** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */
/**
* An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}.
*/
public final class GraphOperationBuilder implements OperationBuilder {

GraphOperationBuilder(Graph graph, String type, String name) {
Expand Down Expand Up @@ -103,7 +105,7 @@ public GraphOperationBuilder addControlInput(Operation control) {
public GraphOperationBuilder addInput(Output<?> input) {
Graph.Reference r = graph.ref();
try {
addInput(unsafeNativeHandle, (TF_Operation)input.getUnsafeNativeHandle(), input.index());
addInput(unsafeNativeHandle, (TF_Operation) input.getUnsafeNativeHandle(), input.index());
} finally {
r.close();
}
Expand All @@ -117,7 +119,7 @@ public GraphOperationBuilder addInputList(Output<?>[] inputs) {
TF_Operation[] opHandles = new TF_Operation[inputs.length];
int[] indices = new int[inputs.length];
for (int i = 0; i < inputs.length; ++i) {
opHandles[i] = (TF_Operation)inputs[i].getUnsafeNativeHandle();
opHandles[i] = (TF_Operation) inputs[i].getUnsafeNativeHandle();
indices[i] = inputs[i].index();
}
addInputList(unsafeNativeHandle, opHandles, indices);
Expand Down Expand Up @@ -444,7 +446,7 @@ private static void setAttrFloatList(TF_OperationDescription handle, String name

private static void setAttrBool(TF_OperationDescription handle, String name, boolean value) {
requireHandle(handle);
TF_SetAttrBool(handle, name, (byte)(value ? 1 : 0));
TF_SetAttrBool(handle, name, (byte) (value ? 1 : 0));
}

private static void setAttrBoolList(TF_OperationDescription handle, String name, boolean[] value) {
Expand Down