From a42ed5b0c0198b43d1c2d0bdcf79fd87f5774ca8 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 18:56:02 -0800 Subject: [PATCH 1/3] start fobbiden ops checks Signed-off-by: Ryan Nett --- .../org/tensorflow/EagerOperationBuilder.java | 3 +++ .../main/java/org/tensorflow/EagerSession.java | 15 +++++++++++++++ .../java/org/tensorflow/ExecutionEnvironment.java | 9 +++++++++ .../org/tensorflow/GraphOperationBuilder.java | 3 +++ 4 files changed, 30 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index 9df8444a11f..816c5ec3cd3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -56,6 +56,9 @@ final class EagerOperationBuilder implements OperationBuilder { EagerOperationBuilder(EagerSession session, String type, String name) { + if(!session.isOpEnabled(type)) + throw new IllegalArgumentException("Op " + type + " is not valid in eager mode."); + this.session = session; this.type = type; this.name = name; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 3f29245ce0b..edf8e53abdc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -27,6 +27,9 @@ import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.internal.c_api.TFE_ContextOptions; import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; /** @@ -279,6 +282,18 @@ public Types environmentType() { return Types.EAGER; } + @Override + public boolean isOpEnabled(String opType) { + switch (opType) { + case Variable.OP_NAME: + case Placeholder.OP_NAME: + case Assign.OP_NAME: + return false; + default: + return true; + } + } + TFE_Context nativeHandle() { checkSession(); return nativeHandle; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index a894b665763..a7a3363f690 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -34,6 +34,15 @@ enum Types { */ OperationBuilder opBuilder(String type, String name); + /** + * Returns true if the given operation is valid in this execution environment. + * @param opType The op to check. + * @return Whether the given operation is valid in this execution environment. + */ + default boolean isOpEnabled(String opType){ + return true; + } + /** * Get the type of this environment (from the `Environments` enumeration. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 927d9c52dd1..da7079879b7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -58,6 +58,9 @@ public final class GraphOperationBuilder implements OperationBuilder { GraphOperationBuilder(Graph graph, String type, String name) { + if(!graph.isOpEnabled(type)) + throw new IllegalArgumentException("Op " + type + " is not valid in graph mode."); + this.graph = graph; Graph.Reference r = graph.ref(); try { From 463036dbcd41628290947245a850ee526a93d6db Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 1 Jan 2021 15:45:43 -0800 Subject: [PATCH 2/3] fix style Signed-off-by: Ryan Nett --- .../org/tensorflow/EagerOperationBuilder.java | 15 +++++++++------ .../org/tensorflow/GraphOperationBuilder.java | 13 ++++++++----- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index 816c5ec3cd3..600a696ca97 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -56,8 +56,9 @@ final class EagerOperationBuilder implements OperationBuilder { EagerOperationBuilder(EagerSession session, String type, String name) { - if(!session.isOpEnabled(type)) + if (!session.isOpEnabled(type)) { throw new IllegalArgumentException("Op " + type + " is not valid in eager mode."); + } this.session = session; this.type = type; @@ -78,7 +79,7 @@ public EagerOperation build() { @Override public EagerOperationBuilder addInput(Output input) { - addInput(opHandle, (TFE_TensorHandle)input.getUnsafeNativeHandle()); + addInput(opHandle, (TFE_TensorHandle) input.getUnsafeNativeHandle()); return this; } @@ -86,7 +87,7 @@ public EagerOperationBuilder addInput(Output input) { public EagerOperationBuilder addInputList(Output[] inputs) { TFE_TensorHandle[] inputHandles = new TFE_TensorHandle[inputs.length]; for (int i = 0; i < inputs.length; ++i) { - inputHandles[i] = (TFE_TensorHandle)inputs[i].getUnsafeNativeHandle(); + inputHandles[i] = (TFE_TensorHandle) inputs[i].getUnsafeNativeHandle(); } addInputList(opHandle, inputHandles); return this; @@ -229,7 +230,9 @@ public EagerOperationBuilder setAttr(String name, Shape[] values) { private final String type; private final String name; - /** This value should be >= to the maximum number of outputs in any op */ + /** + * This value should be >= to the maximum number of outputs in any op + */ private static final int MAX_OUTPUTS_PER_OP = 1000; private static void requireOp(TFE_Op handle) { @@ -361,7 +364,7 @@ private static void setAttrFloatList(TFE_Op opHandle, String name, float[] value private static void setAttrBool(TFE_Op opHandle, String name, boolean value) { requireOp(opHandle); - TFE_OpSetAttrBool(opHandle, name, (byte)(value ? 1 : 0)); + TFE_OpSetAttrBool(opHandle, name, (byte) (value ? 1 : 0)); } private static void setAttrBoolList(TFE_Op opHandle, String name, boolean[] values) { @@ -413,7 +416,7 @@ private static void setAttrShapeList(TFE_Op opHandle, String name, long[] shapes } TF_Status status = TF_Status.newStatus(); TFE_OpSetAttrShapeList(opHandle, new BytePointer(name), shapesPointers, new IntPointer(numDims), - numDims.length, status); + numDims.length, status); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index da7079879b7..aba7f0cd424 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -54,12 +54,15 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.proto.framework.DataType; -/** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */ +/** + * An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. + */ public final class GraphOperationBuilder implements OperationBuilder { GraphOperationBuilder(Graph graph, String type, String name) { - if(!graph.isOpEnabled(type)) + if (!graph.isOpEnabled(type)) { throw new IllegalArgumentException("Op " + type + " is not valid in graph mode."); + } this.graph = graph; Graph.Reference r = graph.ref(); @@ -106,7 +109,7 @@ public GraphOperationBuilder addControlInput(Operation control) { public GraphOperationBuilder addInput(Output input) { Graph.Reference r = graph.ref(); try { - addInput(unsafeNativeHandle, (TF_Operation)input.getUnsafeNativeHandle(), input.index()); + addInput(unsafeNativeHandle, (TF_Operation) input.getUnsafeNativeHandle(), input.index()); } finally { r.close(); } @@ -120,7 +123,7 @@ public GraphOperationBuilder addInputList(Output[] inputs) { TF_Operation[] opHandles = new TF_Operation[inputs.length]; int[] indices = new int[inputs.length]; for (int i = 0; i < inputs.length; ++i) { - opHandles[i] = (TF_Operation)inputs[i].getUnsafeNativeHandle(); + opHandles[i] = (TF_Operation) inputs[i].getUnsafeNativeHandle(); indices[i] = inputs[i].index(); } addInputList(unsafeNativeHandle, opHandles, indices); @@ -447,7 +450,7 @@ private static void setAttrFloatList(TF_OperationDescription handle, String name private static void setAttrBool(TF_OperationDescription handle, String name, boolean value) { requireHandle(handle); - TF_SetAttrBool(handle, name, (byte)(value ? 1 : 0)); + TF_SetAttrBool(handle, name, (byte) (value ? 1 : 0)); } private static void setAttrBoolList(TF_OperationDescription handle, String name, boolean[] value) { From 6d2dcf604d45c609141db0b9f5118455675b6e74 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 1 Jan 2021 15:48:49 -0800 Subject: [PATCH 3/3] move checks to builder method Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/EagerOperationBuilder.java | 4 ---- .../src/main/java/org/tensorflow/EagerSession.java | 3 +++ .../src/main/java/org/tensorflow/Graph.java | 3 +++ .../src/main/java/org/tensorflow/GraphOperationBuilder.java | 4 ---- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index 600a696ca97..37f3af7ca26 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -56,10 +56,6 @@ final class EagerOperationBuilder implements OperationBuilder { EagerOperationBuilder(EagerSession session, String type, String name) { - if (!session.isOpEnabled(type)) { - throw new IllegalArgumentException("Op " + type + " is not valid in eager mode."); - } - this.session = session; this.type = type; this.name = name; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index edf8e53abdc..96ef5228a4f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -274,6 +274,9 @@ static void closeDefaultForTest() { @Override public OperationBuilder opBuilder(String type, String name) { checkSession(); + if (!isOpEnabled(type)) { + throw new IllegalArgumentException("Op " + type + " is not valid in eager mode."); + } return new EagerOperationBuilder(this, type, name); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index d70460ee4ea..f2717f263eb 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -147,6 +147,9 @@ public Iterator operations() { */ @Override public GraphOperationBuilder opBuilder(String type, String name) { + if (!isOpEnabled(type)) { + throw new IllegalArgumentException("Op " + type + " is not valid in graph mode."); + } return new GraphOperationBuilder(this, type, name); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index aba7f0cd424..9c0f011bab4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -60,10 +60,6 @@ public final class GraphOperationBuilder implements OperationBuilder { GraphOperationBuilder(Graph graph, String type, String name) { - if (!graph.isOpEnabled(type)) { - throw new IllegalArgumentException("Op " + type + " is not valid in graph mode."); - } - this.graph = graph; Graph.Reference r = graph.ref(); try {